"""
 Tools for nonbonded interaction analysis.
"""

from cdsMatrix import CDSMatrix_double
chemTypeDist = CDSMatrix_double(0,0,0.)
from cdsVector import CDSVector_int
lookupVector = CDSVector_int(0)
exclList=[]

def initializeRadii():
    import xplor, os
    from simulationTools import mktemp
    tmpFilename=mktemp(suffix=".xplor")
    outputState=xplor.disableOutput()
    xplor.fastCommand("write params output=%s end" % tmpFilename)
    xplor.enableOutput(outputState)
    readXplorRadii(tmpFilename)
    os.unlink(tmpFilename)
    return
def readXplorRadii(filename):
    import re
    global chemTypeDist
    chemTypeRadius=[]
    chemTypeLookup = {}
    for line in open(filename).readlines():
        match = re.search(r"^\s*nonb.*\s+([a-z0-9_]+)\s+([0-9.]+)\s+([0-9.]+)",
                          line,re.IGNORECASE)
        if match:
            chemType = match.group(1)
            sigma = float( match.group(3) )
            i = len(chemTypeRadius)
            chemTypeRadius.append( sigma )
            chemTypeLookup[chemType] = i
            chemTypeDist.resize(i+1,i+1)
            for j in range(len(chemTypeRadius)):
                chemTypeDist[i,j] = 0.5**(5./6) * (chemTypeRadius[i]+
                                                   chemTypeRadius[j])
                chemTypeDist[j,i] = chemTypeDist[i,j]
                pass
            pass

        match = re.search(
            r"^\s*nbfi.*\s+([a-z0-9_]+)\s+([a-z0-9_]+)\s+([0-9.]+)\s+([0-9.]+)",
                          line,re.IGNORECASE)
        if match:
            i = chemTypeLookup[ match.group(1) ]
            j = chemTypeLookup[ match.group(2) ]
            A = float( match.group(3) )
            B = float( match.group(4) )
            sigma=0
            if B>0: sigma = (A/B)**(1./6)
            chemTypeDist[i,j] = 0.5**(5./6) * sigma
            chemTypeDist[j,i] = chemTypeDist[i,j]
            pass
        pass

    from xplor import simulation as sim
    lookupVector.resize( sim.numAtoms() )
    for i in range(sim.numAtoms()):
        lookupVector[i] = chemTypeLookup[ sim.atomByID(i).chemType() ]
        pass

    global exclList
    exclList=[]
    bondList=[]
    for i in range(sim.numAtoms()):
        exclList.append([])
        bondList.append([])
        pass

    import xplor
    outputState=xplor.disableOutput()
    xplor.fastCommand("set print none end")
    nbondNBXMod = int(xplor.fastCommand("param nbond ? end end","NBXMOD")[0])
    xplor.fastCommand("set print $prev_print_file end")
    xplor.enableOutput(outputState)
    if abs(nbondNBXMod)>1:
        for i in range(sim.numBonds()):
            pair = sim.bondPairByID(i)
            exclList[pair[0]].append( pair[1] )
            exclList[pair[1]].append( pair[0] )
            bondList[pair[0]].append( pair[1] )
            bondList[pair[1]].append( pair[0] )
            pass
        pass
    #add in atoms with 1-3 relationship
    if abs(nbondNBXMod)>2:
        newList=[]
        for i in range(sim.numAtoms()):
            newList.append(exclList[i][:])
            pass
        for i in range(sim.numAtoms()):
            for j in exclList[i]:
                for k in bondList[j]:
                    if k != i and not k in newList:
                        newList[i].append(k)
                        pass
                    pass
                pass
            pass
        exclList = newList
        pass
    #add in atoms with 1-4 relationship
    if abs(nbondNBXMod)>3:
        newList=[]
        for i in range(sim.numAtoms()):
            newList.append(exclList[i][:])
            pass
        for i in range(sim.numAtoms()):
            for j in exclList[i]:
                for k in bondList[j]:
                    if k != i and not k in newList:
                        newList[i].append(k)
                        pass
                    pass
                pass
            pass
        exclList = newList
        pass
    #add in explicitly excluded interactions
    if nbondNBXMod>0:
        #write out psf
        outputState=xplor.disableOutput()
        from simulationTools import mktemp
        tmpFilename=mktemp(suffix=".xplor")
        xplor.fastCommand("write psf output=%s end" % tmpFilename)
        xplor.enableOutput(outputState)
        # read in nbondexcl list
        psfFile=file(tmpFilename)
        psfFile.readline();        psfFile.readline()
        numTitle=int(psfFile.readline().split()[0])
        for i in range(numTitle):
            psfFile.readline()
            pass
        psf=psfFile.read()
        (before,after)=psf.split('!NNB')
        numNNB=int(before.split()[-1])
        combTable=map(lambda e:int(e),after.split('!NGRP')[0].split())[:-2]
        (nnbTable,idxTable)=(combTable[:numNNB],combTable[numNNB:])

        #update exclList
        idx=0
        for i in range(sim.numAtoms()):
            if idx!=idxTable[i]:
                for j in range(idx,idxTable[i]):
                    exclList[i].append(nnbTable[j]-1)
                    pass
                pass
            idx=idxTable[i]
            pass
        pass
                    
        
        
    return

