"""high-level refinement tools
"""

from sys import modules
from potList import PotList
from pdbTool import PDBTool


class StructureLoop:
    """ class which performs loop over structure calculations.
    Constructor: StructureLoop()
      arguments:
        numStructures       -    number of structures to calculate
        startStructure      -    id of the first structure to calculate
        structureNums       -    sequence of explicit structure numbers to
                                 calculate. 
        structLoopAction    -    a user-defined function which takes
                                 one argument: an instance of this class.
                                 If this argument is omitted, new structures
                                 are not calculated. Rather, existing structures
                                 in files specified by pdbTemplate are read-in,
                                 and analyzed.
        pdbTemplate         -    template string used to create a filename
                                 for the pdbFile method. The filename is
                                 generated using the makeFilename method.
                                 Note that the template string should always
                                 include the STRUCTURE literal so that distinct
                                 structure files are generated.
                                 
      if numStructures<0,   existing files matching pdbTemplate are processed,
      starting at startStructure, and stopping when a file does not exist.
      This mode of operation does not work with structure parallelism, but
      it does work with ensemble parallelism.

     There are additional arguments if you would like an average structure
     to be calculated. If averaging is enabled, the output structure files
     will be fit to each other.
     
       averageFilename      - name for output structure file. If not specified,
                              the average structure will not be calculated.
       averagePotList       - potential terms to use for average structure
                              calculation. These terms are reported on in the
                              .stats file.
       averageRegularize    - flag determining whether or not structure
                              regularization (gradient minimization) is
                              carried out on the average structure, by
                              minimizing against averagePotList.
                              [default: True]
       averageFixedRegions  - sequence of regions held fixed in space during
                              structure regularization.
       averageRigidRegions  - sequence of regions held rigid during structure
                              regularization.
       averageSortPots      - potential terms used for sorting structures. The
                              top fraction or number of structures is reported
                              on in the .stats file.
                              [defaults to averagePotList]
       averageCrossTerms    - potential terms to report on, but not to use
                              in refinement of average structure.
       averageContext       - function to call to setup average minimization
                              run.  
       averageFitSel        - atom selection used to fit structures in
                              calculation of average structure [name CA].
       averageCompSel       - atom selection used for heavy atom rmsd
                              comparison metric [not (name H* and not PSEUDO)].
       averageTopFraction   - fraction of structures to use in calculation of
                              average structure. The structures are sorted
                              by energy, and the top fraction is retained.
                              [1 - all structures].
       averageTopNum        - number of structures to use in calculation of
                              average structure. Specify only one of 
                              averageTopFraction or averageTopNum.

       averageAccept        - function to call to assess whether it is
                              acceptable, i.e. meets violation, rmsd
                              requirements. The function has a single argument:
                              averagePotList [defaults to accepting all
                              structures.]

       averageRefineSteps   - number of minimization steps to take during
                              regularization refinement of the average
                              structure [50]

       genViolationStats    - flag controlling whether statics are gathered
                              (over all structures) on which restraints are
                              violated most often. The results are collected
                              in a file named pdbTemplate.stats. Statistics
                              will be gathered for all terms in
                              averagePotList, so this attribute must be
                              specified for this facility to work.
       averageRestrain      - flag the control the inclusion of probDist energy
                              potential in calculating the average structure
                              from the ensemble. averageRestrainSel helps in
                              selecting atoms from which the density map is
                              created. inconsistentAveStruct is a flag that
                              get set to 1 if the energy of calculated average
                              structure is greater than the ensemble.
                              

    method: run():
       performs numStructures loop of action. In each pass, the coordinates
       of the current Simulation are reset to their initial values and the
       instance variable count is incremented. Also, the global random seed
       is incremented. If the current simulation is an EnsembleSimulation,
       the seed-setting takes this into account.

       After run() has completed, average structure coordinates are left in
       the current Simulation, if they have been calculated. If restraint
       statistics are generated, the StructureLoop instance will have the
       following members when run() returns:
         restraintStats 
         restraintStatsCross
       These are <m restraintStats>.RestraintStats objects corresponding to
       potential terms in averagePotList and AverageCrossTerms, respectively.
       The precision of the calculated structures is stored in the members
         fitRMSD
         compRMSD
       corresponding to averageFitSel and averageCompSel, respectively.
       Also, the cpu time spent within the run() method will be contained
       in the members
         cpuTime
         cpuTimes
         cpuTimeTot
       The first contains the local process's cpu time, the second is an array
       of times from each process, and the third is the sum.
    """
    def __init__(s,numStructures=-1,startStructure=0,
                 structureNums=[],
                 structLoopAction=0,pdbTemplate="",
                 genViolationStats=0,
                 averageFilename="",
                 averagePotList=PotList(),
                 averageRegularize=True,
                 averageFixedRegions=[],
                 averageRigidRegions=[],
                 averageSortPots=None,
                 averageCrossTerms=[],
                 averageContext=lambda : 1,
                 averageFitSel="name CA",
                 averageCompSel="not name H* and not PSEUDO",
                 averageTopFraction=1,
                 averageTopNum=    -1,
                 averageAccept=lambda potList: 1,
                 averageRefineSteps=50,
                 averageRestrain=False,
                 averageRestrainSel="name CA or name C or name N or name O"):
        import sys
        from simulation import Simulation_currentSimulation
        import xplor
        if numStructures>0 and len(structureNums):
            raise Exception("specify only one of numStructures or structureNums")
        if len(structureNums):
            numStructures = len(structureNums)
        s.numStructures=numStructures
        s.structLoopAction = structLoopAction
        s.pdbTemplate = pdbTemplate
        s.processID = xplor.p_processID
        proc = s.processID
        #if env.has_key("NUM_THREADS"):
        #    print "num_threads =", env["NUM_THREADS"]
        s.numProcs = xplor.p_numProcs

        s.genViolationStats=genViolationStats
        s.averageRestrain=averageRestrain
        s.averageRestrainSel=averageRestrainSel
        s.averageFilename =averageFilename
        s.averagePotList = convertToPotList( averagePotList )
        s.averageRegularize = averageRegularize
        s.averageFixedRegions = averageFixedRegions
        s.averageRigidRegions = averageRigidRegions
        if not averageSortPots:
            averageSortPots = averagePotList
            pass
        s.averageSortPots = convertToPotList( averageSortPots )
        s.averageCrossTerms=convertToPotList( averageCrossTerms )
        s.averageContext   =averageContext 
        s.averageFitSel    =averageFitSel  
        s.averageCompSel   =averageCompSel 
        s.averageTopFraction=averageTopFraction
        s.averageTopNum     =averageTopNum
        s.averageAccept     =averageAccept
        s.averageRefineSteps=averageRefineSteps
        s.inconsistentAveStruct=0

        sim = Simulation_currentSimulation()
        s.esim = 0
        s.initCoords = sim.atomPosArr()
        s.skip=0
        s.structNum=-1
        if ( sim.type() == "EnsembleSimulation"):
            from ensembleSimulation import EnsembleSimulation_currentSimulation
            s.esim = EnsembleSimulation_currentSimulation()
            s.skip = s.esim.member().memberIndex()
            pass
        
        s.procStructMap=[]
        if numStructures>=0:
            # this logic used if number of jobs is greater than the desired 
            # number of structures
            s.numProcs = min(s.numProcs,numStructures)
            if proc >= s.numProcs:
                print 'StructureLoop: this process has no work. Exiting...'
                xplor.p_numProcs=-1 # to avoid barrier error in xplorFinal:34
                sys.exit()
                pass

            if not structureNums:
                structureNums = range(startStructure,
                                      startStructure+numStructures)
                pass
            
            for i in range(s.numProcs):
                i_start = (i     * numStructures) / s.numProcs
                i_stop  = ((i+1) * numStructures) / s.numProcs
                s.procStructMap.append( structureNums[i_start:i_stop] )
