/* This work is a modification of code written by Jens Wetzl and Oliver Taubamann in 2012. 
 * The original work can be found here: https://github.com/jwetzl/CudaLBFGS 
 * This work is not endorsed by the authors. */

/**
 *   ___ _   _ ___   _     _       ___ ___ ___ ___
 *  / __| | | |   \ /_\   | |  ___| _ ) __/ __/ __|
 * | (__| |_| | |) / _ \  | |_|___| _ \ _| (_ \__ \
 *  \___|\___/|___/_/ \_\ |____|  |___/_| \___|___/
 *
 * File lbfgs_cuda.cu: Implementation of class lbfgs
 *
 **/

#include "lbfgs.h"
#include "timer.h"

#include <limits>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <device_launch_parameters.h>
#include <device_functions.h>
#include <fstream>
#include <sstream>

namespace gpu_lbfgs {

	// Variables

	__device__ double fkm1;
	__device__ double fk;
	__device__ double tmp;

	__device__ double H0;
	__device__ double step;
	__device__ double tmp2;
	__device__ int status;
	__device__ double maxStep;
	__device__ double maxFRise;

	__device__ int cmax;
	__device__ int maxsite;
	__device__ int degfreedoms;

	// Small helper kernels for scalar operations in device memory needed during updates.
	// What they're used for is documented by comments in the places they are executed.
	// *** Use with a single thread only! ***

	__global__ void update1   (double *alpha_out, const double *sDotZ, const double *rho, double *minusAlpha_out);       // first  update loop
	__global__ void update2   (double *alphaMinusBeta_out, const double *rho, const double *yDotZ, const double *alpha); // second update loop
	__global__ void update3   (double *rho_out, double *H0_out, double *yDotS, double *yDotY);               // after line search
}

// linesearch_gpu.h is no real header, it contains
// part of the implementation and must be included
// after the variables above have been declared.
#include "linesearch_cuda.h" 

	lbfgs::lbfgs(cost_function& cf)
	: m_costFunction(cf)
	, m_maxIter(10000)
	, m_maxEvals(std::numeric_limits<size_t>::max())
	, m_maxStep(0.2)
	, m_gradientEps(1e-4f)
	, m_maxFRise(1e-4f)
	, m_dGuess(0.1)
	, m_atomRigidT(true)
	, m_degFreedoms(0)
	, m_nRigidBody(0)
	, m_maxSite(0)
	  , m_mUpdate(4)
{
	CublasSafeCall( cublasCreate(&m_cublasHandle) );
}

lbfgs::~lbfgs()
{
	CublasSafeCall( cublasDestroy(m_cublasHandle) );
}

std::string lbfgs::statusToString(lbfgs::status stat)
{
	switch (stat)
	{
		case LBFGS_BELOW_GRADIENT_EPS:
			return "Below gradient epsilon";
		case LBFGS_REACHED_MAX_ITER:
			return "Reached maximum number of iterations";
		case LBFGS_REACHED_MAX_EVALS:
			return "Reached maximum number of function/gradient evaluations";
		case LBFGS_LINE_SEARCH_FAILED:
			return "Line search failed";
		case LBFGS_COLD_FUSION_DIAGNOSED:
			return "Cold fusion diagnosed";
		default:
			return "Unknown status";
	}
}

lbfgs::status lbfgs::minimize(double *d_x, double *energy, double *outrms, _Bool *debug, int *itdone, _Bool *cudatimet, int *ecalls, _Bool *coldfusiont, double *coldfusionlim)
{
	return gpu_lbfgs(d_x, energy, outrms, debug, itdone, cudatimet, ecalls, coldfusiont, coldfusionlim);
}

