#! /usr/bin/env python

import os, sys, cmd, logging
from math import sqrt, log, exp

usage = """Query the LHAPDF library

Usage: %prog <cmd> <setname/num> <memnum> x Q [<flavourID>]

Retrieve PDF values from LHAPDF, given a set name, member number
and values for x and Q. This will return a list of xf(x,Q2) values
for each parton; specifying a parton name on the command line will
return the single appropriate xf(x,Q2) value for that flavour.

TODO:
  * clarify this documentation
  * allow pattern matching on listsets
  * use friendly names for flavours
"""

## Init LHAPDF
import lhapdf
try:
    lhapdf.setVerbosity(lhapdf.SILENT)
except:
    sys.stderr.write("Error when initialising LHAPDF")
    sys.exit(1)


def getPDFSetId(name):
    try:
        if ":" in name:
            pdffile = str(name.split(":")[0])
            pdfmember = int(name.split(":")[1])
        else:
            info = lhapdf.getPDFSetInfo(int(name))
            pdffile = info.file
            pdfmember = info.memberId
        return pdffile, pdfmember
    except:
        sys.stderr.write("Error in parsing PDF set ID: %s\n" % name)
        sys.exit(1)


def getPDFVals(lhafn, name, x, Q):
        PDFNAME, PDFNSET = getPDFSetId(name)
        try:
            lhapdf.initPDFSet(PDFNAME, PDFNSET)
            vals = lhafn(float(x), float(Q))
            return vals
        except:
            sys.stderr.write("Error when querying set %s:%d\n" % (PDFNAME, PDFNSET))
            sys.exit(1)


def printPDFVals(vals, flav=None):
    if flav is None:
        for val in vals:
            print val,
        print
    else:
        print vals[flav]