#                (i     * numStructures) / s.numProcs
#                start = (i     * numStructures) / s.numProcs + startStructure
#                stop  = ((i+1) * numStructures) / s.numProcs + startStructure
#                s.procStructMap.append( (start,stop) )
                pass
            s.structNums = s.procStructMap[proc]
        else:
            # setup for processing existing files-
            if s.esim: s.skip=-2000
            #FIX: check this
            s.start=startStructure
            s.stop=-1
            pass
        
        # barrier across all parallel processes- this is required after
        # extra processes have shutdown so that further communication is
        # successful.
        from ensembleSimulation import commBarrier
        commBarrier(xplor.p_comm)
        return

    def run(s):
        import simulationWorld
        from simulation import Simulation_currentSimulation
        from inspect import currentframe, getouterframes
        simWorld = simulationWorld.SimulationWorld_world()
        sim = Simulation_currentSimulation()
        initCoords = sim.atomPosArr()
        s.initSeed = simWorld.random.seed()
        s.cpuTime = simWorld.cpuTime()
        
        cnt = 0
        
        while 1 and s.structLoopAction:
            if cnt==len(s.structNums): break
            if len(s.structNums):
                s.structNum = s.structNums[cnt]
            elif s.numStructures<0:
                s.structNum = s.start + cnt

                #logic for processing existing files- and the number
                # of structures is not specified.
                #stop if the file doesn't exist.
                try:
                    import os
                    os.stat( s.makeFilename(s.pdbTemplate) )
                except:
                    break
                pass
            # count is for backward compatibility
            s.count = s.structNum
            if simWorld.logLevel()!='none':
                print "StructureLoop: calculating structure %d" % s.structNum
            s.randomSeed = s.initSeed + s.structNum + s.skip*s.numStructures
            simWorld.setRandomSeed( s.randomSeed )
            sim.setAtomPosArr( initCoords )
            
            if type(s.structLoopAction) == type("string"):
                #structLoopInfo = s
                #exec( s.structLoopAction, vars( modules["__main__"] ), locals() )
                global_dict = getouterframes( currentframe() )[1][0].f_globals
                local_dict = getouterframes( currentframe() )[1][0].f_locals
                local_dict["structLoopInfo"] = s
                exec( s.structLoopAction, global_dict, local_dict )
                del local_dict["structLoopInfo"]
            else:
                s.structLoopAction(s)
                pass
            if s.esim: s.esim.barrier()
            cnt += 1
            pass

        import xplor
        comm = xplor.p_comm
        if s.averageFilename or s.genViolationStats:
            s.genAveStats(comm,sim)

        from ensembleSimulation import singleThread, multiThread
        s.cpuTime = simWorld.cpuTime() - s.cpuTime
        s.cpuTimeTot=0
        s.cpuTimes=[]
        if singleThread():
            s.cpuTimes=comm.collect( s.cpuTime )
            for time in s.cpuTimes:
                s.cpuTimeTot += time
                pass
            pass
        multiThread()
        return s

    def genAveStats(s,comm,sim):
        """ generate averages, statistics for calculated structures
        """
        #FIX: this method is a mess. It should be split into a separate
        # function which performs the analysis - it could then be used
        # in standalone mode without StructureLoop.
        #
        from ensembleSimulation import singleThread, multiThread
        from ensembleSimulation import commBarrier
        procs = commBarrier(comm)
        
        from atomSelAction import Fit
        from atomSel import AtomSel
        if s.processID==0:
            structIDs=[]
            for proc in range(s.numProcs):
                pStructIDs=s.procStructMap[proc]
                if proc in procs:
                    structIDs += pStructIDs
                else:
                    host = comm.info(proc).remoteHost
                    print "StructureLoop: no results obtained from ", \
                          "process %d running on %s" % (proc,host)
                    print "\tskipping structures %s" % str(pStructIDs)
                    pass
                pass
            structs = []
            structs2= []
            from restraintStats import RestraintStats
            if s.genViolationStats:
                rStats = RestraintStats()
                rStatsCross = RestraintStats()
                pass
            fitStruct=None
            for s.structNum in structIDs:
                try:
                    #read files
                    s.pdbFile().read()
                    sim.sync()
                    s.averageContext()
                    #calc energy
                    try:
                        sortEnergy   = s.averageSortPots.calcEnergy()
                        sortViol=0 
                        if s.averageRestrain: 
                            sortViol     = s.averageSortPots.violations()
                            
                    except AttributeError:
                        sortEnergy = 0
                        pass
                    if not fitStruct: fitStruct = sim.atomPosArr()
                    AtomSel("known").apply( Fit(fitStruct,s.averageFitSel) )
    
                    s.pdbFile().setMakeBackup(0) # don't backup unfit structures
                    s.pdbFile().write()
                    structs2.append((s.pdbFile().filename(),sortEnergy,
                                                            sortViol))
                    structs.append( (sortEnergy, s.pdbFile().filename(),
                                     sim.atomPosArr()) )
                except IOError:
                    #file is missing, and can't be read
                    pass
                pass

            if s.averageFilename or s.genViolationStats:

                #sort structures by energy
                structs.sort() #( lambda x,y: cmp(x[0],y[0]) )

                numCalcdStructs=len(structs)

                #take top goodFraction structures
                from math import ceil
                if s.averageTopNum>=0:
                    goodNum = s.averageTopNum
                else:
                    goodNum = int(ceil(len(structs) * s.averageTopFraction))

                # filter by the accept function
                def applyAccept(struct):
                    sim.setAtomPosArr( struct[2] )
                    return s.averageAccept(s.averagePotList)
                
                structs = filter( lambda struct: applyAccept(struct), structs )
                
                structs = structs[0:min(goodNum,len(structs))]

                if not structs:
                    print "no acceptable structures:"
                    print "   averages and statistics not collected!"
                    return

                #  - perform average
                ( aveCoords,fRMSD_array,cRMSD_array,bFactors, remarks ) = \
                  calcAverageStruct( map(lambda x: x[2], structs) ,
                                     fitSel=s.averageFitSel,
                                     compSel=s.averageCompSel,
            potList= s.averagePotList if s.averageRegularize else [],
                                     regularizeSteps=s.averageRefineSteps,
                                     averageRestrainSel=s.averageRestrainSel,
                                     averageRestrain=s.averageRestrain,
                                     fixedRegions=s.averageFixedRegions,
                                     rigidRegions=s.averageRigidRegions,
                                     )
                                      
                   
            
                structureDetails = "%-30s  %10s  %9s  %9s\n" % ("","sort ",
                                                                "fit ",
                                                              "   comparison")
                structureDetails += "%-30s  %10s  %9s  %9s\n" % ("Filename:",
                                                                 "energy",
                                                                 "RMSD","RMSD")
                
                fRMSD_ave=0.
                cRMSD_ave=0.
                for i in range(len(structs)):
                    (e,file,struct)  = structs[i]
                    fRMSD = fRMSD_array[i]
                    cRMSD = cRMSD_array[i]
                    fRMSD_ave += fRMSD
                    cRMSD_ave += cRMSD
                    
                    #sim.setAtomPosArr(struct)
                    
                    structureDetails += "%-30s  %10.2f  %9.3f  %9.3f\n" % \
                                        (file,e,fRMSD,cRMSD)
                    pass

                if len(structs):
                    fRMSD_ave /= len(structs)
                    cRMSD_ave /= len(structs)
                    pass
                
                structureDetails += "\n%-30s  %10s  %9.3f  %9.3f\n" % \
                                    ("average:","",fRMSD_ave,cRMSD_ave)

                s.fitRMSD  = fRMSD_ave
                s.compRMSD = cRMSD_ave
                ###################################################
                if s.averageRestrain:
                    sim.setAtomPosArr(aveCoords)
                    aveEnergy= s.averageSortPots.calcEnergy()
                    aveViol  = s.averageSortPots.violations()
                    for i in range(len(structs)):
                        (e,file,struct)  = structs[i]
                        (filename,ene,viol)=structs2[i]
                        if(((aveEnergy>ene)or(aveViol>viol))and(filename==file)and(s.averageRestrain==True)):
                             s.inconsistentAveStruct=1
                             pass
                    if s.inconsistentAveStruct:
                        structureDetails+= "\n"
                        structureDetails += "Warning: Calculated Average have energy or violation \n"
                        structureDetails += "greater than the calculated structures.\n"
                        structureDetails += "%-30s  %9.5s  %9.1s  \n" % ("Filename:",
                                                                 "Energy","Violations")
                        structureDetails += "%-30s  %10.2f  %9.3f  \n" % \
                                         ("average:",aveEnergy,aveViol)
                        structureDetails+="\n"
                ####################################################################
                structureDetails += "\n  fit selection: "
                structureDetails += s.averageFitSel + '\n'
                structureDetails += "\n  comparison selection: "
                structureDetails += s.averageCompSel + '\n'
                

                pass

            if s.genViolationStats:
                for struct in structs:
                    sim.setAtomPosArr( struct[2] )
                    s.averageContext()
                    rStats.accumulate(s.averagePotList)
                    rStatsCross.accumulate(s.averageCrossTerms)
                    pass
                
                out = "\n  Results for the top %d (of %d) structures\n\n" % \
                       (len(structs),numCalcdStructs)
                out+=rStats.summarizeTerms()
                if s.averageCrossTerms:
                    out += "\n Cross Validated Terms\n\n"
                    out+=rStatsCross.summarizeTerms()
                    pass
                extraQuantityStats = rStats.summarizeExtraQuantities()
                if extraQuantityStats!='':
                    out += "\n Statistics for Additional Quantities\n\n"
                    out+=extraQuantityStats
                    pass
                extraQuantityStats = rStatsCross.summarizeExtraQuantities()
                if extraQuantityStats!='':
                    out += "\n Statistics for Additional Cross-Validated Quantities\n\n"
                    out+=extraQuantityStats
                    pass
                out += rStats.summarizeViolations()
                if s.averageCrossTerms:
                    out += "\n Cross Validated Terms\n\n"
                    out+=rStatsCross.summarizeViolations()
                    pass

                out+="\n Energy terms used for sorting:\n "
                for pot in s.averageSortPots:
                    out += " " + pot.instanceName()
                    pass
                out += "\n"

                out += "\n" + structureDetails
                if singleThread():
                    open(genFilename(s.pdbTemplate,'##')+".stats",
                         'w'                                     ).write(out)
                    pass
                s.restraintStats      = rStats
                s.restraintStatsCross = rStatsCross
                multiThread()
                pass

            if s.averageFilename:
                sim.setAtomPosArr( aveCoords )
                #FIX: should fixupCovalentGeom be called?
                #  --it must be proven to be quite robust

                remarks += structureDetails


                from atomSel import AtomSel
                avePDB = PDBTool(s.makeFilename(s.averageFilename))
                for atom in AtomSel('known'):
                    avePDB.setAux2(atom,bFactors[atom.index()])
                avePDB.addRemarks(remarks)
                pl = s.averagePotList
                avePDB.addRemarks(analyze(pl,s.averageCrossTerms,
                     outFilename=s.makeFilename(s.averageFilename+'.viols')))
                if singleThread():
                    avePDB.write()
                    pass
                multiThread()
                pass
            
            pass
        return
    
    def pdbFile(s):
        """ return a PDBTool object whose filename is generated by
            makeFilename() from the pdbTemplate argument of the class
            constructor.
            """
        #a reference to the PDBTool objet is saved here to allow
        #access like info.pdbFile().filename()
        # if this is not done, the PDBFile object is reaped before the
        # filename() string is returned: garbage out
        if (not hasattr(s,'pdbFileObj') or
            s.pdbFileObj.filename() != s.filename()):
            s.pdbFileObj = PDBTool( s.filename() )
        return s.pdbFileObj
    def filename(s):
        """ return filename generated by makeFilename() from the pdbTemplate
        argument of the class constructor.
        """
        if not s.pdbTemplate:
            raise Exception(
                "pdbFile: pdbTemplate must be specified in the\n\t" +
                   "StructureLoop constructor")
        return s.makeFilename(s.pdbTemplate)
    def makeFilename(s,template):
        """ create a filename given a template. In the template:
             the following substitutions are made:
                 SCRIPT    -> name of input script (without .py suffix)
                 STRUCTURE -> the structure number
                 MEMBER    -> the ensemble member index
        """
        from simulation import Simulation_currentSimulation
        sim = Simulation_currentSimulation()
        memberIndex = 0
        if sim.type() == "EnsembleSimulation":
            from ensembleSimulation import EnsembleSimulation_currentSimulation
            esim=EnsembleSimulation_currentSimulation()
            memberIndex = esim.member().memberIndex()
            pass
        return genFilename(template,s.structNum,memberIndex)
    def analyze(s,potList,altPotList=PotList()):
        """print violation info to violations file
        and return summary information as a string
        """
        return analyze(potList,altPotList,
                       s.filename() + ".viols")
        return
    def writeStructure(s,potList=None, 
                       altPotList=None,
                       extraRemarks=""):
        """perform analysis using analyze(), then write a structure with
        the analysis information included as remarks.

        The filename is generated from the pdbTemplate argument of the
        StructureLoop constructor.

        A summary is written out for each term in potList (and in altPotList,
        if specified), and more detailed violation information is output
        to a file named filename + '.viols' .

        If potList is not specified, it defaults to s.averagePotList.
        If altPotList is not specified, it defaults to s.averageCrossTerms.

        extraRemarks, if specified, is extra (string) information in the
        REMARKS section of the PDB file, printed after the usual summary.

        """
        if not potList:    potList = s.averagePotList
        if not altPotList: altPotList = s.averageCrossTerms
        import protocol
        remarks = s.analyze(potList,altPotList)
        remarks += "  generated by"
        remarks += " simulationTools.StructureLoop.writeStructure\n"
        remarks += "  seed info: initial: %d structure id: %d\n\n" % \
                   (protocol.initialRandomSeed(), s.structNum)

        if extraRemarks:
            remarks += "\n" + extraRemarks
            pass

        import os, pwd, time
        user="unknown"
        try:
            user = pwd.getpwuid(os.geteuid())[0]
        except (IndexError,KeyError):
            from os import environ
            try:
                user = environ["LOGNAME"]
            except KeyError:
                pass
            pass
        remarks += "-"*65 + '\n'
        remarks += "user: %-15s            date:" % user
        remarks += " " + time.asctime() + '\n'
        remarks += "-"*65 + '\n'
        

        outFile=s.pdbFile()
        outFile.addRemarks( remarks )
        outFile.write()
        return
    def structureNum(s):
        " return the current structure number, or -1."
        return s.structNum
    pass

