#
# PROTOCOL FOR CONJOINED RIGID BODY/ROSION ANGLE DYNAMICS FOR DOCKING 
# PROTEIN-PROTEIN COMPLEXES BASED ON HIGHLY AMBIGUOUS DISTANCE RESTRAINTS DERIVED
# FROM CHEMICAL SHIFT MAPPING AND DIPOLAR COUPLINGS.
#
# Clore, G.M. & Schwieters, C.D. (2003) Docking of protein-protein complexes
# on the basis of highly ambiguous distance restraints derived from 1HN/15N chemical shift
# mapping and backbone 15N-1H residual dipolar couplings using conjoined
# rigid body/torsion angle dynamics.  J. Am. Chem. Soc. 125, 2902-2912
#

import sys
sys.stderr = sys.stdout
from os import environ as env
from selectTools import groupRigidSideChains
from selectLogic import select as selLogic

xplor.parseArguments() # check for typos on the command-line

from math import pow
from ivm import IVM
from xplor import command
from xplor import select
from xplorPot import XplorPot


write = sys.stdout.write
import simulation
sim = xplor.simulation


#!----------------------------------------------------------------------
#! read in the PSF file and initial structure

command("set mess=off echo=off end")
command("param @TOPPAR:parallhdg_new.pro @TOPPAR:par_axis_3.pro end")
command("set mess=off echo=off end")


command("structure @IIA.psf @hpr_cmplx.psf @TOPPAR:axis_500.psf end")

command("""
delete select
     (name ot* or name ht*
or   ((resid 105 or resid 112 or resid 165) and name he)
or   ((resid 57 or resid 74 or resid 142) and name hd2#)
or   (resid 75 and name hd1) or (resid 90 and name he2)
or   (resid 315 and name hd1) or (resid 376 and name hd1)
     )
end
""")


command("coor @IIA_altB.pdb")
command("coor @hpr_1.pdb")
command("coor @TOPPAR:axis_xyzo_3.pdb")
command("coor copy end")



command("vector do (refx=x) (all)")
command("vector do (refy=y) (all)")
command("vector do (refz=z) (all)")



#!----------------------------------------------------------------------
#! set the weights for the experimental energy terms
    
knoe  = 0.01             # noes force constant
asym  = 0.1              # slope of NOE potential
kcdi  = 10.0             # torsion angles
kramascale = 0.002       # rama
kcoll=0.01               # collapse force constant

command (""" eval ($krama =1.0)""")  # This is read into the forces file

command (""" collapse
  assign (resid 19:168 or resid 301:385 ) 1.0 17.5
  scale %f
end""" %kcoll)


command ("""
rama
nres=10000
set message off echo off end
@QUARTS:2D_quarts_new.tbl
@QUARTS:3D_quarts_new.tbl
@QUARTS:forces_torsion_prot_quarts_intra.tbl  ! this is where krama is needed
scale %f
end
set message off echo off end
@QUARTS:setup_quarts_torsions_intra_2D3D.tbl
""" % kramascale)



#! Read experimental restraints

command("""
noe
  reset
  nres = 30000             
  ceiling=100.
  class all 
  @shifts_noe_newx.tbl
end""")

command("""
noe
  ceiling 1000
  averaging  all sum
  potential  all square   !soft
  scale      all %f
  sqconstant all 1.0
  sqexponent all 2
end""" % knoe)

ksani = 0.001

command("""
  sani
  nres=4000
  class JNHe
  force %f
  potential harmonic
  coeff 0.0 -14.9 0.20
  @iia_dipo_work.tbl
  @iia_dipo_free.tbl

 class JNHh
  force %f
  potential harmonic
  coeff 0.0 -14.9 0.20
  @hpr_dipo_norm_work.tbl
  @hpr_dipo_norm_free.tbl
end""" % (ksani,ksani))

#
# Compare Dipolar couplings to the starting structures in their
# X-ray orientation
(rms_sani_JNHe,
 viol_sani_JNHe) = command("sani print threshold=0.0 class JNHe end",
                           ("result","violations"))
