# SIMULATED ANNEALING REFINE PROTOCOL
#===============================================================================
# to run type: xplor -py -o refine.out refine_REPEL.py
#			   xplor -py -smp 10 -o refine.out refine_REPEL.py
#===============================================================================

xplor.requireVersion('2.40')

# TOTAL NUMBER OF STRUCTURES AND BASE NAME FOR OUTPUT PDB FILES.
Nstructures = 100
outFilename = "SCRIPT_STRUCTURE.sa"

orderedRegion="""resid 3:15 or resid 27:36 or resid 41:51 or resid 68:80 or
                 resid 85:98 or resid 110:126 or resid 129:139 or resid 144:155"""

# INITIALIZE RANDOM SEED.
import protocol
protocol.initRandomSeed(3421)

useEEFx=False # Set to 1 for EEFx/IMMx , false for the REPEL potential

# initialize parameters, topology and starting model.
if useEEFx:   
    protocol.parameters['protein']="eefx/protein_eef.par"
    protocol.topology['protein']  ="eefx/protein_eef22.top"
    pass

inputStructures="fold_REPEL_*.sa.best"

from glob import glob
ini_model=glob(inputStructures)[0] #structure used in initializing energy terms

protocol.loadPDB(ini_model, deleteUnknownAtoms=True)

#this next line may significantly change the structure
#protocol.fixupCovalentGeom(maxIters=100,useVDW=1)


# SET UP POTLISTS, PARAMETERS AND FORCE CONSTANTS TO RAMP DURING STRUCTURE CALCULATIONS.
#===============================================================================
pots = PotList()

from simulationTools import StaticRamp, MultRamp
from simulationTools import InitialParams
highTempParams = []                 # Settings for high T MD stage.
rampedParams = []                   # Settings for annealing stage.
tensorInit = []                     # Settings for alignment tensor.
tensorFinal = []                    # Final settings for alignment tensor.


# SET UP ALIGNMENT TENSOR AND TENSOR CALCULATION DURING SIMULATED ANNEALING.
# rdcPotTools uses NH bond distance r(NH)=1.042A (Da=10.76; DaMAX = 21523.28).
# To calculate tensor orientation, Da and Rh:   calcTensor(oTensor).
# To calculate tensor orientation only:         calcTensorOrientation(oTensor).
# To analyze pot terms with given setFreedom:   VarTensor_analyze(pots).
# To fix all tensor parameters use:             oTensor.setFreedom("fix").
#===============================================================================
memfreedom="fixAxis, varyDa, fixRh"

if True:
    from varTensorTools import create_VarTensor, alignTensorToZ
    from varTensorTools import calcTensorOrientation
    from simulationTools import analyze
    
    mem = create_VarTensor("mem")   # This is the main membrane order tensor
    mem.setDa(10.76)                # based on rNH=1.042.
    mem.setRh(0.00)                 # axially symmetric bilayer.
    
    tensorInit.append(StaticRamp("calcTensorOrientation(mem)"))
    tensorInit.append(StaticRamp("alignTensorToZ(mem)"))
    tensorInit.append(StaticRamp("print analyze(pots)"))
    tensorInit.append(StaticRamp("mem.setFreedom(memfreedom)"))
    rampedParams.append(StaticRamp("print analyze(pots)"))
    tensorFinal.append(StaticRamp("print analyze(pots)"))



# EXPERIMENTAL AND STATISTICAL RESTRAINING TERMS.
#===============================================================================
if False:
    from rdcPotTools import create_RDCPot
    rdcs = PotList('rdc')
    for (name, file, tensor, scale) in [('rdcNH_F', rdcNH1_data, mem, 1)]:
        rdc=create_RDCPot(name=name, file=file, oTensor=tensor)
        rdc.setScale(scale)
        rdc.setThreshold(1.0)     # dflt [0.0]
        rdc.setShowAllRestraints(1)
        rdcs.append(rdc)
    pots.append(rdcs)
    rampedParams.append(MultRamp(0, 0, "rdcs.setScale(VALUE)"))