def convertToPotList(obj):
    """ convert a single potential term or a sequence to a
    <m potList>.PotList, if necessary. 
    """
    from potList import PotList
    try:
        len(obj)
    except TypeError:
        obj=[obj]
        pass
    if not "PotList" in obj.__class__.__name__:
        pl = PotList()
        for p in obj:
            pl.append(p)
            pass
        obj = pl
        pass
    return obj
        

def calcAverageStruct(structs,fitSel,compSel,
                      potList=[],regularizeSteps=50,
                      averageRestrainSel="",averageRestrain=False,
                      fixedRegions=[],
                      rigidRegions=[]): 
    """compute unregularized average structure given structs.
    The structures are first fit using fitSel, and analysis is performed
    using compSel.

    For homogeneous <m ensembleSimulation>.EnsembleSimulations with Ne>1, this
    routine calculate the average of the ensemble averages.

    if potList is set, the sum of the specified terms 
    will be will be minimized with respect to the average coordinates, after
    the straight average structure has been calculated.

    fixedRegions and rigidRegions arguments are passed to regularizeRefine.
    """

    from cdsVector import vec_norm, sqrt
    from atomSelAction import Fit, RMSD
    from atomSel import AtomSel
    #Local copy of potlist
    avepotList=list(potList)
    from simulation import Simulation_currentSimulation
    sim = Simulation_currentSimulation()
                                      
    sim.setAtomPosArr( structs[0] )
    if averageRestrain:
       atmCoordList=[]
       atmCoordList.append(sim.atomPosArr())
    fitTo = ensembleAvePos()
    sim.setAtomPosArr( fitTo )
    aveCoords = sim.atomPosArr() # must be separate copy
    var = vec_norm(aveCoords)**2
    for struct in structs[1:]:
        sim.setAtomPosArr(struct)
        sim.setAtomPosArr( ensembleAvePos() )
        AtomSel("known").apply( Fit(fitTo,fitSel) )
        if averageRestrain: 
           atmCoordList.append(sim.atomPosArr()) 
        aveCoords += sim.atomPosArr()
        var += vec_norm(sim.atomPosArr())**2
        pass
    if averageRestrain: 
       atmSel=averageRestrainSel
       #Calculating the target Grid
       from atomProb import AtomProb
       from selectTools import convertToAtomSel
       map = AtomProb(convertToAtomSel(atmSel),atmCoordList)      
       map.calc()
       from atomProb import Grid
       targetMap=map.getGrid()
       # Converting Grid to DensityGrid
       from atomDensity import DGridhelp
       Dhelp=DGridhelp()
       DMap=Dhelp.atomProb2atomDensity(targetMap)
       #from probDistPot import ProbDistPot
       #prob=ProbDistPot("prob",DMap,atmSel)
       #prob.setScale(200)
       sim.setAtomPosArr(aveCoords)
       from probDistPotTools import create_probDistPot
       prob=create_probDistPot("prob",DMap,AtomSel(atmSel),\
                                     potType="cross_correlation",scale=100)
       avepotList.append(prob)
    aveCoords /= len(structs)
    var /= len(structs)
    bFactors = var - vec_norm(aveCoords)**2
    from math import pi
    bFactors.scale( 8*pi**2 )
    sim.setAtomPosArr( aveCoords )

    #get the tensor atoms in reasonable shape-
    #   averaging will have scrambled them
    from varTensorTools import getRegisteredTerms, calcTensor
    for t in getRegisteredTerms(sim):
        calcTensor(t)