R_sani_JNHe = float(rms_sani_JNHe)*100/19.13

(rms_sani_JNHh,
 viol_sani_JNHh) = command("sani print threshold=0.0 class JNHh end",
                           ("result","violations"))
R_sani_JNHh = float(rms_sani_JNHh)*100/19.13



rcon  = 0.003
cool_steps = 24000
init_t  = 1500.01


# set atomic masses:
#  Balancing the moment of inertia of the orientational tensors' parameter
# atom with real protein atoms
#
for i in select("not resid 500"):  xplor.simulation.setAtomMass(i,30.)
for i in select("resid 500"):      xplor.simulation.setAtomMass(i,3750)  
                                                                         
									 
for i in select("all"):            xplor.simulation.setAtomFric(i,100.)

#
#
#{* Generate Structures 1 -> x where x is set by totStructs *}
#


totStructs = 1
proc = int( env["XPLOR_PROCESS"] )
numProcs = int( env["XPLOR_NUM_PROCESSES"] )

start = (proc     * totStructs) / numProcs
stop  = ((proc+1) * totStructs) / numProcs
# Calculate how many structures per processor


for count in range(start,stop):

    command("set seed %d end" % count)

    for i in ("x", "y", "z"):
        command("vector do (%s=%scomp) (all)" % ((i,)*2))
        command("vector do (v%s = maxwell(%f)) (all)" % (i,init_t))
        pass
        # copy from comp and set velocities equal to a Maxwell dis-
	# tribution with init_t as the temp
    ini_rad  = 0.9        ; fin_rad  = 0.75
    ini_con=  0.004       ; fin_con=  1.0
    ini_ang = 1.0         ; fin_ang = 1.0
    ini_imp = 1.0         ; fin_imp = 1.0

    ini_noe = 0.01        ; fin_noe = 30.0
    knoe  = ini_noe                           

    ini_rama = 0.002      ; fin_rama = 1.0
    kramascale = ini_rama

    ini_coll = 0.01       ; fin_coll = 100.0
    kcoll = ini_coll

    
    ini_sani = 0.001      ; fin_sani = 0.01  



    command("""
    parameters
    nbonds
    atom
    nbxmod 4
    wmin  =   0.01  ! warning off
    cutnb =   4.5   ! nonbonded cutoff
    tolerance 0.5
    repel=    0.8   ! scale factor for vdW radii = 1 ( L-J radii)
    rexp   =  2     ! exponents in (r^irex - R0^irex)^rexp
    irex   =  2
    rcon =4     ! actually set the vdW weight
    end
    end""")


    command("energy end")
    (rms_noe,violations_noe) = command("print threshold=0. noe",
                                       ("result","violations"))

    
    command("constraints inter (resid 19:168) (resid 301:385) end")

    m = IVM(xplor.simulation)        #``m'' for minimization
    m.setStepType( 'powell' )
    m.setVerbose( m.verbose() | m.printNodeDef )
    m.setNumSteps( 1000 )
    m.setDEpred( 1 )
    m.setETolerance( 1e-7 )
    m.setPrintInterval( 1 )
    m.fix( select('resid 19:168') )  # Rigid body minimization set up
    m.fix( select('resid 500') )     
    groupList = m.groupList()
    groupList.append(select('resid 301:385') )
    m.setGroupList( groupList )

#    command("flags exclude * include sani end")
    m.potList().removeAll()
    m.potList().add( XplorPot("SANI") )
    m.run()    # Minimize with just dipolar couplings (rigid body)