lbfgs::status lbfgs::gpu_lbfgs(double *d_x, double *energy, double *outrms, _Bool *debug, int *itdone, _Bool *cudatimet, int *ecalls, _Bool *coldfusiont, double *coldfusionlim)
{
	timer timer_total     ("GPU_LBFGS_total"     );
	timer timer_evals     ("GPU_LBFGS_evals"     );
	timer timer_updates   ("GPU_LBFGS_updates"   );
	timer timer_linesearch("GPU_LBFGS_linesearch");
	if (*cudatimet){
		timer_total.start();
	}

	using namespace gpu_lbfgs;
	const size_t NX = m_costFunction.getNumberOfUnknowns();

	double *d_fkm1, *d_fk;  // f_{k-1}, f_k, function values at x_{k-1} and x_k
	double *d_gkm1, *d_gk;  // g_{k-1}, g_k, gradients       at x_{k-1} and x_k
	double *d_z;            // z,            search direction
	double *d_H0;           // H_0,          initial inverse Hessian (diagonal, same value for all elements)

	double *d_step;         // step          current step length
	double *d_tmp, *d_tmp2; // tmp, tmp2     temporary storage for intermediate results
	int   *d_status;       // status        return code for communication device -> host

	// Rigid bodies

	double *d_xrigid;	// rigid body coordinates
	double *d_gkrigid;	// rigid body gradient
	int *d_cmax;		// no. of rigid bodies
	int *d_maxsite;		// maximum number of sites in a rigid body
	int *d_degfreedoms;	// no. of degrees of freedom

	// Ring buffers for history

	double *d_s;            // s,            history of solution updates
	double *d_y;            // y,            history of gradient updates
	double *d_alpha;        // alpha,        history of alphas (needed for z updates)
	double *d_rho;          // rho,          history of rhos   (needed for z updates)

	// Allocations

	CudaSafeCall( cudaMalloc(&d_gk,   NX * sizeof(double)) );
	CudaSafeCall( cudaMalloc(&d_gkm1, NX * sizeof(double)) );
	CudaSafeCall( cudaMalloc(&d_z,    NX * sizeof(double)) );

	CudaSafeCall( cudaMalloc(&d_s,    m_mUpdate * NX * sizeof(double)) );
	CudaSafeCall( cudaMalloc(&d_y,    m_mUpdate * NX * sizeof(double)) );
	CudaSafeCall( cudaMalloc(&d_alpha,    m_mUpdate * sizeof(double)) );
	CudaSafeCall( cudaMalloc(&d_rho,    m_mUpdate * sizeof(double)) );

	// Addresses of global symbols

	CudaSafeCall( cudaGetSymbolAddress((void**)&d_fkm1,  gpu_lbfgs::fkm1  ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_fk,    gpu_lbfgs::fk    ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_tmp,   gpu_lbfgs::tmp   ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_tmp2,  gpu_lbfgs::tmp2  ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_H0,    gpu_lbfgs::H0    ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_step,  gpu_lbfgs::step  ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_status,gpu_lbfgs::status) );

	CudaSafeCall( cudaMemcpyToSymbol(maxStep, &m_maxStep,  sizeof(double)) );
	CudaSafeCall( cudaMemcpyToSymbol(maxFRise, &m_maxFRise,  sizeof(double)) );

	if (! m_atomRigidT){
		CudaSafeCall( cudaMalloc(&d_xrigid,    m_degFreedoms * sizeof(double)) );
		CudaSafeCall( cudaMalloc(&d_gkrigid,    m_degFreedoms * sizeof(double)) );

		CudaSafeCall( cudaGetSymbolAddress((void**)&d_cmax, gpu_lbfgs::cmax ) );
		CudaSafeCall( cudaGetSymbolAddress((void**)&d_maxsite, gpu_lbfgs::maxsite ) );
		CudaSafeCall( cudaGetSymbolAddress((void**)&d_degfreedoms, gpu_lbfgs::degfreedoms ) );

		CudaSafeCall( cudaMemcpyToSymbol(gpu_lbfgs::cmax, &m_nRigidBody,  sizeof(int)) );
		CudaSafeCall( cudaMemcpyToSymbol(gpu_lbfgs::maxsite, &m_maxSite,  sizeof(int)) );
		CudaSafeCall( cudaMemcpyToSymbol(gpu_lbfgs::degfreedoms, &m_degFreedoms,  sizeof(int)) );

		CudaSafeCall( cudaMemcpy(d_xrigid, d_x, m_degFreedoms * sizeof(double), cudaMemcpyDeviceToDevice) );

		// Rigid bodies coordinate transformation
		transform_rigidtoc(m_nRigidBody, d_x, d_xrigid, m_degFreedoms, m_maxSite, d_cmax, d_maxsite);
	}

	// Initialize

	if (*cudatimet){
		timer_evals.start();
	}

	m_costFunction.f_gradf(d_x, d_fk, d_gk, coldfusiont, coldfusionlim);

	CudaCheckError();
	cudaDeviceSynchronize();

	if (*cudatimet){
		timer_evals.stop();
	}

	if (! m_atomRigidT){
		double *zeros;
		zeros = new double[NX];
		for (size_t k = 0; k < NX; ++k){
			zeros[k] = 0.0;
		}

		CudaSafeCall( cudaMemcpy(d_gkrigid, zeros, m_degFreedoms * sizeof(double), cudaMemcpyHostToDevice) );

		// Rigid bodies gradient transformation
		transform_grad(d_gk, d_xrigid, d_gkrigid, m_nRigidBody, d_cmax, m_maxSite, d_maxsite, m_degFreedoms);

		CudaSafeCall( cudaMemcpy(d_x, zeros, NX * sizeof(double), cudaMemcpyHostToDevice) );
		CudaSafeCall( cudaMemcpy(d_gk, zeros, NX * sizeof(double), cudaMemcpyHostToDevice) );

		CudaSafeCall( cudaMemcpy(d_x, d_xrigid, m_degFreedoms * sizeof(double), cudaMemcpyDeviceToDevice) );
		CudaSafeCall( cudaMemcpy(d_gk, d_gkrigid, m_degFreedoms * sizeof(double), cudaMemcpyDeviceToDevice) );

		delete [] zeros;
	}

	if (m_maxIter == 0){
		double firstgkNormSq;
		dispatch_dot(NX, &firstgkNormSq, d_gk, d_gk, false);
		if (! m_atomRigidT){
			*outrms = sqrt(firstgkNormSq/m_degFreedoms);
		}
		else{
			*outrms = sqrt(firstgkNormSq/NX);
		}
	}

	size_t evals = 1;

	status stat = LBFGS_REACHED_MAX_ITER;

	if (*debug){
		std::string aline("lbfgs::gpu_lbfgs()");
		writetodebug(aline);
	}

	// Initial guess for inverse Hessian diagonal elements
	const double guess = m_dGuess;
	CudaSafeCall( cudaMemcpy(d_H0, &guess, sizeof(double), cudaMemcpyHostToDevice) );

	size_t it;

	for (it = 0; it < m_maxIter; ++it)
	{
		if (*debug){
			double h_y;
			CudaSafeCall( cudaMemcpy(&h_y, d_fk, sizeof(double), cudaMemcpyDeviceToHost) );

			double gknorm2;
			dispatch_dot(NX, &gknorm2, d_gk, d_gk, false);

			double rmsforce;
			if (! m_atomRigidT){
				rmsforce = sqrt(gknorm2/m_degFreedoms);
			}
			else{
				rmsforce = sqrt(gknorm2/NX);
			}

			std::string str_h_y = static_cast<std::ostringstream*>( &(std::ostringstream() << h_y) )->str();
			std::string str_rmsforce = static_cast<std::ostringstream*>( &(std::ostringstream() << rmsforce) )->str();
			std::string first = "f(x) = ";
			std::string second = ", rms = ";

			std::string aline(first + str_h_y + second + str_rmsforce);
			writetodebug(aline);
		}

		// Check for convergence
		// ---------------------

		double gkNormSquared;

		dispatch_dot(NX, &gkNormSquared, d_gk, d_gk, false);
		if (! m_atomRigidT){
			*outrms = sqrt(gkNormSquared/m_degFreedoms);
		}
		else{
			*outrms = sqrt(gkNormSquared/NX);
		}

		if (*outrms < m_gradientEps)
		{
			stat = LBFGS_BELOW_GRADIENT_EPS;
			break;
		}

		// Find search direction
		// ---------------------

		if (*cudatimet){
			timer_updates.start();
		}

		const double minusOne = -1.0;
		dispatch_scale(NX, d_z, d_gk, &minusOne, false); // z = -gk

		const size_t MAX_IDX = std::min<size_t>(it, m_mUpdate);

		for (size_t i = 1; i <= MAX_IDX; ++i)
		{
			size_t idx = ((it - i) % m_mUpdate);

			dispatch_dot(NX, d_tmp, d_s + idx * NX, d_z); // tmp = sDotZ

			// alpha = tmp * rho
			// tmp = -alpha
			update1<<<1, 1>>>(d_alpha + idx, d_tmp, d_rho + idx, d_tmp);

			CudaCheckError();
			cudaDeviceSynchronize();

			// z += tmp * y
			dispatch_axpy(NX, d_z, d_z, d_y + idx * NX, d_tmp);
		}

		dispatch_scale(NX, d_z, d_z, d_H0); // z = H0 * z

		for (size_t i = MAX_IDX; i > 0; --i)
		{
			size_t idx = ((it - i) % m_mUpdate);

			dispatch_dot(NX, d_tmp, d_y + idx * NX, d_z); // tmp = yDotZ

			// beta = rho * tmp
			// tmp = alpha - beta
			update2<<<1, 1>>>(d_tmp, d_rho + idx, d_tmp, d_alpha + idx);

			CudaCheckError();
			cudaDeviceSynchronize();

			// z += tmp * s
			dispatch_axpy(NX, d_z, d_z, d_s + idx * NX, d_tmp);

		}

		if (*cudatimet){
			timer_updates.stop();
			timer_linesearch.start();
		}

		CudaSafeCall( cudaMemcpy(d_fkm1, d_fk, 1  * sizeof(double), cudaMemcpyDeviceToDevice) ); // fkm1 = fk;
		CudaSafeCall( cudaMemcpy(d_gkm1, d_gk, NX * sizeof(double), cudaMemcpyDeviceToDevice) ); // gkm1 = gk;

		timer *t_evals = NULL, *t_linesearch = NULL;

		if (*cudatimet){
			t_evals = &timer_evals;
			t_linesearch = &timer_linesearch;
		}
		// (line search defined in linesearch_gpu.h)
		if (!gpu_linesearch(d_x, d_z, d_fk, d_gk, evals, stat, d_step,
					m_maxEvals, t_evals, t_linesearch, d_status, 
					debug, cudatimet, m_atomRigidT, m_degFreedoms, 
					m_nRigidBody, d_cmax, d_maxsite, d_xrigid, d_gkrigid, 
					coldfusiont, coldfusionlim, it))
		{
			break;
		}

		if (*cudatimet){
			timer_linesearch.stop();
			timer_updates.start();
		}

		// Update s, y, rho and H_0
		// ------------------------

		// s   = x_k - x_{k-1} = step * z
		// y   = g_k - g_{k-1}
		// rho = 1 / (y^T s)
		// H_0 = (y^T s) / (y^T y)

		double *d_curS = d_s + (it % m_mUpdate) * NX;
		double *d_curY = d_y + (it % m_mUpdate) * NX;

		dispatch_scale(NX, d_curS, d_z,  d_step);                   // s = step * z
		dispatch_axpy (NX, d_curY, d_gk, d_gkm1, &minusOne, false); // y = gk - gkm1

		dispatch_dot(NX, d_tmp,  d_curY, d_curS); // tmp  = yDotS
		dispatch_dot(NX, d_tmp2, d_curY, d_curY); // tmp2 = yDotY

		// rho = 1 / tmp
		//   H0 = tmp / tmp2
		update3<<<1, 1>>>(d_rho + (it % m_mUpdate), d_H0, d_tmp, d_tmp2);

		CudaCheckError();
		cudaDeviceSynchronize();

		if (*cudatimet){
			timer_updates.stop();
		}
	}
	CudaSafeCall( cudaMemcpy(energy, d_fk, sizeof(double), cudaMemcpyDeviceToHost) );
	*itdone = it;
	*ecalls = evals;

	// Deallocations

	CudaSafeCall( cudaFree(d_gk)   );
	CudaSafeCall( cudaFree(d_gkm1) );
	CudaSafeCall( cudaFree(d_z)    );

	CudaSafeCall( cudaFree(d_s)    );
	CudaSafeCall( cudaFree(d_y)    );
	CudaSafeCall( cudaFree(d_alpha));
	CudaSafeCall( cudaFree(d_rho)  );

	if (! m_atomRigidT){
		CudaSafeCall( cudaFree(d_xrigid)    );
		CudaSafeCall( cudaFree(d_gkrigid)    );
	}

	if (*cudatimet){
		timer_total.stop();

		timer_total.saveMeasurement();
		timer_evals.saveMeasurement();
		timer_updates.saveMeasurement();
		timer_linesearch.saveMeasurement();
	}

	if (*debug){
		std::string statstring = statusToString(stat);
		std::string str_it = static_cast<std::ostringstream*>( &(std::ostringstream() << it) )->str();
		std::string str_evals = static_cast<std::ostringstream*>( &(std::ostringstream() << evals) )->str();

		std::string first = "Number of iterations: ";
		std::string second = "Number of function/gradient evaluations: ";
		std::string third = "Reason for termination: ";

		std::string aline(first + str_it);
		writetodebug(aline);

		std::string bline(second + str_evals);
		writetodebug(bline);

		std::string cline(third + statstring);
		writetodebug(cline);
	}
	return stat;
}