#        t.setFreedom("varyDa, varyRh")       #allow all tensor parameters float
        pass
    
    if len(avepotList)>0:
        minimizeRefine(avepotList,
                       fixedRegions=fixedRegions,
                       rigidRegions=rigidRegions,
                       refineSteps=regularizeSteps)
        pass

    # make sure final tensors are consistent with structure
    for t in getRegisteredTerms(sim):
        calcTensor(t)
        pass
    
    aveCoords = sim.atomPosArr()
        

    fitTo=aveCoords
    aveRMSD=0.
    aveRMSDcomp=0.
    rmsdArray=[]
    rmsdCompArray=[]
    for struct in structs:
        sim.setAtomPosArr(struct)
        sim.setAtomPosArr( ensembleAvePos() )
        AtomSel("known").apply( Fit(fitTo,fitSel) )
                    
        comparer=RMSD(fitTo)
        AtomSel(fitSel).apply(comparer)
        rmsd = comparer.rmsd()
        aveRMSD += rmsd
        rmsdArray.append( rmsd )
                    
        AtomSel(compSel).apply(comparer)
        rmsdComp = comparer.rmsd()
        rmsdCompArray.append( rmsdComp )
        aveRMSDcomp += rmsdComp
        pass
    
    aveRMSD      /= len(structs)
    aveRMSDcomp /= len(structs)

    remarks  = "average structure over %d files\n" % len(structs)
    remarks += "fitted using atoms: %s\n" % fitSel
    remarks += "RMSD diff. for fitted atoms: %f\n" % aveRMSD
    remarks += "comparison atoms: %s\n" % compSel
    remarks += "RMSD diff. for comparison atoms: %f\n" % aveRMSDcomp
    remarks += "B array (last column) is rms diff from mean\n"
    return (aveCoords,rmsdArray,rmsdCompArray,bFactors, remarks)

def ensembleAvePos():
    """return the (unregularized) average coordinates of the current ensemble
    For heterogeneous ensembles, this average is not meaningful, and the
    full ensemble is returned.
    """
    from simulation import Simulation_currentSimulation
    sim = Simulation_currentSimulation()
    esim=0
    if ( sim.type() == "EnsembleSimulation"):
        from ensembleSimulation import EnsembleSimulation_currentSimulation
        esim = EnsembleSimulation_currentSimulation()
        pass
    if not esim: return sim.atomPosArr()

    ave = esim.members(0).weight() * esim.members(0).atomPosArr()
    for i in range(1,esim.size()):
        if esim.members(i).numAtoms() != esim.members(i-1).numAtoms():
            return sim.atomPosArr()
        ave += esim.members(i).weight() * esim.members(i).atomPosArr()
        pass
    return ave
    

def minimizeRefine(potList,
                   refineSteps=50,
                   xplorPots=['BOND','ANGL','IMPR'],
                   scaleMultiplier=0.001,
                   rigidRegions=(),
                   fixedRegions=(),
                   translateRegions=(),
                   ):
    """ simple refinement using gradient minimization in Cartesian coordinates.
    Some potential terms are held fixed during minimization, while others are
    scaled up from a smaller value progressively, during each round of
    minimization.

    refineSteps specifies how many rounds of minimization to perform.
    xplorPots are XPLOR terms which are to always be present, and for which
              the scale constant is held constant.
    scaleMultiplier specifies the initial value for the scale constant.

    rigidRegions specifies selections of atoms which do not move relative to
    each other.

    fixedRegions specifies selections of atoms which do not move at all.

    translateRegions specifies selections of atoms which can translate
    as a rigid body, but not rotate.
    """
    pots = flattenPotList(potList)

    # refine here
    #  remove bond, angle, impr, vdw terms- if they exist- use them as is.
    # if they don't exist, add them in with default scale values.
    #
    # for rest of terms, loop over them, with MultRamp running from .01 .. 1
    # times the nominal scale values.
    from potList import PotList
    minPots = PotList()
    hasReqdTerms = {}
    for p in xplorPots: hasReqdTerms[p] = 0
    rampedParams = []
    for pot in pots:
        reqdTerm=0
        for pType in xplorPots:
            if potType(pot) == 'XplorPot' and pot.instanceName() == pType:
                minPots.append(pot)
                hasReqdTerms[pType] = 1
                reqdTerm=1
                continue
            pass
        if reqdTerm: continue
        minPots.append( pot )
        rampedParams.append( MultRamp( scaleMultiplier*pot.scale(),
                                       pot.scale(),
                                       "minPots['%s'].setScale(VALUE)"%
                                       pot.instanceName() ) )
        pass
    from xplorPot import XplorPot
    for pType in xplorPots:
        # FIX: if this is required for EnsembleSimulations, it will break
        if not hasReqdTerms[pType]: minPots.append( XplorPot(pType) )
        pass
    
    from ivm import IVM
    import protocol

    minc = IVM()
    for aSel in rigidRegions: minc.group(aSel)
    for aSel in fixedRegions: minc.fix(aSel)
    for aSel in translateRegions:
        minc.group(aSel)
        minc.hinge('translate',aSel)
        pass
    protocol.cartesianTopology(minc)

    if refineSteps<=0:
        return
    
    #run initial powell minimization with Bond & Angles only - 
    # so impropers are well-defined.
    protocol.initMinimize(minc,
                           potList=[minPots['BOND'],minPots['ANGL']],
                           numSteps=200)

    minc.run()

    protocol.initMinimize(minc,
                          numSteps=100,
                          potList=minPots,
                          dEPred=10)


    from simulationTools import AnnealIVM
    AnnealIVM(initTemp =0, finalTemp=0,
              numSteps=refineSteps,
              ivm=minc,
              rampedParams = rampedParams).run()
                             
    return

#used when creating temporary files to avoid races in EnsembleSimulation
#calculations
#threadSuffix=""
def mktemp(suffix='', prefix='tmp', dir=None):
    """ return a temporary file name which is unique for each process and
    thread within an ensemble simulation.
    """
    import tempfile
    from ensembleSimulation import threadSuffix
    suffix += '.' + threadSuffix
    return tempfile.mktemp(suffix,prefix,dir)

def potType(pot):
    """ return the potential's type.
    For most potential terms, the type is given by the potName() accessor,
    exception:
      for <m avePot>AvePot, potType(subPot()) is returned.
    """
    if pot.potName().startswith('ave_'):   return potType(pot.subPot())
    return pot.potName()
    


def genFilename(template,
                structure=0,
                member=0):
        """ create a filename given a template. In the template:
             the following substitutions are made:
                 SCRIPT    -> name of input script (without .py suffix)
                 STRUCTURE -> the structure argument
                 MEMBER    -> the member argument
        """
        import re

        from sys import argv

        #convert to string
        if isinstance(structure,int): structure = "%d" % structure

        import xplor
        scriptName =  re.sub(r'\.[^.]*$','',xplor.scriptName)
        if not scriptName:
            scriptName='stdin'
            pass
        filename = template.replace("SCRIPT","%s" %scriptName)

        filename = filename.replace("STRUCTURE",structure)
        filename = filename.replace("MEMBER","%d" % member)
        return filename
    