if False:
    csas = PotList('csa')
    for (name, file, tensor, scale) in [('csaN_F', csaN1_data,  mem, 1)]:
        csa=create_CSAPot(name=name, file=file, oTensor=tensor)
        csa.setScale(scale)
        rdc.setThreshold(1.0)     # dflt [0.0]
        csa.setShowAllRestraints(1)
        csa.setVerbose(True)
        csa.setDaScale(-21523.28)       # based on rNH=1.042 ***NOTE SIGN.
        csas.append(csa)
    pots.append(csas)
    rampedParams.append(MultRamp(0, 0, "csas.setScale(VALUE)"))

if False:
    rdcsCross = PotList('rdcCross')
    for (name, file, tensor, scale) in [('rdcNH_V', rdcNH2_data, mem, 1)]:
        rdcCross=create_RDCPot(name=name, file=file, oTensor=tensor)
        rdcCross.setScale(scale)
        rdcCross.setShowAllRestraints(1)
        rdcsCross.append(rdcCross)
    pots.append(rdcsCross)
    rampedParams.append(MultRamp(0, 0, "rdcsCross.setScale(VALUE)"))

if False:
    from csaPotTools import create_CSAPot
    csasCross = PotList('csaCross')
    for (name, file, tensor, scale) in [('csaN_V', csaN2_data,  mem, 1)]:
        csaCross=create_CSAPot(name=name, file=file, oTensor=tensor)
        csaCross.setScale(scale)
        csaCross.setShowAllRestraints(1)
        csaCross.setVerbose(True)
        csaCross.setDaScale(-21523.28)       # based on rNH=1.042 ***NOTE SIGN.
        csasCross.append(csaCross)
    pots.append(csasCross)
    rampedParams.append(MultRamp(0, 0, "csasCross.setScale(VALUE)"))

if True:
    from noePotTools import create_NOEPot
    dsts = PotList('dst')
    for (exp, file, scale) in [('noe', "noe_Ail.tbl", 1),
                               ('hbn', "hbn_Ail.tbl", 1)]:
        dst = create_NOEPot(exp,file,)
        dst.setScale(scale)
        dst.setThreshold(0.5)     # dflt [0.5]
        dsts.append(dst)
    pots.append(dsts)
    rampedParams.append(MultRamp(2, 30, "dsts.setScale(VALUE)"))

from xplorPot import XplorPot

cdi_data="cdi_Ail.tbl"
protocol.initDihedrals(cdi_data)
pots.append(XplorPot('CDIH'))
highTempParams.append(StaticRamp("pots['CDIH'].setScale(10)"))
rampedParams.append(StaticRamp("pots['CDIH'].setScale(200)"))

# CONFORMATIONAL ENERGY TERMS.
#===============================================================================
from torsionDBPotTools import create_TorsionDBPot
torsionDB = create_TorsionDBPot(name='torsionDB')
pots.append(torsionDB)
rampedParams.append(MultRamp(0.002, 2, "torsionDB.setScale(VALUE)"))

# NONBONDED TERM FOR STANDARD REPEL CALCULATIONS.
# Dflt settings:
# [cutnb=4.5, rcon=4.0, nbxmod=3, selStr='all', tol=0.5, repel=0.8, onlyCA=0, simulation=0]
# rcon = energy constant for REPEL function.
# repel = factor by which to multiply the vdw radius.
#
# nbxmod=5: Use with no torsionDB. Exclude 1-2, 1-3 and compute 1-4 interactions.
# nbxmod=4: Use with torsionDB. Exclude 1-2, 1-3 and 1-4 interactions.
#===============================================================================
pots.append(XplorPot('VDW'))

highTempParams.append(StaticRamp("""protocol.initNBond(cutnb=100,
                                                     repel=1.2,
                                                     rcon=0.004,
                                                     nbxmod=4,
                                                     tolerance=45,
                                                     onlyCA=1)"""))
                                                                
