Skip to content

Commit

Permalink
Fix the local frame axis setting in ADMP to be consistent with MPID
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 31, 2023
1 parent f92d04a commit d51be82
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 91 deletions.
6 changes: 4 additions & 2 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
JAXOPT_OLD = False
except ImportError:
JAXOPT_OLD = True
print(
import warnings
warnings.warn(
"jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern."
)
except ImportError:
print("jaxopt not found, QEQ cannot be used.")
import warnings
warnings.warn("jaxopt not found, QEQ cannot be used.")
import jax

from jax.scipy.special import erf, erfc
Expand Down
2 changes: 1 addition & 1 deletion dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def graph2top(graph: nx.Graph) -> app.Topology:
return top


def top2rdmol(top: app.Topology, indices: List[int]) -> Chem.rdchem.Mol:
def top2rdmol(top: app.Topology, indices: List[int]):
rdmol = Chem.Mol()
emol = Chem.EditableMol(rdmol)
idx2ridx = {}
Expand Down
4 changes: 2 additions & 2 deletions dmff/api/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def vsites(self):
return iter(self._vsites)

@classmethod
def regularize_aromaticity(cls, mol: Chem.Mol) -> bool:
def regularize_aromaticity(cls, mol) -> bool:
"""
Regularize Aromaticity for a rdkit.Mol object. Rings with exocyclic double bonds will not be set aromatic.
"""
Expand Down Expand Up @@ -681,7 +681,7 @@ def decomptop(top) -> List[List[int]]:
return indices


def top2rdmol(top, indices) -> Chem.rdchem.Mol:
def top2rdmol(top, indices):
rdmol = Chem.Mol()
emol = Chem.EditableMol(rdmol)
idx2ridx = {}
Expand Down
142 changes: 73 additions & 69 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,15 +962,9 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
self.multipole_types.append(attribs[self.key_type])
# record local coords as kz kx ky
# None if not occur
kx = attribs.get("kx", None)
ky = attribs.get("ky", None)
kz = attribs.get("kz", None)
if kx == "-1":
kx = None
if ky == "-1":
ky = None
if kz == "-1":
kz = None
kx = attribs.get("kx", "")
ky = attribs.get("ky", "")
kz = attribs.get("kz", "")
kxs.append(kx)
kys.append(ky)
kzs.append(kz)
Expand Down Expand Up @@ -1124,6 +1118,51 @@ def _find_polarize_key_index(self, atype: str):
if i == atype:
return n
return None

@staticmethod
def setAxisType(kIndices):
# setting up axis_indices and axis_type
ZThenX = 0
Bisector = 1
ZBisect = 2
ThreeFold = 3
ZOnly = 4 # typo fix
NoAxisType = 5
LastAxisTypeIndex = 6

# set axis type

ky = kIndices[2]
kyNegative = False
if ky.startswith('-'):
ky = ky[1:]
kyNegative = True

kx = kIndices[1]
kxNegative = False
if kx.startswith('-'):
kx = kx[1:]
kxNegative = True

kz = kIndices[0]
kzNegative = False
if kz.startswith('-'):
kz = kz[1:]
kzNegative = True

axisType = ZThenX
if (not kz):
axisType = NoAxisType
if (kz and not kx):
axisType = ZOnly
if (kz and kzNegative or kx and kxNegative):
axisType = Bisector
if (kx and kxNegative and ky and kyNegative):
axisType = ZBisect
if (kz and kzNegative and kx and kxNegative and ky and kyNegative):
axisType = ThreeFold

return axisType, [kz, kx, ky]

def createPotential(
self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs
Expand Down Expand Up @@ -1198,25 +1237,7 @@ def createPotential(
kz = self.kStrings["kz"][i_type]
kx = self.kStrings["kx"][i_type]
ky = self.kStrings["ky"][i_type]
axisType = ZThenX # Z, X, -1
if kz is None:
axisType = NoAxisType # -1, -1, -1
if kz is not None and kx is None:
axisType = ZOnly # Z, -1, -1
if (kz is not None and kz[0] == "-") or (
kx is not None and kx[0] == "-"
):
axisType = Bisector # Z, X, -1
if (kx is not None and kx[0] == "-") and (
ky is not None and ky[0] == "-"
):
axisType = ZBisect # Z, Y, X
if (
(kz is not None and kz[0] == "-")
and (kx is not None and kx[0] == "-")
and (ky is not None and ky[0] == "-")
):
axisType = ThreeFold # Z, Y, X
axisType, (kz, kx, ky) = self.setAxisType([kz, kx, ky])

zaxis = -1
xaxis = -1
Expand All @@ -1228,90 +1249,73 @@ def createPotential(
if axisType == ZThenX:
if kz is None or kx is None:
raise DMFFException("ZThenX axis requires both kz and kx!")
kz_real = kz[1:] if kz[0] == "-" else kz
kx_real = kx[1:] if kx[0] == "-" else kx
# 1-2, 1-2
if neighbors.shape[0] > 1:
# find zaxis
for i_neighbor in neighbors:
if kz_real == atoms[i_neighbor].meta[self.key_type]:
zaxis = i_neighbor
break
if zaxis < 0:
continue
# find xaxis
for i_neighbor in neighbors:
if i_neighbor == zaxis:
continue
if kx_real == atoms[i_neighbor].meta[self.key_type]:
xaxis = i_neighbor
break
if xaxis < 0:
continue
# 1-2, 1-3
elif neighbors.shape[0] == 1:
if kz_real == atoms[neighbors[0]].meta[self.key_type]:
zaxis = neighbors[0]
else:
# find zaxis
for i_neighbor in neighbors:
if kz == atoms[i_neighbor].meta[self.key_type]:
zaxis = i_neighbor
break
if zaxis < 0:
continue
# find xaxis on 1-2 pairs
for i_neighbor in neighbors:
if i_neighbor == zaxis:
continue
# find xaxis
if kx == atoms[i_neighbor].meta[self.key_type]:
xaxis = i_neighbor
break
if xaxis < 0:
# find xaxis on 1-3 pairs
neighbors2 = np.where(covalent_map[zaxis] == 1)[0]
for j_neighbor in neighbors2:
if j_neighbor == i_atom:
continue
if kx_real == atoms[j_neighbor].meta[self.key_type]:
if kx == atoms[j_neighbor].meta[self.key_type]:
xaxis = j_neighbor
break
if xaxis < 0:
continue
if xaxis < 0:
continue
elif axisType == ZOnly:
kz_real = kz[1:] if kz[0] == "-" else kz
# 1-2 only
for i_neighbor in neighbors:
if kz_real == atoms[i_neighbor].meta[self.key_type]:
if kz == atoms[i_neighbor].meta[self.key_type]:
zaxis = i_neighbor
break
if zaxis < 0:
continue
elif axisType == Bisector:
kz_real = kz[1:] if kz[0] == "-" else kz
kx_real = kx[1:] if kx[0] == "-" else kx
for i_neighbor in neighbors:
if kz_real == atoms[i_neighbor].meta[self.key_type]:
if kz == atoms[i_neighbor].meta[self.key_type]:
zaxis = i_neighbor
break
if zaxis < 0:
continue
for j_neighbor in neighbors:
if j_neighbor == zaxis:
continue
if kx_real == atoms[j_neighbor].meta[self.key_type]:
if kx == atoms[j_neighbor].meta[self.key_type]:
xaxis = j_neighbor
break
if xaxis < 0:
continue
elif axisType in [ZBisect, ThreeFold]:
kz_real = kz[1:] if kz[0] == "-" else kz
kx_real = kx[1:] if kx[0] == "-" else kx
ky_real = ky[1:] if ky[0] == "-" else ky
for i_neighbor in neighbors:
if kz_real == atoms[i_neighbor].meta[self.key_type]:
if kz == atoms[i_neighbor].meta[self.key_type]:
zaxis = i_neighbor
break
if zaxis < 0:
continue
for j_neighbor in neighbors:
if j_neighbor == zaxis:
continue
if kx_real == atoms[j_neighbor].meta[self.key_type]:
if kx == atoms[j_neighbor].meta[self.key_type]:
xaxis = j_neighbor
break
if xaxis < 0:
continue
for k_neighbor in neighbors:
if k_neighbor == zaxis or k_neighbor == xaxis:
continue
if ky_real == atoms[k_neighbor].meta[self.key_type]:
if ky == atoms[k_neighbor].meta[self.key_type]:
yaxis = k_neighbor
break
if yaxis < 0:
Expand Down
6 changes: 5 additions & 1 deletion dmff/mbar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
import mdtraj as md
try:
import mdtraj as md
except ImportError:
import warnings
warnings.warn("MDTraj not installed. MBAREstimator is not available.")

try:
from pymbar import MBAR
Expand Down
2 changes: 1 addition & 1 deletion dmff/operators/parmed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def renderLennardJonesXML(self, filename):
f.write("</ForceField>\n")

@classmethod
def overwriteLennardJones(cls, top: parmed.gromacs.GromacsTopologyFile, ffinfo: dict):
def overwriteLennardJones(cls, top, ffinfo: dict):
nodes = [n for n in ffinfo["Forces"]["LennardJonesForce"]["node"]]
prm = {}
for node in nodes:
Expand Down
6 changes: 4 additions & 2 deletions dmff/sgnn/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/usr/bin/env python
import sys
from functools import partial
from itertools import permutations, product

import jax.numpy as jnp
# import MDAnalysis as mda
import mdtraj as md
try:
import mdtraj as md
except ImportError:
pass
import numpy as np
from ..admp.pairwise import distribute_scalar, distribute_v3
from ..admp.spatial import pbc_shift
Expand Down
15 changes: 6 additions & 9 deletions docs/user_guide/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@ conda create -n dmff python=3.9 --yes
```
+ Install [jax](https://github.com/google/jax) (select the correct cuda version, see more details in the Jax installation guide):
```bash
pip install "jaxlib[cuda11_cudnn805]==0.3.15" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install jax==0.3.17
# CPU version
pip install "jax[cpu]==0.4.19"
# GPU version
pip install "jax[cuda11_local]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
+ Install [jax-md](https://github.com/google/jax-md):
```bash
pip install jax-md==0.2.0
```
+ Install [mdtraj](https://github.com/mdtraj/mdtraj), [optax](https://github.com/deepmind/optax) and [pymbar](https://github.com/choderalab/pymbar):
+ Install [mdtraj](https://github.com/mdtraj/mdtraj), [optax](https://github.com/deepmind/optax), [jaxopt](https://github.com/google/jaxopt) and [pymbar](https://github.com/choderalab/pymbar):
```bash
conda install -c conda-forge mdtraj==1.9.7
pip install optax==0.1.3
pip install pymbar==4.0.1
pip install optax==0.1.3 pymbar==4.0.1 jaxopt==0.8.1
```
+ Install [OpenMM](https://openmm.org/):
```bash
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"freud-analysis",
"networkx>=3.0",
"optax>=0.1.4",
"jaxopt>=0.8.0",
"pymbar>=4.0.0",
"tqdm"
]

Expand Down
6 changes: 5 additions & 1 deletion tests/test_dimer/test_run_dimer_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import jax
import openmm.app as app
import openmm.unit as unit
from rdkit import Chem
import numpy as np


Expand Down Expand Up @@ -45,6 +44,7 @@ def test_load_protein():


def test_build_dimer():
from rdkit import Chem
ffinfo = load_xml()

mol_top = DMFFTopology(from_sdf="tests/data/dimer/ligand.mol")
Expand Down Expand Up @@ -88,6 +88,7 @@ def test_build_dimer():
print(energy_func.getPotentialFunc()(pos, box, pairs, hamilt.paramset))

def test_dimer_coul():
from rdkit import Chem
ffinfo = load_xml()

print("------- Jax Ener Func --------")
Expand Down Expand Up @@ -214,6 +215,7 @@ def test_dimer_coul():


def test_dimer_lj():
from rdkit import Chem
ffinfo = load_xml()
paramset = ParamSet()
lj_gen = LennardJonesGenerator(ffinfo, paramset)
Expand Down Expand Up @@ -358,6 +360,7 @@ def test_dimer_lj():


def test_hamiltonian():
from rdkit import Chem
hamilt = Hamiltonian("tests/data/dimer/forcefield.xml", "tests/data/dimer/gaff2.xml", "tests/data/dimer/amber14_prot.xml")
smarts_type = SMARTSATypeOperator(hamilt.ffinfo)
smarts_vsite = SMARTSVSiteOperator(hamilt.ffinfo)
Expand Down Expand Up @@ -434,6 +437,7 @@ def test_hamiltonian():
hamilt.renderXML("test.xml")

def test_optax():
from rdkit import Chem
hamilt = Hamiltonian("tests/data/dimer/forcefield.xml", "tests/data/dimer/gaff2.xml", "tests/data/dimer/amber14_prot.xml")
hamilt.renderXML("tests/data/dimer/test-init.xml")
smarts_type = SMARTSATypeOperator(hamilt.ffinfo)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_mbar/test_mbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@
import openmm as mm
import numpy as np
import numpy.testing as npt
import mdtraj as md
from pymbar import MBAR
try:
import mdtraj as md
except ImportError as e:
import warnings
warnings.warn(f"mdtraj not found. Tests about MBAR would fail.")
try:
from pymbar import MBAR
except ImportError as e:
import warnings
warnings.warn(f"pymbar not found. Tests about MBAR would fail.")
from dmff import Hamiltonian, NeighborListFreud
from tqdm import tqdm

Expand Down
Loading

0 comments on commit d51be82

Please sign in to comment.