#    command("flags exclude * include noe vdw end")
    m.potList().removeAll()
    m.potList().add( XplorPot("NOE") )
    m.potList().add( XplorPot("VDW") )
    m.run()  # Minimize with just NOEs and van der Waals (rigid body)

    m.setGroupList( [] )   # This removes any selections from m.groupList
    m.groupList()          # This prints out what is in m.groupList (nothing)
    m.fix( select('resid 19:168') )
    groupList = m.groupList()
    groupList.append( select('resid 301:385') )  # Since we cleared it, have to 
                                                 # read this in again
    groupList.append( select('resid 19:168'))    # both fixed and grouped
                                                 # has no effect on the final
						 # structure
    groupList.append( select('resid 500') )
    m.setGroupList( groupList )
    m.setHingeList( [( 'rotate', select('resid 500')),] )
    m.setETolerance( 1e-8 )
    m.setDEpred( 1e-4 )


    ivm = IVM()
    ivm.setPrintInterval( 50 )
    ivm.fix( select('resid 19:168') )
    ivm.setGroupList([])         # set up a empty grouplist
    groupList = ivm.groupList()
    groupList.append( select('resid 500') )
    groupList.append( select('resid 301:385') )
    groupList.append( select('resid 19:168'))  # both fixed and grouped
                                               # this has an effect on 
					       # the final structure
    ivm.setGroupList( groupList )
    ivm.setHingeList( [( 'rotate', select('resid 500')),] )
    ivm.setResponseTime(20)
    ivm.setStepsize( 1e-3 )
    ivm.setResetCMInterval( 30000 )
    ivm.setStepType( "PC6" )
    ivm.setStepsize( 1.5e-2 )
    ivm.setAdjustStepsize( 1 )

    #  Variables for cooling loop with rigid bodies
    final_t = 500  #     { K }
    tempstep = 25  #     { K }

    ncycle = int(init_t-final_t)/tempstep
    ncycle2 = ncycle*1.0001
    nstep = int(cool_steps*1.6/ncycle)

    bath  = init_t
    k_vdw = ini_con
    k_vdwfact = pow(fin_con/ini_con,1/ncycle2)
    radius=    ini_rad
    radfact = pow(fin_rad/ini_rad,1/ncycle2)
    k_ang = ini_ang
    ang_fac = pow(fin_ang/ini_ang,1/ncycle2)
    k_imp = ini_imp
    imp_fac = pow(fin_imp/ini_imp,1/ncycle2)
    noe_fac = pow(fin_noe/ini_noe,1/ncycle2)
    knoe = ini_noe

    rama_fac = pow(fin_rama/ini_rama,1/ncycle2)
    kramascale = ini_rama

    coll_fac = pow(fin_coll/ini_coll,1/ncycle2)
    kcoll = ini_coll

    sani_fac = pow(fin_sani/ini_sani,1/ncycle2)
    ksani = ini_sani
    