rampedParams.append(StaticRamp("protocol.initNBond(nbxmod=4)"))
rampedParams.append(MultRamp(0.9,  0.8,"xplor.command('param nbonds repel VALUE end end')"))
rampedParams.append(MultRamp(0.004, 4.0,"xplor.command('param nbonds rcon VALUE end end')"))

pots.append(XplorPot('BOND'))          # Dflt scale [1]
pots.append(XplorPot('ANGL'))          # Dflt scale [1]
pots.append(XplorPot('IMPR'))          # Dflt scale [1]
rampedParams.append(MultRamp(0.4, 1,  "pots['BOND'].setScale(VALUE)"))
rampedParams.append(MultRamp(0.4, 1,  "pots['ANGL'].setScale(VALUE)"))
rampedParams.append(MultRamp(0.1, 1,  "pots['IMPR'].setScale(VALUE)"))

# THRESHOLD VALUES FOR  VIOLATION ANALYSIS OF THE POTLIST TERMS.
# Use default values unless specified (cdih [5.0], noe [0.5], rdc [0.5]).
#===============================================================================
pots['BOND'].setThreshold(0.05)     # dflt [0.05]
pots['ANGL'].setThreshold(5.00)     # dflt [2.0]
pots['IMPR'].setThreshold(5.00)     # dflt [2.0]
pots['CDIH'].setThreshold(5.00)     # dflt [5.0]
dsts.setThreshold(5.00)             # dflt [5.0]

# NONBONDED ENERGY TERM - EEFx/IMMx IMPLICIT SOLVATION.
# Membrane hydrophobic thickness: DMPC (25.4A), DPPC (28.6A), POPC (27.0A), DOPC (29.6A)
# See Marsh, Handbook of lipid bilayers, 2nd Ed. p. 379.
# IMMn<10 generates a steeper polar-apolar membrane transition gradient.
# IMMa>0.85 increases dielectric screening (reduces electrostatic effects) in membrane.
# EEFx/IMMx uses nbxmod=4 (exclude 1-2, 1-3 and 1-4 interactions).
#===============================================================================
#from eefxPotTools import create_EEFxPot, param_LK
#eefx=create_EEFxPot("eefx","All",
#                    paramSet=param_LK,
#                    verbose=False)
#eefx.setScale(1)
#eefx.setVerbose(1)
#eefx.setIMMx(1)
#eefx.useGROUp(1)
#eefx.setMoveTol(0.5)
#print eefx.showParam()
#eefx.setThickness(25) # IMMx membrane thickness [25.4 DMPC; 28.6 DPPC; 27.0 POPC; 29.6 DOPC].
#eefx.setProfileN(10)  # IMMx n parameter of membrane profile (use n<3 in early stages).
#eefx.setA(0.85)       # IMMx a value that scales dielectric screening. Dflt=[0.85].
#pots.append(eefx)

IMM_com = "resid 1:156 and (name CA)"   # Center of mass selection for IMMx position.
Zpos=0									# Z position relative to IMMx membrane center. Dflt=[0].
from eefxPotTools import setCenter, setCenterXY
setCenter(IMM_com, Zpos)				# Translate selected center of mass to IMMx Zpos.


# SET UP IVM OBJECTS (dyn, minc) THAT PERFORM DYNAMICS AND MINIMIZATION.
# IVM (internal variable module) is used to perform dynamics and minimization in
# both torsion-angle and Cartesian space. Bonds, angles and many impropers cannot
# change with internal torsion-angle dynamics.
#===============================================================================
ini_temp = 3000.0 ; fin_temp = 25.0     # Initial and final temperatures.
protocol.massSetup()                    # Give atoms uniform weights except for axes.


from ivm import IVM
dyn = IVM()                             # IVM object for torsion-angle dynamics.
dyn.reset()                             # reset ivm topology for torsion-angle dynamics.
protocol.torsionTopology(dyn)

minc = IVM()                            # IVM object for Cartesian minimization.
protocol.cartesianTopology(minc)

from simulationTools import AnnealIVM
cool = AnnealIVM(initTemp=ini_temp,     # Cooling loop.
                 finalTemp=fin_temp,
                 tempStep=12.5,
                 ivm=dyn,
                 rampedParams=rampedParams)


