/**
 *   ___ _   _ ___   _     _       ___ ___ ___ ___
 *  / __| | | |   \ /_\   | |  ___| _ ) __/ __/ __|
 * | (__| |_| | |) / _ \  | |_|___| _ \ _| (_ \__ \
 *  \___|\___/|___/_/ \_\ |____|  |___/_| \___|___/
 *
 * File linesearch_cuda.h: Line search for GPU implementation.
 * NOTE: Included from lbfgs.cu, not to be used on its own!
 **/

#ifndef LINESEARCH_GPU_H
#define LINESEARCH_GPU_H

#include "cost_function.h"
#include "timer.h"
#include "lbfgs.h"

#include <cmath>

namespace gpu_lbfgs
{
	// Variables on the GPU
	__device__ double fnew;
	__device__ double factor;
	__device__ double stepsize;

	//Kernels
	__global__ void adjust_stepsize(double *d_factor, double *d_stepsize);
	__global__ void check_maxfrise(double *d_fnew, double *d_fk);
	__global__ void reduce_stepsize(double *d_factor, double *d_stepsize);
}

bool lbfgs::gpu_linesearch(double *d_x, double *d_z, double *d_fk, double *d_gk,
		size_t &evals, lbfgs::status &stat, double *d_step, const size_t maxEvals,
		timer *timer_evals, timer *timer_linesearch, int *d_status, _Bool *debug, 
		_Bool *cudatimet, const bool atomRigidT, const int degFreedoms, const int nRigidBody, 
		int *d_cmax, int *d_maxsite, double *d_xrigid, double *d_gkrigid, _Bool *coldfusiont, 
		double *coldfusionlim, size_t it)
{
	using namespace gpu_lbfgs;

	const size_t NX = m_costFunction.getNumberOfUnknowns(); 	// 3*natoms

	double *d_xnew, *d_gnew, *d_fnew, *d_factor, *d_stepsize;

	const double one = 1.0;
	const double minusone = -1.0;

	double gnorm, dotted;

	CudaSafeCall( cudaMalloc((void**)&d_xnew,   NX * sizeof(double)) );
	CudaSafeCall( cudaMalloc((void**)&d_gnew,   NX * sizeof(double)) );

	CudaSafeCall( cudaGetSymbolAddress((void**)&d_fnew,  gpu_lbfgs::fnew ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_factor,  gpu_lbfgs::factor ) );
	CudaSafeCall( cudaGetSymbolAddress((void**)&d_stepsize,  gpu_lbfgs::stepsize ) );

	if (it == 0){
		dispatch_nrm2(NX, &gnorm, d_gk, false);     // gnorm = sqrt(gkDotgk)	
		double firstfactor = std::min(1/gnorm, gnorm);
		CudaSafeCall( cudaMemcpyToSymbol(gpu_lbfgs::factor, &firstfactor,  sizeof(double)) );
	} else{
		CudaSafeCall( cudaMemcpyToSymbol(gpu_lbfgs::factor, &one,  sizeof(double)) );
	}

	// If the step is pointing uphill, invert it
	dispatch_dot(NX, &dotted, d_z, d_gk, false);	// dotted = zDotgk
	if (dotted > 0.0){
		if (*debug){
			std::string aline("Warning: step direction was uphill. Inverting. ");
			writetodebug(aline);
		}
		dispatch_scale(NX, d_z, d_z, &minusone, false);		// z = -1 * z
	}

	dispatch_nrm2(NX, d_stepsize, d_z);	// stepsize = sqrt(zDotZ)
	// Make sure the step is no larger than maxStep
	adjust_stepsize<<<1,1>>>(d_factor, d_stepsize);		
	CudaCheckError();
	cudaDeviceSynchronize();

	int nred;
	int nred_max = 10;
	for (nred = 0; nred < nred_max; ++nred){
		dispatch_axpy(NX, d_xnew, d_x, d_z, d_factor);	// xnew = x + z*factor

		if (*cudatimet){
			timer_linesearch->stop();
			timer_evals->start();
		}

		if (! m_atomRigidT){
			CudaSafeCall( cudaMemcpy(d_xrigid, d_xnew, m_degFreedoms * sizeof(double), cudaMemcpyDeviceToDevice) );

			// Rigid bodies coordinate transformation
			transform_rigidtoc(m_nRigidBody, d_xnew, d_xrigid, m_degFreedoms, m_maxSite, d_cmax, d_maxsite);
		}

		// Calculate energy and gradient for proposed step
		m_costFunction.f_gradf(d_xnew, d_fnew, d_gnew, coldfusiont, coldfusionlim);	
		CudaCheckError();
		cudaDeviceSynchronize();

		if (*coldfusiont){
			stat = lbfgs::LBFGS_COLD_FUSION_DIAGNOSED;
			return false;
		}

		if (! m_atomRigidT){
			double *zeros;
			zeros = new double[NX];
			for (size_t i = 0; i < NX; ++i){
				zeros[i] = 0.0;
			}

			CudaSafeCall( cudaMemcpy(d_gkrigid, zeros, m_degFreedoms * sizeof(double), cudaMemcpyHostToDevice) );

			// Rigid bodies gradient transformation
			transform_grad(d_gnew, d_xrigid, d_gkrigid, m_nRigidBody, d_cmax, m_maxSite, d_maxsite, m_degFreedoms);

			CudaSafeCall( cudaMemcpy(d_xnew, zeros, NX * sizeof(double), cudaMemcpyHostToDevice) );
			CudaSafeCall( cudaMemcpy(d_gnew, zeros, NX * sizeof(double), cudaMemcpyHostToDevice) );

			CudaSafeCall( cudaMemcpy(d_xnew, d_xrigid, m_degFreedoms * sizeof(double), cudaMemcpyDeviceToDevice) );
			CudaSafeCall( cudaMemcpy(d_gnew, d_gkrigid, m_degFreedoms * sizeof(double), cudaMemcpyDeviceToDevice) );

			delete [] zeros;
		}

		++evals;

		if (*cudatimet){
			timer_evals->stop();
			timer_linesearch->start();
		}

		// Check whether energy has risen too much
		check_maxfrise<<<1,1>>>(d_fnew, d_fk);
		CudaCheckError();
		cudaDeviceSynchronize();

		int ret=0;
		CudaSafeCall( cudaMemcpy(&ret, d_status, sizeof(int), cudaMemcpyDeviceToHost) );

		if (ret == 1){
			break;
		} else {
			// If energy rose too much, reduce stepsize and continue the loop
			reduce_stepsize<<<1,1>>>(d_factor, d_stepsize);
			CudaCheckError();
			cudaDeviceSynchronize();
			if (*debug){
				double b, testfactor, teststepsize;
				CudaSafeCall( cudaMemcpy(&testfactor, d_factor, sizeof(double), cudaMemcpyDeviceToHost) );
				CudaSafeCall( cudaMemcpy(&teststepsize, d_stepsize, sizeof(double), cudaMemcpyDeviceToHost) );
				b = testfactor * teststepsize;

				std::string str_b = static_cast<std::ostringstream*>( &(std::ostringstream() << b) )->str();
				std::string first = "Function increased, reducing step size to ";

				std::string aline(first + str_b);
				writetodebug(aline);
			}
		}
	}

	if (nred >= nred_max){
		if (*debug){
			std::string aline("Warning: a step size cannot be found where the maximum allowed function increase is not exceeded. ");
			writetodebug(aline);
		}
	}

	CudaSafeCall( cudaMemcpy(d_x, d_xnew, NX * sizeof(double), cudaMemcpyDeviceToDevice) );
	CudaSafeCall( cudaMemcpy(d_gk, d_gnew, NX * sizeof(double), cudaMemcpyDeviceToDevice) );
	CudaSafeCall( cudaMemcpy(d_fk, d_fnew, sizeof(double), cudaMemcpyDeviceToDevice) );

	if (evals >= maxEvals){
		stat = lbfgs::LBFGS_REACHED_MAX_EVALS;
		return false;
	}

	CudaSafeCall( cudaMemcpy(d_step, d_factor, sizeof(double), cudaMemcpyDeviceToDevice) );

	CudaSafeCall( cudaFree(d_xnew) );
	CudaSafeCall( cudaFree(d_gnew) );
	return true;
}


namespace gpu_lbfgs
{
	__global__ void adjust_stepsize(double *d_factor, double *d_stepsize)
	{
		double a = factor * stepsize;
		if (a > maxStep){
			factor = maxStep / stepsize;
		}
	}

	__global__ void check_maxfrise(double *d_fnew, double *d_fk)
	{
		double df;
		df = fnew - fk;
		if (df < maxFRise){
			status = 1;
		} else{
			status = 0;
		}
	}

	__global__ void reduce_stepsize(double *d_factor, double *d_stepsize)
	{
		factor /= 10.;
	}
}

#endif // LINESEARCH_GPU_H