// Vector operations
// -----------------

void lbfgs::dispatch_axpy(const size_t n, double *d_dst, const double *d_y, const double *d_x, const double *a, bool aDevicePointer) const
{
	const cublasPointerMode_t mode = aDevicePointer ? CUBLAS_POINTER_MODE_DEVICE
		: CUBLAS_POINTER_MODE_HOST;

	CublasSafeCall( cublasSetPointerMode(m_cublasHandle, mode) );

	if (d_dst != d_y)
		CudaSafeCall( cudaMemcpy(d_dst, d_y, n * sizeof(double), cudaMemcpyDeviceToDevice) );

	CublasSafeCall( cublasDaxpy(m_cublasHandle, int(n), a, d_x, 1, d_dst, 1) );
}

void lbfgs::dispatch_scale(const size_t n, double *d_dst, const double *d_x, const double *a, bool aDevicePointer) const
{
	const cublasPointerMode_t mode = aDevicePointer ? CUBLAS_POINTER_MODE_DEVICE
		: CUBLAS_POINTER_MODE_HOST;

	CublasSafeCall( cublasSetPointerMode(m_cublasHandle, mode) );

	if (d_dst != d_x)
		CudaSafeCall( cudaMemcpy(d_dst, d_x, n * sizeof(double), cudaMemcpyDeviceToDevice) );

	CublasSafeCall( cublasDscal(m_cublasHandle, int(n), a, d_dst, 1) );
}