#    command("flags exclude * include vdw noe  end")
    ivm.potList().removeAll()
    ivm.potList().add( XplorPot("NOE") )
    ivm.potList().add( XplorPot("VDW") )


    #rigid body cooling loop
    for i_cool in range(0,ncycle+1):

        final_t = 500  #     { K }
        tempstep = 25  #     { K }

        ncycle = int(init_t-final_t)/tempstep
        nstep = int(cool_steps*1.6/ncycle)

        bath = init_t


        bath -= tempstep
        etol     =  bath/1000

        k_vdw=min(fin_con,k_vdw*k_vdwfact)
        radius=max(fin_rad,radius*radfact)
        k_ang *= ang_fac
        k_imp *= imp_fac
        knoe  *= noe_fac
        kramascale *= rama_fac
        kcoll *= coll_fac

        ksani *= sani_fac
    
        command("""
        parameter
        nbonds cutnb=4.5 rcon=%f nbxmod=4 repel=%f end
        end""" % (k_vdw, radius))

        command("noe scale all %f end" % knoe)  # only NOEs and vdw forces
	                                        # in this loop

        ivm.setBathTemp(bath)
        ivm.setETolerance( etol )
        ivm.setAdjustStepsize( 1 )
        ivm.setFinalTime( 60 )
        ivm.setStepType( "PC6" )
        try:
            ivm.run()
        except:
            import traceback
            print 'during ivm.run() : caught exception: '
            traceback.print_exc()
            pass
        pass
    
    # cooling loop with sidechains free
    #  first define sidechains
    bbStr = "(name N or name C or name CA or name O or name H or name HA)"

    IIASel = select('resid 19:168')
    IIASideChainSel = select("not " + bbStr + " and " + \
        """(resid 38:42 or resid 45:48 
        or resid 69 or resid 71:72 or resid 79
        or resid 80 or resid 86:88 or resid 90 
        or resid 94 or resid 96:97 or resid 99 
        or resid 109 or resid 141 or resid 143:144)""")

    hprSel= select('resid 301:385')
    hprSideChainSel = select("not " + bbStr + " and " + \
        """(resid 312 or resid 315:317 or resid 321:322 
        or resid 327 or resid 345:353 
        or resid 356)""")
    print "hpr sc: " + `len(hprSideChainSel)`

    fullGroupList = []   # Empty the group list
    fullGroupList += groupRigidSideChains(
        selLogic("IIASideChainSel or hprSideChainSel",vars()) )
    fullGroupList.append( selLogic("hprSel and not hprSideChainSel",vars()) )
    fullGroupList.append( select('resid 500') )  
    
    command("sani class JNHe force %f end" % ksani)
    command("sani class JNHh force %f end" % ksani)
    
    command("constraints inter (not resid 500) (not resid 500) end")
    ivm.reset()
    fullGroupList.append( selLogic("IIASel and not IIASideChainSel",vars()) )
    ivm.setGroupList( fullGroupList )
    ivm.setHingeList( [( 'rotate', select('resid 500')),] )
    ivm.autoTorsion()
    init_t= 1500.01
    bath  = init_t
    ini_con=0.1
    k_vdw = ini_con
    ini_rad = 0.78
    radius= ini_rad
    knoe  =  1.0
    ksani = ini_sani
    kcoll = ini_coll
    kramascale = ini_rama
    final_t = 25  #     { K }
    tempstep = 25  #     { K }


    ncycle = int(init_t-final_t)/tempstep
    ncycle2 = ncycle*1.0001
    nstep = int(cool_steps*4.0/ncycle)

    bath  = init_t
    k_vdwfact = pow(fin_con/ini_con,1/ncycle2)
    radfact = pow(fin_rad/ini_rad,1/ncycle2)
    ang_fac = pow(fin_ang/ini_ang,1/ncycle2)
    imp_fac = pow(fin_imp/ini_imp,1/ncycle2)
    noe_fac = pow(fin_noe/ini_noe,1/ncycle2)
    rama_fac = pow(fin_rama/ini_rama,1/ncycle2)
    coll_fac = pow(fin_coll/ini_coll,1/ncycle2)
    sani_fac = pow(fin_sani/ini_sani,1/ncycle2)

    command("""parameter
                 nbonds cutnb=4.5 rcon=%f nbxmod=4 repel=%f end
               end""" % (k_vdw, radius))