def accept(potList):
    """
    return True if current structure meets acceptance criteria
    """
    if pots['dsts'].rms()>0.08:
    #if dsts.violations()>0:
        return False
    if pots['CDIH'].rms()>1.75:
    #if pots['CDIH'].violations()>0:
        return False
    if pots['BOND'].violations()>0:
        return False
    if pots['ANGL'].violations()>0:
        return False
    if pots['IMPR'].violations()>1:
        return False
    
    return True


# CALCULATE STRUCTURE MODULE.
#===============================================================================
def calcOneStructure(loopInfo):
    """
    This function calculates a single structure, performs analysis on the
    structure and then writes out a pdb file with remarks.
    """
    InitialParams(tensorInit)       # Calculate initial tensor parameters for this
                                    # structure.
    InitialParams(rampedParams)     # parameters for SA.
    InitialParams(highTempParams)   # reset some rampedParams.

    # Initial Cartesian minimization.
    #===========================================================================
    protocol.initMinimize(minc,
                          numSteps=500,         # dflt [500 steps]
                          potList=pots,
                          printInterval=50,
                          dEPred=10)
    minc.run()
    
    # High temperature dynamics.
    #===========================================================================
    #setCenter(IMM_com, Zpos)            # translate selected center of mass to IMMx Zpos.
    protocol.initDynamics(dyn,
                          potList=pots,         # potential terms to use
                          bathTemp=ini_temp,    # set bath temperature.
                          initVelocities=1,     # uniform initial velocities.
                          finalTime=20,         # run for finalTime ps or
                          numSteps=20001,       # numSteps * 0.001, whichever is less
                          printInterval=100)
    dyn.setETolerance(ini_temp/100)             # used to det. stepsize, dflt [temp/1000]
    dyn.run()

    # Initialize integrator and loop for simulated annealing and run.
    #===========================================================================
    InitialParams(rampedParams)
    protocol.initDynamics(dyn,
                          potList=pots,
                          finalTime=0.2,        # run for finalTime ps or
                          numSteps=201,         # numSteps * 0.001, whichever is less
                          printInterval=100)    
    cool.run()                                  # Run cooling loop.
    
    # Final minimization.
    #===========================================================================
    # Torsion angle minimization.
    protocol.initMinimize(dyn,
                          numSteps=500,         # dflt [500 steps]
                          potList=pots,
                          printInterval=50)
    dyn.run()

    # Cartesian all-atom minimization.
    protocol.initMinimize(minc,
                          numSteps=500,         # dflt [500 steps]
                          potList=pots,
                          printInterval=50,
                          dEPred=10)
    minc.run()

    # Rotate coordinates and tensor in EEF/IMMM membrane axis frame.
    #===========================================================================
    InitialParams(tensorFinal)
    #setCenter(IMM_com, Zpos)            # translate selected center of mass to IMMx Zpos.
    setCenterXY()           # translate protein coordinates to XY center.
    
    # Do analysis and write structure when this routine is finished.
    pass

# LOOP CONTROL.
#===========================================================================
from simulationTools import StructureLoop, FinalParams
StructureLoop(structLoopAction=calcOneStructure,
              pdbFilesIn=inputStructures,
              numStructures=100,
              pdbTemplate=outFilename,
              calcMissingStructs=True,
              doWriteStructures=True,           # analyze and write coords after calc
              genViolationStats=True,           # print stats file
              averageContext=FinalParams(rampedParams),
              averageFitSel="(%s) and name CA" %# selection for bkbn fit and rmsd [CA]
                              orderedRegion,
              averageCompSel="not name H*",     # selection for heavy atom rmsd
              averageTopFraction=0.2,           # Report stats on top 20%
              averagePotList=pots,              # Terms for stats and avg
              #averageCrossTerms=[rdcs,csas,rdcsCross,csasCross],    # Cross correlation terms
              averageSortPots=pots,             # Terms used to sort models
              ).run()
