/**
 *   ___ _   _ ___   _     _       ___ ___ ___ ___
 *  / __| | | |   \ /_\   | |  ___| _ ) __/ __/ __|
 * | (__| |_| | |) / _ \  | |_|___| _ \ _| (_ \__ \
 *  \___|\___/|___/_/ \_\ |____|  |___/_| \___|___/
 *
 * File setup_cuda.cu: Communicates with GMIN and sets up and starts the minimization 
 *
 **/

#include "CudaLBFGS/lbfgs.h"
#include "CudaLBFGS/potential.h"

#include <stdio.h>
#include <iomanip>
#include <iostream>
#include <stdbool.h>

	template <class SomeType> 	// Can call this function with any potential class
void setup(const size_t& ndim, double *x, double *eps, _Bool *mflag, double *energy, int *itmax, int *itdone, double *maxbfgs, double *maxerise, double *outrms, SomeType &potential, _Bool *debug, _Bool *cudatimet, int *ecalls, _Bool *atomrigidt, int *degfree, int *nrigid, int *nrigidsites, int *rigidgroups, int *maxsite, double *sitesrigid, int *rigidsingles, _Bool *coldfusiont, double *coldfusionlim, double *dguess, int *mupdate)
{
	double *d_x;	// Coordinates on the GPU
	int *d_nrigidsites;	// No. of rigid body sites, on the GPU
	int *d_rigidgroups;	// List of atoms in rigid bodies, on the GPU
	double *d_sitesrigid;	// Coordinates of the rigid body sites, on the GPU
	int *d_rigidsingles;	// List of atoms not in rigid bodies, on the GPU 
	lbfgs::status stat;	// Status returned by lbfgs which indicates whether the minimization was successful (see lbfgs.h)

	const size_t maxIter = *itmax;	// Maximum number of iterations for lbfgs
	const double maxStep = *maxbfgs; 	// Maximum step size
	const double gradientEps = *eps;	// Convergence criterion for RMS force
	const double maxFRise = *maxerise;	// The maximum amount the energy is allowed to rise by when taking a step
	const double dGuess = *dguess;          // Guess for initial inverse Hessian diagonal elements
	const bool atomRigidT = *atomrigidt;	// If true, atomistic coordinates, if false, rigid bodies
	const int degFreedoms = *degfree;	// No. of degrees of freedom for rigid body framework
	const int nRigidBody = *nrigid;		// No. of rigid bodies
	const int maxSite = *maxsite;		// Max. no. of site in a rigid body
	const int mUpdate = *mupdate;		// History size for lbfgs

	lbfgs minimizer(potential);	// Pass our potential object to the minimizer

	minimizer.setMaxIterations(maxIter);
	minimizer.setGradientEpsilon(gradientEps);
	minimizer.setMaxStep(maxStep);
	minimizer.setMaxFRise(maxFRise);
	minimizer.setDGuess(dGuess);
	minimizer.setAtomRigidT(atomRigidT);
	minimizer.setMUpdate(mUpdate);

	if (! atomRigidT){
		CudaSafeCall( cudaMalloc(&d_nrigidsites, nRigidBody * sizeof(int)) );
		CudaSafeCall( cudaMalloc(&d_sitesrigid, nRigidBody * 3 * (*maxsite) * sizeof(double)) );
		CudaSafeCall( cudaMalloc(&d_rigidgroups, nRigidBody * (*maxsite) * sizeof(int)) );
		CudaSafeCall( cudaMalloc(&d_rigidsingles, (degFreedoms/3 - 2*nRigidBody) * sizeof(int)) );

		CudaSafeCall( cudaMemcpy(d_nrigidsites, nrigidsites, nRigidBody * sizeof(int), cudaMemcpyHostToDevice) );
		CudaSafeCall( cudaMemcpy(d_sitesrigid, sitesrigid,
					nRigidBody * 3 * (*maxsite) * sizeof(double), cudaMemcpyHostToDevice) );
		CudaSafeCall( cudaMemcpy(d_rigidgroups, rigidgroups,
					nRigidBody * (*maxsite) * sizeof(int), cudaMemcpyHostToDevice) );
		CudaSafeCall( cudaMemcpy(d_rigidsingles, rigidsingles,
					(degFreedoms/3 - 2*nRigidBody) * sizeof(int), cudaMemcpyHostToDevice) );

		minimizer.setDegFreedoms(degFreedoms);
		minimizer.setnRigidBody(nRigidBody);
		minimizer.setDevnRigidSites(d_nrigidsites);
		minimizer.setDevSitesRigid(d_sitesrigid);
		minimizer.setDevRigidGroups(d_rigidgroups);
		minimizer.setMaxSite(maxSite);
		minimizer.setDevRigidSingles(d_rigidsingles);
	}

	CudaSafeCall( cudaMalloc(&d_x, ndim * sizeof(double)) );        // Allocate memory for coordinates on the GPU
	CudaSafeCall( cudaMemcpy(d_x, x, ndim * sizeof(double), cudaMemcpyHostToDevice) );      // Copy our coordinates to the GPU

	stat = minimizer.minimize(d_x, energy, outrms, debug, itdone, cudatimet, ecalls, coldfusiont, coldfusionlim);	// Perform the minmization

	CudaSafeCall( cudaMemcpy(x, d_x, ndim * sizeof(double), cudaMemcpyDeviceToHost) );	// Copy our coordinates back from the GPU to the CPU

	if (stat == 0){
		*mflag = 1;	// Convergence achieved (LBFGS_BELOW_GRADIENT_EPS)
	}
	else {
		*mflag = 0;	// Did not converge
	}

	if (*debug){
		std::string aline("GPU LJ:");
		minimizer.writetodebug(aline);

		std::string bline = minimizer.statusToString(stat).c_str();
		minimizer.writetodebug(aline);	// Output end status of the minimizer
	}

	CudaSafeCall(cudaFree(d_x));	// Delete memory allocated for coordinates
	if (! atomRigidT){
		CudaSafeCall(cudaFree(d_nrigidsites));
		CudaSafeCall(cudaFree(d_sitesrigid));
		CudaSafeCall(cudaFree(d_rigidgroups));
		CudaSafeCall(cudaFree(d_rigidsingles));
	}
}