#    command("flags exclude * include noe sani vdw rama coll end")
    ivm.potList().removeAll()
    ivm.potList().add( XplorPot("NOE") )
    ivm.potList().add( XplorPot("SANI") )
    ivm.potList().add( XplorPot("VDW") )
    ivm.potList().add( XplorPot("RAMA") )
    ivm.potList().add( XplorPot("COLL") )

    command("noe scale all %f end" % knoe)
    command("rama scale %f end" % kramascale)
    command("collapse scale %f end" % kcoll)


    command("sani class JNHe force %f end" % ksani)
    command("sani class JNHh force %f end" % ksani)


    ivm.setNumSteps(1000)
    ivm.setAdjustStepsize( 1 )
    ivm.setStepType( "PC6" )
    ivm.setFinalTime(nstep*0.002)
    ivm.setBathTemp(bath)
    etol = bath/1000
    ivm.setETolerance(etol)
    ivm.run()                         # Initial run at init_t temperature
    print ivm.info()

    bath  = init_t

    for i_cool in range(0,ncycle):

        bath -= tempstep
        etol     =  bath/1000

        k_vdw=min(fin_con,k_vdw*k_vdwfact)
        radius=max(fin_rad,radius*radfact)
        k_ang *= ang_fac
        k_imp *= imp_fac
        knoe  *= noe_fac
        kcoll *= coll_fac
        kramascale *= rama_fac

        ksani *= sani_fac

        command("""
        parameter
        nbonds cutnb=4.5 rcon=%f nbxmod=4 repel=%f end
        end""" % (k_vdw, radius))


        command("noe scale all %f end" % knoe)
        command("rama scale %f end" % kramascale)
        command("collapse scale %f end" % kcoll)


        command("sani class JNHe force %f end" % ksani)
        command("sani class JNHh force %f end" % ksani)

        ivm.setBathTemp(bath)
        ivm.setETolerance( etol )
        ivm.setAdjustStepsize( 1 )
        ivm.setFinalTime( nstep*0.002 )
        ivm.setStepType( "PC6" )

        try:
            ivm.run()
        except:
            import traceback
            print 'during ivm.run() : caught exception: '
            traceback.print_exc()
            pass
        pass


     #free interfacial sidechain minimization
    m.reset()
    m.setGroupList( fullGroupList )              # imports groupList from ivm in above loop
    m.setHingeList( [( 'rotate', select('resid 500')),] )
    m.autoTorsion()
    command("parameter nbonds rcon 3 repel 0.78 end end")
    m.setPotList( ivm.potList() )
#    for pot in ivm.potList():                   # imports potentials from ivm in above loop
#        m.potList().add( pot )
    m.run();


    command("noe scale all 60.0 end")
    command("sani class JNHe force 0.2 end")
    command("sani class JNHh force 0.2 end")

    command("parameter nbonds rcon 3 repel 0.78 end end")

    m.setNumSteps(20000)
    m.run()


    
    command("flags exclude * include vdw noe sani coll rama end")
    # match the xplor potentials to those selected in potList
    command("ener end")
    (rms_noe,violations_noe) = command("print threshold=0.5 noe",
                                       ("result","violations"))
    rms_bonds = command("print thres=0.05 bonds","result")
    rms_angles = command("print thres=5. angles","result")
    rms_impropers = command("print thres=5. impropers","result")
   
    (rms_sani_JNHe,
     viol_sani_JNHe) = command("sani print threshold=0.0 class JNHe end",
                              ("result","violations"))
    R_sani_JNHe = float(rms_sani_JNHe)*100/19.13

    (rms_sani_JNHh,
     viol_sani_JNHh) = command("sani print threshold=0.0 class JNHh end",
                              ("result","violations"))
    R_sani_JNHh = float(rms_sani_JNHh)*100/19.13

    R_sani_JNH =  (float(rms_sani_JNHe)*100*118/19.13 + float(rms_sani_JNHh)*100*75/19.13)/193

    command("""
    remarks ===============================================================
    remarks repel=0.78
    remarks      overall vdw,noe, sani, coll, rama
    remarks energies: $ener $vdw, $noe, $sani, $coll, $rama
    remarks ===============================================================
    remarks            bonds,angles,impropers,noe
    remarks  bonds etc: """ +
            `rms_bonds` + ' ' + `rms_angles` + ' ' + `rms_impropers` + ' ' +
            `rms_noe`   + """
    remarks ===============================================================
    remarks                noe
    remarks violations :   """ + `violations_noe` + """
    remarks ===============================================================
    remarks  Rms sani iia: """ + `rms_sani_JNHe` + """
    remarks  R factor sani iia: """ + `R_sani_JNHe` + """
    remarks  viol sani iia: """ + `viol_sani_JNHe` + """
    remarks  Rms sani hpr: """ + `rms_sani_JNHh` + """
    remarks  R factor sani hpr: """ + `R_sani_JNHh` + """
    remarks  viol sani hpr: """ + `viol_sani_JNHh` + """
    remarks  overall R sani: """ + `R_sani_JNH` + """
    remarks ===============================================================""")

 
    command("write coor output=test1yy_%d.min end" % count)

    pass