defaultRadiusScale=0.9

def vdwViolations(threshold):
    """determine nbonded violations:
      atoms closer than the sums of the appropriate scaled vdw radii.

      The scale factor is taken from the XPLOR REPEl parameter, if it is used.
      Otherwise the variable defaultRadiusScale is used (default value is 0.8).

      The XPLOR NBXMod parameter is consulted and the appropriate 1-2,3,4
      interactions are excluded. Also, explicitly excluded interactions are
      excluded, if NBXMod>0.
    """
    from xplor import simulation as sim
    if len(lookupVector) != sim.numAtoms():
        initializeRadii()
        pass

    from atomSel import AtomSel
    atoms = AtomSel("not resname ANI")

    import xplor
    outputState=xplor.disableOutput()
    xplor.fastCommand("set print none end")
    try:
        radiusScale = float(xplor.fastCommand("param nbond ? end end",
                                              "repel")[0])
    except:
        radiusScale = defaultRadiusScale
        pass
    xplor.fastCommand("set print $prev_print_file end")
    xplor.enableOutput(outputState)
    distanceMatrix = CDSMatrix_double(chemTypeDist)
    distanceMatrix.scale( radiusScale )

    violations=[]
    from vec3 import norm
    from nbond import findCloseAtoms
    for (i,j) in findCloseAtoms(atoms,lookupVector,
                                distanceMatrix, threshold):
        if not j in exclList[i]:
            atomi = atoms[i]
            atomj = atoms[j]
            violations.append( Violation(atomi, atomj,
                                         norm(atomi.pos()-atomj.pos()),
                                         distanceMatrix[lookupVector[i],
                                                        lookupVector[j]]) )
            pass
        pass

    return violations

#def vdwViolationsPrint(threshold):
#    """print out vdw violation info, and return the number of violations.
#    """
#    violations = vdwViolations(threshold)
#
#    print
#    print "  Nonbonded Violations"
#    print "        threshold: %f" % threshold
#    print
#    print "%26s %26s %6s %6s" % ("atom1     ", "atom2     ",
#                                 "dist", "vdwDist")
#    print "_"*70
#
#    for v in violations:
#        print "%26s %26s %6.2f %6.2f" % (v.atomi.string(), v.atomj.string(),
#                                         v.dist, v.dist0)
#        pass
#        
#    return len(violations)

class Violation:
    def __init__(s,atomi,atomj,dist,dist0):
        s.atomi = atomi
        s.atomj = atomj
        s.dist  = dist
        s.dist0 = dist0
        return
    def name(s):
        return "( %s ) ( %s )" % (s.atomi.string(),s.atomj.string())
    def diff(s):
        return s.dist - s.dist0
    
        