void lbfgs::dispatch_dot(const size_t n, double *dst, const double *d_x, const double *d_y, bool dstDevicePointer) const
{
	const cublasPointerMode_t mode = dstDevicePointer ? CUBLAS_POINTER_MODE_DEVICE
		: CUBLAS_POINTER_MODE_HOST;

	CublasSafeCall( cublasSetPointerMode(m_cublasHandle, mode) );

	CublasSafeCall( cublasDdot(m_cublasHandle, int(n), d_x, 1, d_y, 1, dst) );
}

void lbfgs::dispatch_nrm2(const size_t n, double *dst, const double *d_x, bool aDevicePointer) const
{
	const cublasPointerMode_t mode = aDevicePointer ? CUBLAS_POINTER_MODE_DEVICE
		: CUBLAS_POINTER_MODE_HOST;

	CublasSafeCall( cublasSetPointerMode(m_cublasHandle, mode) );

	CublasSafeCall( cublasDnrm2(m_cublasHandle, int(n), d_x, 1, dst) );
}


// -----------------

// Device / kernel functions
// -------------------------

namespace gpu_lbfgs
{
	__global__ void update1(double *alpha_out, const double *sDotZ, const double *rho, double *minusAlpha_out)
	{
		*alpha_out      = *sDotZ * *rho;
		*minusAlpha_out = -*alpha_out;
	}

	__global__ void update2(double *alphaMinusBeta_out, const double *rho, const double *yDotZ, const double *alpha)
	{
		const double beta = *rho * *yDotZ;
		*alphaMinusBeta_out = *alpha - beta;
	}

	__global__ void update3(double *rho_out, double *H0_out, double *yDotS, double *yDotY)
	{
		if (abs(*yDotS) < 1.0e-20){
			if (*yDotS >= 0){
				*yDotS = 1.0e-20;
			}
			else {
				*yDotS = -1.0e-20;
			}
		}

		*rho_out = 1.0 / *yDotS;

		if (abs(*yDotY) < 1.0e-20){
			if (*yDotY >= 0){
				*yDotY = 1.0e-20;
			}
			else {
				*yDotY = -1.0e-20;
			}
		}


		*H0_out = *yDotS / *yDotY;
	}
}
// ------------------

// Function for writing to debug file

void lbfgs::writetodebug(std::string debugline) const
{
	std::string filename("debug_out.txt");
	std::ofstream stream;
	stream.open(filename.c_str(), std::ios_base::app);
	stream << debugline << std::endl;
	stream.close();
}
// ------------------