## Internal command parser
class LHAPDFCmd(cmd.Cmd):
    """Process commands that query the LHAPDF library."""

    intro = "Query the LHAPDF parton density collection"
    prompt = "lhapdf> "

    def _plotPDFFn(self, pdffn, argstr):
        try:
            import pylab as mpl
            _args = argstr.split(' ')
            name = _args[0]
            flav = int(_args[1])
            #
            PDFNAME, PDFNSET = getPDFSetId(name)
            lhapdf.initPDFSet(PDFNAME, PDFNSET)
            x_range = (lhapdf.getXmin(1), lhapdf.getXmax(1))
            Q_range = (sqrt(lhapdf.getQ2min(1)), sqrt(lhapdf.getQ2max(1)))
            Qs = []
            N = 4
            for i in range(N+1):
                logQDiff = log(Q_range[1]) - log(Q_range[0])
                logQ = log(Q_range[0]) + i/float(N)*logQDiff
                Qs.append(exp(logQ))
            if len(_args) > 2:
                print _args[2]
                x_range = eval(_args[2])
                if len(x_range) != 2:
                    raise ValueError("x_range '%s' is not a 2-element list" % x_range)
            if len(_args) > 3:
                print _args[3]
                Qs = eval(_args[2])
            #
            NX = 50
            xs = []
            for i in range(NX+1):
                logxDiff = log(x_range[1]) - log(x_range[0])
                logx = log(x_range[0]) + i/float(NX)*logxDiff
                xs.append(exp(logx))
            #print "x =", xs
            #print "Q =", Qs
            plots = {}
            for Q in Qs:
                xfxs = []
                for x in xs:
                    xfx = getPDFVals(pdffn, name, x, Q)[flav]
                    xfxs.append(xfx)
                print "Q, x, xf(x) =", Q, zip(xs, xfxs)
                plot = mpl.loglog(xs, xfxs)
                plots["Q = %2.1f GeV" % Q] = plot
            #mpl.legend(plots.keys(), plots.values())
            mpl.show()
        except:
            raise
            #return self.default(argstr)


    def _printPDFFn(self, pdffn, argstr):
        try:
            _args = argstr.split(' ')
            name, x, Q = _args[0:3]
            vals = getPDFVals(pdffn, name, x, Q)
            flav = None
            if len(_args) > 3:
                flav = int(_args[3])
            vals = getPDFVals(pdffn, name, x, Q)
            printPDFVals(vals, flav)
        except:
            return self.default(argstr)


    def emptyline(self):
        return False

    def do_EOF(self, line):
        return True
    def do_exit(self, argstr):
        return self.do_EOF(argstr)
    def do_quit(self, argstr):
        return self.do_EOF(argstr)


    def do_version(self, argstr):
        "Show library version code"
        print lhapdf.getVersion()
    def help_version(self, argstr):
        print "Show library version code"


    def do_listsets(self, argstr):
        "List PDF sets, with an optional pattern match argument"
        infos = lhapdf.getAllPDFSetInfo()
        import re
        regex = argstr or ".*"
        for info in infos:
            if re.search(regex, info.file, re.I) or re.search(regex, info.description, re.I):
                print info.toString()
    def help_listsets(self, argstr):
        print "List PDF sets, with an optional pattern match argument"


    def do_listids(self, argstr):
        "List PDF set ID codes for all members"
        infos = lhapdf.getAllPDFSetInfo()
        for info in infos:
            print info.id
    def help_listids(self, argstr):
        print "List PDF set ID codes for all members"


    def do_listnames(self, argstr):
        "List PDF set file names, with an optional pattern match argument"
        infos = lhapdf.getAllPDFSetInfo()
        import re
        regex = argstr or ".*"
        for info in infos:
            if re.search(regex, info.file, re.I) or re.search(regex, info.description, re.I):
                print "%s:%d" % (info.file, info.memberId)
    def help_listnames(self, argstr):
        print "List PDF set file names, with an optional pattern match argument"


    def do_xfx(self, argstr):
        "Calc nucleon xf(x,Q2) for set <name> with flavour index <flav>"
        self._printPDFFn(lhapdf.xfx, argstr)
    def help_xfx(self):
        print "xfx <setId|name:memberId> <x> <Q> [<flav>]"
        print "Calc nucleon xf(x,Q2) for set <name> with flavour index <flav>"


    def do_plotxfx(self, argstr):
        "Plot nucleon xf(x,Q2) for set <name> with flavour index <flav>"
        self._plotPDFFn(lhapdf.xfx, argstr)
    def help_xfx(self):
        print "plotxfx <setId|name:memberId> [<x> <Q>] [<flav>]"
        print "Plot nucleon xf(x,Q2) for set <name> with flavour index <flav>"


    def do_xfxa(self, argstr):
        "Calc nuclear xf(x,Q2) for set <name> with flavour index <flav>"
        self._printPDFFn(lhapdf.xfxa, argstr)
    def help_xfxa(self):
        print "xfxa <setId|name:memberId> <x> <Q> [<flav>]"
        print "Calc nuclear xf(x,Q2) for set <name> with flavour index <flav>"


    def do_xfxp(self, argstr):
        "Calc photon xf(x,Q2) for set <name> with flavour index <flav>"
        self._printPDFFn(lhapdf.xfxp, argstr)
    def help_xfxp(self):
        print "xfxp <setId|name:memberId> <x> <Q> [<flav>]"
        print "Calc photon xf(x,Q2) for set <name> with flavour index <flav>"


    def do_xfxphoton(self, argstr):
        "Calc MRST nucleon QED xf(x,Q2) for set <name> with shifted flavour index <flav>"
        self._printPDFFn(lhapdf.xfxphoton, argstr)
    def help_xfxphoton(self):
        print "xfxphoton <setId|name:memberId> <x> <Q> [<flav>]"
        print "Calc MRST nucleon QED xf(x,Q2) for set <name> with shifted flavour index <flav>"



if __name__ == '__main__':
    ## Parse command line options
    from optparse import OptionParser
    parser = OptionParser(usage=usage)
    parser.add_option("-q", "--quiet", help="Suppress normal messages", dest="LOGLEVEL",
                      action="store_const", default=logging.INFO, const=logging.WARNING)
    parser.add_option("-v", "--verbose", help="Add extra debug messages", dest="LOGLEVEL",
                      action="store_const", default=logging.INFO, const=logging.DEBUG)
    opts, args = parser.parse_args()

    ## Configure logging
    try:
        logging.basicConfig(level=opts.LOGLEVEL, format="%(message)s")
    except:
        logging.getLogger().setLevel(opts.LOGLEVEL)
        h = logging.StreamHandler()
        h.setFormatter(logging.Formatter("%(message)s"))
        logging.getLogger().addHandler(h)

    ## Run command parser
    if not len(args):
        LHAPDFCmd().cmdloop()
    else:
        wholecmd = " ".join(args)
        LHAPDFCmd().onecmd(wholecmd)