def waitForProcesses(numProcs,esim=0,
                     timeOut=-1):
    """

    DEPRECATED: will be removed soon.
    
    wait for all processing calculating structures in parallel to complete
    their current computation.

    If timeOut>=0, this specifies how long (in sec) to wait before abandoning
    a process.

    This routine acts like a barrier, suspending computation until all
    processes have reached this point.

    this routine assumes write access to the local directory, and that this
    directory is shared by all processes.
    """
    pollInterval = 5 #secs to wait between polling for process completion
    if numProcs==1:
        return 0

    sharedErr=0
    err=0
    if esim:
        sharedErr = EnsembleSharedObj(esim,0)
    # only a single thread should execute this code
    from ensembleSimulation import singleThread, multiThread
    from os import environ as env
    if singleThread():
        procs = [0]

        processID = int( env["XPLOR_PROCESS"] )
        ppid = int( env["XPLOR_PPID"] )
        
        import re
        from sys import argv
        try:
            scriptName =  re.sub(r'\.[^.]*$','',argv[0])
        except IndexError:
            scriptName = "stdin"
            pass
        if scriptName.endswith(__name__):
            scriptName = "stdin"
            pass

        barrierFilePref = ".%s-barrier-%d-" % (scriptName,ppid)
        barrierFile = barrierFilePref+"%s" % processID
        
        #remove file if it already exists
        import os, time
        try:
            while 1:
                if open(barrierFile).read() == "run.\n": break
                time.sleep(pollInterval)
                pass
            #        print "waitForProcesses: WARNING: barrier file already exists!"
            #        os.unlink(barrierFile)
        except IOError:
            pass

        file = open(barrierFile,"w"); file.write("waiting.\n"); file.close()

        if processID==0:
            #wait until everyone is done.
            for cnt in range(1,numProcs):
                file = barrierFilePref+"%s" % cnt
                trys = timeOut / pollInterval #FIX: must be >= 1
                ok=1
                while 1:
                    trys -= 1
                    try:
                        print "trying ", file
                        if open(file).read()=="waiting.\n": break
                        pass
                    except IOError:
                        pass
                    print "waiting for process %d" % cnt
                    time.sleep(pollInterval)
                    if trys==0:
                        print 'abandoning process %d' % cnt
                        ok=0
                        break
                    pass
                if ok: procs.append( cnt )
                pass

            print "waking all processes"
            #wake all processes
            for cnt in procs[1:]:
                file = barrierFilePref+"%s" % cnt
                open(file,"w").write("wake.\n")
                pass
            #wait til process is running again, then delete its barrier file
            for cnt in procs[1:]:
                file = barrierFilePref+"%s" % cnt
                while open(file).read() != "waking.\n":
                    time.sleep(1)
                    pass
                open(file,'w').write('run.\n')
                try:
                    os.unlink(file)
                except:
                    pass
                pass
            open(barrierFile,'w').write('run.\n')
            try:
                os.unlink(barrierFile)
            except:
                pass
            pass
        else: # processID!=0
            #wait until awakened by process 0
            trys = timeOut / pollInterval #FIX: must be >= 1
            while 1:
                trys -= 1
                if open(barrierFile).read() == "wake.\n":
                    break
                time.sleep(pollInterval)
                if trys==0:
                    print 'ERROR: waited too long for wakeup call.'
                    err=1
                    sharedErr.set(1)
                    break
                pass
            #tell process 0 that we're awake.
            print "process awoken"
            open(barrierFile,"w").write("waking.\n")
            pass
        pass

    multiThread()
    if err or (sharedErr and sharedErr.barrerGet()):
        print "waitForProcess: shutting down due to errors."
        import sys
        sys.exit(1)
        pass
    return procs

class RampedParameter:
    """Base class for ramped parameters.
      update() - increments value. It will not change the value beyond
                 that specified by stopValue
      value()  - return the current value
      init(ns)        - set number of steps and initialize val to startValue
      finalize()      - set value to the final value.
    """
    def __init__(s,action):
        from inspect import currentframe, getouterframes
        s.action = action
        s.callingFrame = getouterframes( currentframe() )[2][0]
        s.val = 0
        s.startValue = 0
        s.stopValue  = 0
        return
    def init(s,numSteps,caller=0):
        s.setNumSteps(numSteps)
        s.val = s.startValue
        s.runAction(caller)
        return
    def finalize(s):
        s.val = s.stopValue
        s.runAction()
        return
    def runAction(s,caller=0):
        if not s.action: return
        if type(s.action) == type("string"):
            global_dict = s.callingFrame.f_globals
            local_dict = s.callingFrame.f_locals
            local_dict["ParameterRampInfo"] = s
            local_dict["caller"] = caller
            exec( s.action.replace("VALUE","%e" %s.val),
                  global_dict, local_dict )
            del local_dict["ParameterRampInfo"]
        else:
            s.action(s.val)
            pass
        return
    def setNumSteps(s,numSteps):
        return
    def value(s):
        return s.val
    def update(s,caller=0):
        s.runAction(caller)
        return 0.
    pass


class MultRamp(RampedParameter):
    """convenience class for multiplicatively (geometrically)
    ramping a value from startValue to stopValue over numberSteps
    constructor: MultRamp(startValue, stopValue, action)
    methods:
      update() - increments value. It will not change the value beyond
                 that specified by stopValue
      value()  - return the current value
      setNumSteps(ns) - set number of steps
      init(ns)        - set number of steps and initialize val to startValue
      finalize()      - set value to the final value.
    """
    def __init__(s,startValue,stopValue,action=None):
        RampedParameter.__init__(s,action)

        s.val = startValue
        s.startValue = max(startValue,1e-30)
        s.stopValue  = max(stopValue,1e-30)
        s.dirIncreasing=0
        if stopValue>startValue: s.dirIncreasing=1
        return
    def update(s,caller=0):
        s.val *= s.factor
        if (s.dirIncreasing and s.val>s.stopValue) or \
           (not s.dirIncreasing and s.val<s.stopValue):
            s.val = s.stopValue
            pass
        s.runAction(caller)
            
        return s.val
    def setNumSteps(s,numSteps):
        s.factor=1
        if numSteps: s.factor = (s.stopValue/s.startValue)**(1./numSteps)
        return
    pass

class LinRamp(RampedParameter):
    """convenience class for linearly
    ramping a value from startValue to stopValue over numberSteps
    constructor: MultRamp(startValue, stopValue, action)
    methods:
      update() - increments value. It will not change the value beyond
                 that specified by stopValue
      value()  - return the current value
      setNumSteps(ns) - set number of steps
      init(ns)        - set number of steps and initialize val to startValue
    """
    def __init__(s,startValue,stopValue,action=None):
        RampedParameter.__init__(s,action)

        s.val = startValue
        s.startValue = startValue
        s.stopValue  = stopValue
        s.dirIncreasing=0
        if stopValue>startValue: s.dirIncreasing=1
        return
    def update(s,caller=0):
        s.val += s.factor
        if (s.dirIncreasing and s.val>s.stopValue) or \
           (not s.dirIncreasing and s.val<s.stopValue):
            s.val = s.stopValue
            pass
        s.runAction(caller)
            
        return s.val
    def setNumSteps(s,numSteps):
        s.factor=0
        if numSteps>0: s.factor = (s.stopValue-s.startValue)/numSteps
        return
    pass

class StaticRamp(RampedParameter):
    """convenience class for static parameter setup.
      update() - increments value. It will not change the value beyond
                 that specified by stopValue
      value()  - return the current value
      setNumSteps(ns) - set number of steps
      init(ns)        - set number of steps and initialize val to startValue
    """
    def __init__(s,action,stride=1):
        """
        action is a function or string to be executed.
        stride specifies how often the function is called. A stride value of
        1 specifies that the function is called every time update is called.
        Larger values of stride specify that action is called by update()
        if caller.step%stride=0.
        """
        RampedParameter.__init__(s,action)
        s.stride=stride
        return
    def update(s,caller=0):
        if caller and 'step' in dir(caller) and caller.step%s.stride!=0:
            return
        s.runAction(caller)
            
        return s.val
    pass

class InitialParams:
    """constructor takes a list of ramped parameters. The constructor invokes
    each parameter such that it set to its initial value.
    Also, this object can be called as a function with zero arguments, to
    set parameters to initial values
    """
    def __init__(s,pList):
        s.pList = pList
        s.__call__()
        return
    def __call__(s):
        for p in s.pList:
            p.init(0)
            pass
        return
    pass

