#!/usr/bin/env python

import sys, getopt
import os.path
from time import perf_counter

from openff.toolkit.topology import Molecule
from openmmforcefields.generators import GAFFTemplateGenerator
from openmm.app import *
from openmm import *
from openmm.unit import *
from pdbfixer import PDBFixer
from openmmforcefields.generators import SystemGenerator

t_start = perf_counter()

protein_name = ''
ligand_name = ''
json_filename = ''

try:
    opts, args = getopt.getopt(sys.argv[1:],"hl:p:j:w:n:",["help","ligand=","protein=","json=","winsteps=","nwin="])
except getopt.GetoptError:
    print('openmm-protein-ligand-complex-mmgbsa-sq.py -l <name> -p <name> -j <name> -w <int> -n <int>')
    print('  -l: name of ligand file, sdf fromat, w/o extension')
    print('  -p: name of protein file, pdb fromat, w/o extension')
    print('  -j: name of json file witn ligand topology and charges')
    print('  -w: number of simulated annealing windows')
    print('  -n: MD steps per SA window')
    sys.exit(2)
for opt, arg in opts:
    if opt == '-h':
        print('openmm-protein-ligand-complex-mmgbsa-sa.py -l <name> -p <name> -j <name>')
        sys.exit()
    elif opt in ('-l', '--ligand'):
        ligand_name = arg
    elif opt in ('-p', '--protein'):
        protein_name = arg
    elif opt in ('-j', '--json'):
        json_filename = arg
    elif opt in ('-w', '--nwin'):
        nwin = int(arg)
    elif opt in ('-n', '--winsteps'):
        winsteps = int(arg)

protein_filename = protein_name + '.pdb'
ligand_filename  = ligand_name + '.sdf'

if os.path.isfile(json_filename):
    print('# try to find topology for ',ligand_filename,' in ', json_filename)
    calc_charges = False
else:
    print('# ligand charges for ',ligand_filename,' will be calculated and topology saved in: ', json_filename)
    calc_charges = True

print('# protein input file:   ', protein_filename)
print('# ligand input file:    ', ligand_filename)
print('# json file:            ', json_filename)

print('Reading',protein_filename ,'and Fixing protein')
sys.stdout.flush()

fixer = PDBFixer(filename=protein_filename)
fixer.findMissingResidues()
fixer.findMissingAtoms()
fixer.addMissingAtoms()
fixer.addMissingHydrogens(7.0)

fixed_protein_filename = 'fixed-' + protein_filename
PDBFile.writeFile(fixer.topology, fixer.positions, open(fixed_protein_filename, 'w'),keepIds=True)

print('# Reading',fixed_protein_filename)
protein = PDBFile(fixed_protein_filename)

print('# load molecule')
sys.stdout.flush()

molecule = Molecule.from_file(ligand_filename, allow_undefined_stereo=True)

print('# generate and register topology')
print('# If this takes a while the topology for the molecule in ',ligand_filename,' has not been found in ',json_filename,' and the charges need to be calculated with antechamber/sqm')
sys.stdout.flush()

gaff = GAFFTemplateGenerator(cache=json_filename, molecules=molecule, forcefield='gaff-2.11')
forcefield = ForceField('amber14-all.xml', 'implicit/gbn2.xml', 'amber/tip3p_HFE_multivalent.xml')
forcefield.registerTemplateGenerator(gaff.generator)

print('# merge protein and molecule')
sys.stdout.flush()

modeller = Modeller(protein.topology, protein.positions)
modeller.add(molecule.to_topology().to_openmm(), molecule.conformers[0].to_openmm())

print('# define system')
sys.stdout.flush()

system = forcefield.createSystem(modeller.topology,soluteDielectric=1.0,solventDielectric=80.0,nonbondedMethod=NoCutoff,removeCMMotion=True,constraints=HBonds)

for i, f in enumerate(system.getForces()):
    f.setForceGroup(i)

integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)
simulation = Simulation(modeller.topology, system, integrator)
simulation.context.setPositions(modeller.positions)

print('# Initial energy contribustions')
sys.stdout.flush()
for i, f in enumerate(system.getForces()):
    state = simulation.context.getState(getEnergy=True, groups={i})
    print(f.getName(), state.getPotentialEnergy())

print('# optimize initial strucure')
sys.stdout.flush()
simulation.minimizeEnergy(maxIterations=1000)

print('# simulated annealing')
sys.stdout.flush()

for i in range(nwin):
    ct = 300-i*(300.0/nwin)
    if ct<0.0:
        ct=0.0
    integrator.setTemperature(ct*kelvin)
    simulation.step(winsteps)

print('energy contribustions after SA')
sys.stdout.flush()
for i, f in enumerate(system.getForces()):
    state = simulation.context.getState(getEnergy=True, groups={i})
    print(f.getName(), state.getPotentialEnergy())

print('saving final structure')
sys.stdout.flush()

opt_filename = 'opt-' + protein_filename
positions = simulation.context.getState(getPositions=True).getPositions()
PDBFile.writeFile(simulation.topology, positions, open(opt_filename, 'w'),keepIds=True)

t_end = perf_counter()

print('# Elapsed time/sec: ',t_end-t_start)
