import argparse
import numpy as np

import minimization_tools as opt

from pele.optimize import mylbfgs, lbfgs_scipy
from pele.potentials import LJ

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="compute benchmarks for lbfgs with pele")
    
#    parser.add_argument("fname", type=str, help="Database file name")
    parser.add_argument("-M", type=int, default=4, help="lbfgs history length")
    parser.add_argument("-natoms", type=int, default=38, help="number of atoms")
    parser.add_argument("--maxstep", type=float, default=0.1, help="lbfgs maximum step size")
    parser.add_argument("--lbfgs-scipy", action="store_true", help="use scipy lbfgs")
    args = parser.parse_args()
    
    tol = 0.01

    potLJ = LJ()
    pot = opt.PotWrapper(potLJ)


    structuredir = "./lj38-clusters"
    nstructures = 1000
    benchmarker = opt.QuenchBenchmark(structuredir, nstructures)
    
    kwargs = dict()
    if False:
        # use the maximum force on an atom as the stop criterion
        stop_crit = opt.MaxForceOnAtom()
        kwargs["alternate_stop_criterion"]=stop_crit
    else:
        # use the norm of the gradient.  This is sqrt(natoms) times the rms
        tol /= np.sqrt(3. * args.natoms)
    

    
    if args.lbfgs_scipy:
        minimizer = opt.Minimizer("results_data", pot, lbfgs_scipy,
                                  tol=tol, **kwargs )
    else:
        minimizer = opt.Minimizer("results_data", pot, mylbfgs,
                                   M=args.M, tol=tol, maxstep=args.maxstep, **kwargs )
    benchmarker.addMinimizer(minimizer)


    benchmarker.run()


    for minimizer in benchmarker.minimizers:
        print ""
        print ""
        print minimizer.label
        ncalls = np.array( minimizer.ncalls )
        if True:
            print "mean ncalls", np.mean(ncalls)
            print "max ncalls", np.max(ncalls)
            print "min ncalls", np.min(ncalls)

        if True:
            fname = minimizer.label + ".ncalls"
            with open(fname, "w") as fout:
                fout.write( opt.QuenchResult.header() + "\n" )
                for qr in minimizer.quench_results:
                    fout.write( qr.datastring() + "\n" )
                    
                    