class FinalParams:
    """constructor takes a list of ramped parameters. The constructor invokes
    each parameter such that it set to its final value.
    Also, this object can be called as a function with zero arguments, to
    set parameters to final values.
    """
    def __init__(s,pList):
        s.pList = pList
        s.__call__()
        return
    def __call__(s):
        for p in s.pList:
            p.finalize()
            pass
        return
    pass
        
#from potList import PotList
#class PotListWithContext(PotList):
#    """a <m potList>.PotList with an initializing function called before
#    calcEnergy/calcEnergyAndDerivs
#    """
#    def __init__(s,name="",
#                 potList=None,
#                 context=None):
#        if not potList: potList = PotList()
#        PotList.__init__(s,name)
#        if potList: s.copy(potList)
#        s.context = context
#        from inspect import currentframe, getouterframes
#        s.callingFrame = getouterframes( currentframe() )[1][0]
#        return
#    def calcEnergy(s):
#        if s.context: s.context()
#        return PotList.calcEnergy(s)
#    def calcEnergyAndDerivs(s,derivs):
#        if s.context: s.context()
#        return PotList.calcEnergyAndDerivs(s,derivs)
    



class AnnealIVM:
    """class to perform simulated annealing using molecular dynamics. """
    def __init__(s,  initTemp,  finalTemp,
                 ivm=0,
                 numSteps=None,   tempStep=None,
                 rampedParams   ={},
                 extraCommands  =0,
                 toleranceFactor=1000):
        """construct by specifying the intial and final annealing temperatures,
        and an <m ivm>.IVM object.

        if tempStep is specified, it will be used to determine the number of
          dynamics runs at different temperatures. If if is omitted, numSteps
          will be used (and tempStep implicitly determined).

        rampedParams is a list of MultRamp and LinRamp objects which specify
        refinement parameters to adjust during the simulated annealing run.
        extraCommands is a function or string which is run before dynamics at
        each temperature. If it is a function, is is passed the current
        AnnealIVM instance as the argument.

        toleranceFactor is used to adjust the <m ivm>.IVM's energy tolerance
        as the temperature changes. The energy tolerance is calculated as
        
            eTolerance = temp / toleranceFactor

        The ivm argument doesn't actually have to be an <m ivm>.IVM, but it
        must be an object which has the following methods defined:
          setBathTemp
          setETolerance
          run

        """
        from inspect import currentframe, getouterframes
        s.initTemp = initTemp
        s.finalTemp = finalTemp
        s.extraCommands = extraCommands
        if tempStep:
            s.tempStep=tempStep
            s.numSteps = int( (initTemp - finalTemp)/float(tempStep) )
        elif numSteps:
            s.numSteps = numSteps
            s.tempStep = (initTemp - finalTemp)/float(numSteps)
        else:
            raise("AnnealIVM: neither numSteps nor tempStep is defined")
        s.ivm = ivm
        s.params = rampedParams
        s.toleranceFactor = toleranceFactor
        s.callingFrame = getouterframes( currentframe() )[1][0]
        return
    def run(s):
        s.bathTemp = s.initTemp
        s.printTemp()

        s.initParameters()
        
        s.runExtraCommands()
        s.runIVM(s.bathTemp) #run at initial temperature
        
        for s.step in range(0,s.numSteps):
            s.bathTemp -= s.tempStep
            s.printTemp()

            s.updateParameters()
            s.runExtraCommands()

            s.runIVM(s.bathTemp)

            pass
        return

    def printTemp(s):
        import simulationWorld
        simWorld = simulationWorld.SimulationWorld_world()
        if simWorld.logLevel() != 'none':
            print "AnnealLoop: current temperature: %.2f" % s.bathTemp
            pass
        return

    def runExtraCommands(s):
        if s.extraCommands:
            if type(s.extraCommands) == type("string"):
                global_dict = s.callingFrame.f_globals
                local_dict = s.callingFrame.f_locals
                local_dict["annealLoopInfo"] = s
                exec( s.extraCommands, global_dict, local_dict )
                del local_dict["annealLoopInfo"]
            else:
                s.extraCommands(s)
                pass
            pass
        return
            
    def initParameters(s):
        "initialize ramped parameters"
        for param in s.params:
            param.init( s.numSteps ) #calls runAction
            pass
    def finalParameters(s):
        "sets parameters to final values"
        for param in s.params:
            param.val = param.stopValue
            param.runAction(s)
            pass
    def updateParameters(s):
        for param in s.params:
            param.update(s)
            pass
        return

    def runIVM(s,temp):

        if not s.ivm:
            return

        s.ivm.setBathTemp( temp )
        if s.toleranceFactor>0:
            s.ivm.setETolerance( temp / float(s.toleranceFactor) )
            pass
        s.ivm.run()
        
        return
    pass



def verticalFormat(name,
                   nameValuePairs,
                   minWidth=6,
                   numericFormat=".2f"):
    line2 = " " + name + ":"
    line1 = " " * len(line2)
    for (name,value) in nameValuePairs:
        strLen = max(minWidth,len(name))
        format = " %" + "%ds" % strLen
        line1 += format % name
        format = " %" + ("%d" % strLen) + numericFormat
        line2 += format % value
        pass
    return (line1, line2)

    
                   


if __name__ == "__main__":
    #
    # tests
    #

    
    from ensembleSimulation import EnsembleSimulation
    import sys
    import simulationWorld
    esim = EnsembleSimulation("esim",2)

    initSeed=751
    simWorld = simulationWorld.SimulationWorld_world()
    simWorld.setRandomSeed(751)

    result = ""
    expectedResult="""running structure 0 in process 0  thread: 0 seed: 751
running structure 1 in process 0  thread: 0 seed: 752
running structure 2 in process 0  thread: 0 seed: 753
running structure 3 in process 0  thread: 0 seed: 754
running structure 4 in process 0  thread: 0 seed: 755
running structure 5 in process 0  thread: 0 seed: 756
running structure 6 in process 0  thread: 0 seed: 757
running structure 7 in process 0  thread: 0 seed: 758
running structure 8 in process 0  thread: 0 seed: 759
running structure 9 in process 0  thread: 0 seed: 760
"""
    def concat(str):
        global result
        result += str
        return 0

    
    sys.stdout.write("StructureLoop: action as function...")
    StructureLoop(numStructures=10,
                  structLoopAction=lambda loopInfo: \
                  concat("running structure %d" % loopInfo.count +
                         " in process %d " % loopInfo.processID +
                         " thread: %d" % esim.member().memberIndex() +
                         " seed: %d\n" % loopInfo.randomSeed)
                  ).run()

    if result == expectedResult:
        print "ok"
    else:
        print "FAILED"
        pass


    simWorld.setRandomSeed(751)
    result=""
        
    sys.stdout.write("StructureLoop: action as string...")
        
    StructureLoop(numStructures=10,
                  structLoopAction=r'''global result
result += "running structure %d" % structLoopInfo.count + \
" in process %d " % structLoopInfo.processID + \
" thread: %d" % esim.member().memberIndex() + \
" seed: %d\n" % structLoopInfo.randomSeed
'''
                  ).run()
        
    if result == expectedResult:
        print "ok"
    else:
        print "FAILED"
        pass
            
    del esim

    sys.stdout.write("MultRamp: action as function...")
    expectedResult = "0.10 0.16 0.25 0.40 0.63 1.00 1.58 2.51 3.98 6.31 10.00 "
    result = ""
    def multProc(v):
        global result
        result +=  "%.2f " % v
        return
    
    param = MultRamp(0.1,10,multProc)
    numSteps=10
    param.init(numSteps)
    for i in range(numSteps):
        param.update()
        pass

    if result == expectedResult:
        print "ok"
    else:
        print "FAILED"
        pass
    

    sys.stdout.write("MultRamp: action as string...")
    expectedResult = "0.10 0.16 0.25 0.40 0.63 1.00 1.58 2.51 3.98 6.31 10.00 "
    result = ""
    def multProc(v):
        global result
        result +=  "%.2f " % v
        return
    
    param = MultRamp(0.1,10,"global result; result += '%.2f ' % VALUE")
    numSteps=10
    param.init(numSteps)
    for i in range(numSteps):
        param.update()
        pass

    if result == expectedResult:
        print "ok"
    else:
        print "FAILED"
        pass

    expectedResult="""0.10 extra: 1000.000000
0.32 extra: 900.000000
1.00 extra: 800.000000
3.16 extra: 700.000000
10.00 extra: 600.000000
"""

    sys.stdout.write("AnnealIVM: extraCmd as function...")
    result=""
        
    def coolProc(s):
        global result
        result += 'extra: %f\n' % s.bathTemp
        return

    AnnealIVM(initTemp=1000,
              finalTemp=600,
              tempStep=100,
              rampedParams=(param,),
              extraCommands=coolProc).run()

    if result == expectedResult:
        print "ok"
    else:
        print "FAILED"
        pass


    sys.stdout.write("AnnealIVM: extraCmd as string...")
    result=""
        
    AnnealIVM(initTemp=1000,
              finalTemp=600,
              tempStep=100,
              rampedParams=(param,),
              extraCommands=
              r"result += 'extra: %f\n' % annealLoopInfo.bathTemp").run()

    if result == expectedResult:
        print "ok"
    else:
        print "FAILED"
        pass

    pass