extern "C" void cuda_setup(int *n, double *x, double *eps, _Bool *mflag, double *energy, int *itmax, int *itdone, double *maxbfgs, double *maxerise, double *outrms, char *cudapot, _Bool *debug, _Bool *cudatimet, int *ecalls, _Bool *atomrigidt, int *degfree, int *nrigid, int *nrigidsites, int* rigidgroups, int *maxsite, double *sitesrigid, int *rigidsingles, _Bool *coldfusiont, double *coldfusionlim, double *dguess, int *mupdate)
{
	const size_t ndim = *n; 	// 3 * number of atoms

	if (*cudapot == 'L') { 		// L specifies the Lennard-Jones potential
		const size_t atom_max = 1024;
		if (ndim > 3*atom_max){
			std::cerr << "Lennard-Jones is currently only supported for a maximum of 1024 atoms. " << std::endl;
			exit(EXIT_FAILURE);
		}
		gpu_lj potential(ndim);		// Create an instance of the appropriate class for the potential, gpu_lj
		setup<gpu_lj>(ndim, x, eps, mflag, energy, itmax, itdone, maxbfgs, maxerise, outrms, potential, debug, cudatimet, ecalls, atomrigidt, degfree, nrigid, nrigidsites, rigidgroups, maxsite, sitesrigid, rigidsingles, coldfusiont, coldfusionlim, dguess, mupdate);	
	}
	else if (*cudapot == 'A') {	// A specifies the amber potential
		gpu_amber potential(ndim);
		setup<gpu_amber>(ndim, x, eps, mflag, energy, itmax, itdone, maxbfgs, maxerise, outrms, potential, debug, cudatimet, ecalls, atomrigidt, degfree, nrigid, nrigidsites, rigidgroups, maxsite, sitesrigid, rigidsingles, coldfusiont, coldfusionlim, dguess, mupdate); 
	}
	else {
		std::cerr << "The specified potential has not been recognised" << std::endl;
		exit(EXIT_FAILURE);
	}
}