def testGradient(pots,
                 eachTerm=0,
                 alwaysPrint=0,
                 components=[0],
                 tolerance=1e-3,
                 eTolerance=1e-8,
                 epsilon=1e-7,
                 sim=0):
    """
    check the analytic gradient in the given potential terms against finite
    difference values.

    If the eachTerm argument is set, each term in the potList is tested
    individually.

    If alwaysPrint is set, all gradient terms will always be printed.

    components specificies which of the three gradient components (x, y, z) to
    test. By default this is just the x (0) component.

    tolerance specifies the agreement expected between numerical and
    analytic gradient

    eTolerance specifies the ageement relative to the magnitude of the energy

    epsilon specifies the stepsize used in the finite different gradient
    calculation:

            dE/dqi(xyz) = E(qi(xyz)+epsilon) - E(qi(xyz)) / epsilon

    where E(0) is the energy evaluated at the nominal coordinate value
    qi(xyz).

    For each atom, the following is printed:
    atom identifying string  numerical gradient  analytical gradient
    """

    from potList import PotList
    potList = PotList()

    try:
        len(pots)
        for term in pots:
            potList.append(term)
            pass
        pass
    except TypeError:
        potList.append(pots)
        pass


    if not sim:
        from simulation import currentSimulation
        sim = currentSimulation()
        pass
    
    #is this an EnsembleSimulation?
    from ensembleSimulation import EnsembleSimulation_currentSimulation
    esim = EnsembleSimulation_currentSimulation()
    if esim and sim.name() == esim.name():
        sim = esim
    else:
        esim=0
        pass
    

    if eachTerm:
        ret=1
        for term in potList:
            print "testing gradient of potential term:", term.instanceName()
            if not testGradient(term):
                ret=0
            pass
        return ret
    
    from vec3 import Vec3

    # total gradient
    
    from derivList import DerivList
    dlist = DerivList()
    dlist.init(sim)
    energy = potList.calcEnergyAndDerivs(dlist)
    derivs = dlist.get(sim)
    
    ret=1
    if esim:
        header = "\n   Gradient Error Report\n\n"
        header += "%19s  ens %3s  %8s     %8s" %("Atom       ","xyz",
                                               "Numerical","From pot")
        isHeader=False
        sharedObj = esim.sharedObj()

        for j in range( esim.size() ):
            member = esim.members(j)
            for i in range( member.numAtoms() ):
                for xyz in components:
                    if j == esim.member().memberIndex():
                        initPos = Vec3( member.atomPos(i) )
                        pos = Vec3( initPos )
                        pos[xyz] += epsilon
                        member.setAtomPos(i,pos)
                        pass
                    denergy = potList.calcEnergy()-energy
                    if j == esim.member().memberIndex():
                        sharedObj.set( None )
                        grad = denergy/epsilon
                        if (alwaysPrint or
                            (abs(grad-derivs[i][xyz])>
                             tolerance*(1+
                                        max(abs(grad),abs(derivs[i][xyz]))) and
                             abs(grad-derivs[i][xyz])>eTolerance*energy)):
                            sharedObj.set( (grad, derivs[i][xyz]) )
                            pass
                        member.setAtomPos(i,initPos)
                        pass
                    result=sharedObj.barrierGet()
                    if result!=None:
                        ret=0
                        if not isHeader:
                            print header
                            isHeader=True
                            pass
                        print "%-25s"%member.atomByID(i).string(),\
                              j,xyz,result
                        pass
                    pass
                pass
            pass
        pass
    else:
        #not an EnsembleSimulation
        header = "\n   Gradient Error Report\n\n"
        header += "%19s  %3s %8s     %8s" %("Atom       ","xyz",
                                        "Numerical","From pot")
        isHeader=False
        for i in range( sim.numAtoms() ):
            for xyz in components:
                initPos = Vec3( sim.atomPos(i) )
                pos = Vec3( initPos )
            
                pos[xyz] += epsilon
                sim.setAtomPos(i,pos)
                denergy = potList.calcEnergy()-energy
                grad = denergy/epsilon
                if (alwaysPrint or
                    (abs(grad-derivs[i][xyz]) >
                     tolerance*(1+
                                max(abs(grad),abs(derivs[i][xyz]))) and
                     abs(grad-derivs[i][xyz]) > tolerance and
                     abs(grad-derivs[i][xyz])>eTolerance*abs(energy))):
                    if not isHeader:
                        print header
                        isHeader=True
                        pass
                    ret=0
                    print "%s %3d %10.5f %10.5f" %(sim.atomByID(i).string(),
                                                   xyz,
                                                   grad,
                                                   derivs[i][xyz])
                    pass
                sim.setAtomPos(i,initPos)
                pass
            pass
        pass
    return ret

def summarizePotentials(potList):
    "get rms, violations summaries from potential terms"
    terms = []

    for pot in potList:
        rms=-1
        viols=-1
        if pot.potName()=='PotList' and len(pot):
            lterms=summarizePotentials(pot)
            rms=0.
            viols=0
            count=0
            for (pName,name,lrms,lviols) in lterms:
                if lrms>=0:
                    count += 1
                    rms += lrms
                    viols += lviols
                    pass
                pass
            if count>0:
                rms /= count
                viols
                pass
            pass
        try:
            rms = pot.rms()
        except AttributeError:
            pass
        try:
            viols = pot.violations()
        except AttributeError:
            pass
        terms.append((pot.potName(),pot.instanceName(),rms,viols))
        pass
    return terms

def flattenPotList(potList):
    from potList import PotListPtr
    ret = []
    for pot in potList:
        if pot.potName()=='PotList':
            p=PotListPtr( pot.this )
            ret += flattenPotList(p)
        else:
            ret.append(pot)
            pass
        pass
    return ret

def getPotTerms(potList,names):
    """return in a list all the potential terms in potList whose potType()
    matches names. names can be a string or a list of strings.
    """
    pl = flattenPotList(potList)
    if type(names)==type('string'): names = (names,)
    ret =[]
    for name in names:
        ret += filter(lambda x: potType(x)==name,pl)
        pass
    return ret 
                

registeredTerms=[]
def registerTerm(analyzeFunc,termTitle,termPrefix):
    """ register an analysis function to be called by the analyze() function
    [see below].

    One should specify an analysis function which takes a list of potential
    terms as an argument, and two strings to identify the potential term - long
    form, and short form, respecitively.

    analyzeFunc should return a string containing an analysis summary. A more
    detailed analysis can be printed to stdout.
    """
    registeredTerms.append( (analyzeFunc,termTitle,termPrefix) )
    return

def analyze(potList,extraTerms=PotList(),
            outFilename=0):
    """ pretty print appropriate terms from the given PotList and return
    them in a remarks string.
    The optional extraTerms is a PotList of terms which do not
    contribute to the total energy, but which should be analyzed all the same-
    use this for cross-validated potential terms. The potList and extraTerms
    arguments can have type of Pot, PotList, list or tuple.

    If outFilename is specified, the verbose violations info is written to
    the specified file.

    """

    try:
        len(potList)
    except:
        potList = [potList]
        pass

    if type(potList)==type([]) or type(potList)==type(tuple([])):
        terms=potList
        potList=PotList()
        for p in terms:
            potList.append(p)
            pass
        pass
    if type(extraTerms)==type([]) or type(extraTerms)==type(tuple([])):
        terms=extraTerms
        extraTerms=PotList()
        for p in terms:
            extraTerms.append(p)
            pass
        pass
    

    if outFilename:
        outfile = open(outFilename,"w")
        print >> outfile, "\n  Violation Analysis \n" 
        outfile.close()
        pass

    ret=""
    totViols=0
    totEnergy = potList.calcEnergy()
    reports = potList.energyReports()

    terms = summarizePotentials(potList)
    
    for term in terms:
        (potType,name,rms,viols) = term
        energy = filter(lambda x: x[0]==name,reports)[0][1]
        rmsString=" "*8
        if rms>-1: rmsString = "%8.3f" % rms
        violString=' '*8
        if viols>-1:
            violString = "%8d" % viols
            totViols += viols
            pass
        ret += "summary %-10s %10.2f %s %s\n" % (term[1],
                                               energy,rmsString,violString)
        pass

    ret = "summary %-10s %10.2f %8s %8d\n" % ("total",
                                              totEnergy,'',totViols) + ret
    ret = "summary    Name       Energy      RMS     Violations\n" + ret
    ret = "-"*60 + '\n' + ret
    
    ret += "-"*60 + '\n'

    if len(extraTerms):
        ret += "\nCross-validated terms:\n"
        extraTerms.calcEnergy()
        reports = extraTerms.energyReports()

        terms = summarizePotentials(extraTerms)
    
    
        for term in terms:
            (potType,name,rms,viols) = term
            energy = filter(lambda x: x[0]==name,reports)[0][1]
            rmsString=" "*8
            if rms>-1: rmsString = "%8.3f" % rms
            violString=' '*8
            if viols>-1: violString = "%8d" % viols
            ret += "summary %-10s %10.2f %s %s\n" % (term[1],
                                                     energy,
                                                     rmsString,violString)
            pass

        terms = flattenPotList(extraTerms)

        if len(terms) > len(extraTerms):
            ret += "summary\n"
        
            lineBeg = "summary cross-validated terms: "
            line=lineBeg
            for term in terms:
                if (len(line)+len(term.instanceName()))>71: 
                    ret += line + '\n'
                    line=lineBeg
                    pass
                line += "%s" % term.instanceName()
                if term != terms[-1]: line += ", "
                pass
            if line != lineBeg: ret += line + '\n'
            pass

        ret += "-"*60 + '\n'

    if outFilename:
        import sys
        old_stdout = sys.stdout
        sys.stdout = open(outFilename,"a")
        pass

    for (analyzeFunc,termTitle,termPrefix) in registeredTerms:
        tmp = analyzeFunc(list(potList)+list(extraTerms))
        if tmp:
            ret += ' %s\n' % termTitle
            ret += reduce(lambda x,y: x+'\n'+y,
                          map(lambda x: termPrefix+" "+x,
                              tmp.splitlines()))
            ret += '\n' + "-"*60 + '\n'
            pass
        pass

    
    if outFilename:
        import sys
        sys.stdout.close()
        sys.stdout = old_stdout
        pass
    
    # first print overall energies, energies of each term

    #go through each potential type for analysis:
    # print out results if verbose flag is set.
    # bond
    # angle
    # impropers
    #

    # rms, # violations

    # NOE

    # rdc:
    #
    # deal with ensemble, non-ensemble cases
    return ret
    
#dictionary keyed by potName each member of which is a list of two-membered
# tuples: (name,function)
# where function is to be called on a Pot term, like this function(term) and
# return a floating value
extraStats={}
def registerExtraStats(potType,name,function,supportsList=False):
    """register extra terms, averages over selected structures will be
    reported by <m restraintStats>. 

    The four arguments are the potential type as given by the potName
    accessor, the name of the property to be averaged, a
    function to be called on the pot term, like this: function(term) and
    which returns a floating value, and whether this function supports a
    list of terms.
    """
    if not extraStats.has_key(potType):
        extraStats[potType] = [(name,function,supportsList)]
    else:
        extraStats[potType].append( (name,function,supportsList) )
        pass
    return


def saRefine(potList,
             refineSteps=50,
             xplorPots=['BOND','ANGL','IMPR','RAMA'],
             initTemp=10,
             finalTime=0.2,
             numSteps=100,
             htFinalTime=10,
             htNumSteps=1000,
             initVel=1,
             scaleMultiplier=0.001,
             rigidRegions=(),
             fixedRegions=(),
                   ):
    """ 
    Added by Robin A Thottungal 
    Add explanation 07/16/09
    refineSteps specifies how many rounds of minimization to perform.
    xplorPots are XPLOR terms which are to always be present, and for which
              the scale constant is held constant.
    scaleMultiplier specifies the initial value for the scale constant.

    initTemp specifies the initial temperature for high-temperature dynamics,
    and at the start of simulated annealing. The parameters finalTime, and
    numSteps specify dynamics duration and number of steps at each step of
    simulated annealing, while htFinalTime and htNumSteps specify these
    parameters for initial dynamics.

    rigidRegions specifies selections of atoms which do not move relative to
    each other.

    fixedRegion specifies selections of atoms which do not move at all.
    """
    pots = flattenPotList(potList)

    #first get the tensor atoms in reasonable shape-
    # averaging will have scrambled them
    from varTensorTools import getVarTensors, calcTensor
    varTensors = getVarTensors(pots)

    for t in varTensors:
        calcTensor(t)
        pass


    # refine here
    #  remove bond, angle, impr, vdw terms- if they exist- use them as is.
    # if they don't exist, add them in with default scale values.
    #
    # for rest of terms, loop over them, with MultRamp running from .01 .. 1
    # times the nominal scale values.
    from potList import PotList
    minPots = PotList()
    hasReqdTerms = {}
    for p in xplorPots: hasReqdTerms[p] = 0
    rampedParams = []
    for pot in pots:
        reqdTerm=0
        for pType in xplorPots:
            if potType(pot) == 'XplorPot' and pot.instanceName() == pType:
                minPots.append(pot)
                hasReqdTerms[pType] = 1
                reqdTerm=1
                continue
            pass
        if reqdTerm: continue
        minPots.append( pot )
        rampedParams.append( MultRamp( scaleMultiplier*pot.scale(),
                                       pot.scale(),
                                       "minPots['%s'].setScale(VALUE)"%
                                       pot.instanceName() ) )
        pass
    from xplorPot import XplorPot
    for pType in xplorPots:
        if not hasReqdTerms[pType]: minPots.append( XplorPot(pType) )
        pass
    """
    #Added by Robin A Thottungal on 05/11/09 for
    #running a powell minimization with Bond & Angles
    # to fix up the improper
    #begin:
    from ivm import IVM
    import protocol
    min=IVM()
    protocol.cartesianTopology(min)
    protocol.initMinimize(min,
                           potList=[XplorPot("BOND"),XplorPot("ANGL")],
                           numSteps=200)

    if refineSteps>0:
        min.run()
    #end
    """    
    #Cartesian topology
    from ivm import IVM
    minc = IVM()
    for aSel in rigidRegions: minc.group(aSel)
    for aSel in fixedRegions: minc.fix(aSel)
    import protocol
    protocol.cartesianTopology(minc)

    #high-temp dynamics
    protocol.initDynamics(minc,
                          potList=minPots,
                          bathTemp=initTemp,
                          initVelocities=initVel,
                          finalTime=htFinalTime,# stops at 800ps or 8000 steps
                          numSteps=htNumSteps,  # whichever comes first
                          printInterval=100)

    InitialParams(rampedParams)
    minc.run()
    
    #simulated annealing
    protocol.initDynamics(minc,
                          bathTemp=initTemp,
                          finalTime=finalTime,
                          numSteps=numSteps,
                          initVelocities=initVel,
                          potList=minPots)


    import varTensorTools
    for m in varTensors:
        m.setFreedom("varyDa, varyRh")       #allow all tensor parameters float
        varTensorTools.topologySetup(minc,m) #setup tensor topology
        pass

    if refineSteps>0:
        from simulationTools import AnnealIVM
        AnnealIVM(initTemp, finalTemp=0,
                  numSteps=refineSteps,
                  ivm=minc,
                  rampedParams = rampedParams).run()
        pass
                             
    # make sure final tensor is consistent with structure
    for t in varTensors:
        calcTensor(t)
        pass

    return
