diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py index 7fcc2790b..b263d2b5f 100644 --- a/dmff/admp/qeq.py +++ b/dmff/admp/qeq.py @@ -24,11 +24,13 @@ JAXOPT_OLD = False except ImportError: JAXOPT_OLD = True - print( + import warnings + warnings.warn( "jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern." ) except ImportError: - print("jaxopt not found, QEQ cannot be used.") + import warnings + warnings.warn("jaxopt not found, QEQ cannot be used.") import jax from jax.scipy.special import erf, erfc diff --git a/dmff/api/graph.py b/dmff/api/graph.py index f718971af..5e4085581 100644 --- a/dmff/api/graph.py +++ b/dmff/api/graph.py @@ -104,7 +104,7 @@ def graph2top(graph: nx.Graph) -> app.Topology: return top -def top2rdmol(top: app.Topology, indices: List[int]) -> Chem.rdchem.Mol: +def top2rdmol(top: app.Topology, indices: List[int]): rdmol = Chem.Mol() emol = Chem.EditableMol(rdmol) idx2ridx = {} diff --git a/dmff/api/topology.py b/dmff/api/topology.py index a24992f83..971741258 100644 --- a/dmff/api/topology.py +++ b/dmff/api/topology.py @@ -346,7 +346,7 @@ def vsites(self): return iter(self._vsites) @classmethod - def regularize_aromaticity(cls, mol: Chem.Mol) -> bool: + def regularize_aromaticity(cls, mol) -> bool: """ Regularize Aromaticity for a rdkit.Mol object. Rings with exocyclic double bonds will not be set aromatic. """ @@ -681,7 +681,7 @@ def decomptop(top) -> List[List[int]]: return indices -def top2rdmol(top, indices) -> Chem.rdchem.Mol: +def top2rdmol(top, indices): rdmol = Chem.Mol() emol = Chem.EditableMol(rdmol) idx2ridx = {} diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index 7012dd619..6e688f0ad 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -962,15 +962,9 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): self.multipole_types.append(attribs[self.key_type]) # record local coords as kz kx ky # None if not occur - kx = attribs.get("kx", None) - ky = attribs.get("ky", None) - kz = attribs.get("kz", None) - if kx == "-1": - kx = None - if ky == "-1": - ky = None - if kz == "-1": - kz = None + kx = attribs.get("kx", "") + ky = attribs.get("ky", "") + kz = attribs.get("kz", "") kxs.append(kx) kys.append(ky) kzs.append(kz) @@ -1124,6 +1118,51 @@ def _find_polarize_key_index(self, atype: str): if i == atype: return n return None + + @staticmethod + def setAxisType(kIndices): + # setting up axis_indices and axis_type + ZThenX = 0 + Bisector = 1 + ZBisect = 2 + ThreeFold = 3 + ZOnly = 4 # typo fix + NoAxisType = 5 + LastAxisTypeIndex = 6 + + # set axis type + + ky = kIndices[2] + kyNegative = False + if ky.startswith('-'): + ky = ky[1:] + kyNegative = True + + kx = kIndices[1] + kxNegative = False + if kx.startswith('-'): + kx = kx[1:] + kxNegative = True + + kz = kIndices[0] + kzNegative = False + if kz.startswith('-'): + kz = kz[1:] + kzNegative = True + + axisType = ZThenX + if (not kz): + axisType = NoAxisType + if (kz and not kx): + axisType = ZOnly + if (kz and kzNegative or kx and kxNegative): + axisType = Bisector + if (kx and kxNegative and ky and kyNegative): + axisType = ZBisect + if (kz and kzNegative and kx and kxNegative and ky and kyNegative): + axisType = ThreeFold + + return axisType, [kz, kx, ky] def createPotential( self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs @@ -1198,25 +1237,7 @@ def createPotential( kz = self.kStrings["kz"][i_type] kx = self.kStrings["kx"][i_type] ky = self.kStrings["ky"][i_type] - axisType = ZThenX # Z, X, -1 - if kz is None: - axisType = NoAxisType # -1, -1, -1 - if kz is not None and kx is None: - axisType = ZOnly # Z, -1, -1 - if (kz is not None and kz[0] == "-") or ( - kx is not None and kx[0] == "-" - ): - axisType = Bisector # Z, X, -1 - if (kx is not None and kx[0] == "-") and ( - ky is not None and ky[0] == "-" - ): - axisType = ZBisect # Z, Y, X - if ( - (kz is not None and kz[0] == "-") - and (kx is not None and kx[0] == "-") - and (ky is not None and ky[0] == "-") - ): - axisType = ThreeFold # Z, Y, X + axisType, (kz, kx, ky) = self.setAxisType([kz, kx, ky]) zaxis = -1 xaxis = -1 @@ -1228,56 +1249,42 @@ def createPotential( if axisType == ZThenX: if kz is None or kx is None: raise DMFFException("ZThenX axis requires both kz and kx!") - kz_real = kz[1:] if kz[0] == "-" else kz - kx_real = kx[1:] if kx[0] == "-" else kx - # 1-2, 1-2 - if neighbors.shape[0] > 1: - # find zaxis - for i_neighbor in neighbors: - if kz_real == atoms[i_neighbor].meta[self.key_type]: - zaxis = i_neighbor - break - if zaxis < 0: - continue - # find xaxis - for i_neighbor in neighbors: - if i_neighbor == zaxis: - continue - if kx_real == atoms[i_neighbor].meta[self.key_type]: - xaxis = i_neighbor - break - if xaxis < 0: - continue - # 1-2, 1-3 - elif neighbors.shape[0] == 1: - if kz_real == atoms[neighbors[0]].meta[self.key_type]: - zaxis = neighbors[0] - else: + # find zaxis + for i_neighbor in neighbors: + if kz == atoms[i_neighbor].meta[self.key_type]: + zaxis = i_neighbor + break + if zaxis < 0: + continue + # find xaxis on 1-2 pairs + for i_neighbor in neighbors: + if i_neighbor == zaxis: continue - # find xaxis + if kx == atoms[i_neighbor].meta[self.key_type]: + xaxis = i_neighbor + break + if xaxis < 0: + # find xaxis on 1-3 pairs neighbors2 = np.where(covalent_map[zaxis] == 1)[0] for j_neighbor in neighbors2: if j_neighbor == i_atom: continue - if kx_real == atoms[j_neighbor].meta[self.key_type]: + if kx == atoms[j_neighbor].meta[self.key_type]: xaxis = j_neighbor break - if xaxis < 0: - continue + if xaxis < 0: + continue elif axisType == ZOnly: - kz_real = kz[1:] if kz[0] == "-" else kz # 1-2 only for i_neighbor in neighbors: - if kz_real == atoms[i_neighbor].meta[self.key_type]: + if kz == atoms[i_neighbor].meta[self.key_type]: zaxis = i_neighbor break if zaxis < 0: continue elif axisType == Bisector: - kz_real = kz[1:] if kz[0] == "-" else kz - kx_real = kx[1:] if kx[0] == "-" else kx for i_neighbor in neighbors: - if kz_real == atoms[i_neighbor].meta[self.key_type]: + if kz == atoms[i_neighbor].meta[self.key_type]: zaxis = i_neighbor break if zaxis < 0: @@ -1285,17 +1292,14 @@ def createPotential( for j_neighbor in neighbors: if j_neighbor == zaxis: continue - if kx_real == atoms[j_neighbor].meta[self.key_type]: + if kx == atoms[j_neighbor].meta[self.key_type]: xaxis = j_neighbor break if xaxis < 0: continue elif axisType in [ZBisect, ThreeFold]: - kz_real = kz[1:] if kz[0] == "-" else kz - kx_real = kx[1:] if kx[0] == "-" else kx - ky_real = ky[1:] if ky[0] == "-" else ky for i_neighbor in neighbors: - if kz_real == atoms[i_neighbor].meta[self.key_type]: + if kz == atoms[i_neighbor].meta[self.key_type]: zaxis = i_neighbor break if zaxis < 0: @@ -1303,7 +1307,7 @@ def createPotential( for j_neighbor in neighbors: if j_neighbor == zaxis: continue - if kx_real == atoms[j_neighbor].meta[self.key_type]: + if kx == atoms[j_neighbor].meta[self.key_type]: xaxis = j_neighbor break if xaxis < 0: @@ -1311,7 +1315,7 @@ def createPotential( for k_neighbor in neighbors: if k_neighbor == zaxis or k_neighbor == xaxis: continue - if ky_real == atoms[k_neighbor].meta[self.key_type]: + if ky == atoms[k_neighbor].meta[self.key_type]: yaxis = k_neighbor break if yaxis < 0: diff --git a/dmff/mbar.py b/dmff/mbar.py index 0e5c027dd..4e55e7957 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -1,5 +1,9 @@ import numpy as np -import mdtraj as md +try: + import mdtraj as md +except ImportError: + import warnings + warnings.warn("MDTraj not installed. MBAREstimator is not available.") try: from pymbar import MBAR diff --git a/dmff/operators/parmed.py b/dmff/operators/parmed.py index 1aa336485..273ae5d4d 100644 --- a/dmff/operators/parmed.py +++ b/dmff/operators/parmed.py @@ -50,7 +50,7 @@ def renderLennardJonesXML(self, filename): f.write("\n") @classmethod - def overwriteLennardJones(cls, top: parmed.gromacs.GromacsTopologyFile, ffinfo: dict): + def overwriteLennardJones(cls, top, ffinfo: dict): nodes = [n for n in ffinfo["Forces"]["LennardJonesForce"]["node"]] prm = {} for node in nodes: diff --git a/dmff/sgnn/graph.py b/dmff/sgnn/graph.py index 164a41f3f..a4a6da5e9 100755 --- a/dmff/sgnn/graph.py +++ b/dmff/sgnn/graph.py @@ -1,11 +1,13 @@ -#!/usr/bin/env python import sys from functools import partial from itertools import permutations, product import jax.numpy as jnp # import MDAnalysis as mda -import mdtraj as md +try: + import mdtraj as md +except ImportError: + pass import numpy as np from ..admp.pairwise import distribute_scalar, distribute_v3 from ..admp.spatial import pbc_shift diff --git a/docs/user_guide/installation.md b/docs/user_guide/installation.md index 19f0f31b7..6b158e701 100644 --- a/docs/user_guide/installation.md +++ b/docs/user_guide/installation.md @@ -6,18 +6,15 @@ conda create -n dmff python=3.9 --yes ``` + Install [jax](https://github.com/google/jax) (select the correct cuda version, see more details in the Jax installation guide): ```bash -pip install "jaxlib[cuda11_cudnn805]==0.3.15" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -pip install jax==0.3.17 +# CPU version +pip install "jax[cpu]==0.4.19" +# GPU version +pip install "jax[cuda11_local]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -+ Install [jax-md](https://github.com/google/jax-md): -```bash -pip install jax-md==0.2.0 -``` -+ Install [mdtraj](https://github.com/mdtraj/mdtraj), [optax](https://github.com/deepmind/optax) and [pymbar](https://github.com/choderalab/pymbar): ++ Install [mdtraj](https://github.com/mdtraj/mdtraj), [optax](https://github.com/deepmind/optax), [jaxopt](https://github.com/google/jaxopt) and [pymbar](https://github.com/choderalab/pymbar): ```bash conda install -c conda-forge mdtraj==1.9.7 -pip install optax==0.1.3 -pip install pymbar==4.0.1 +pip install optax==0.1.3 pymbar==4.0.1 jaxopt==0.8.1 ``` + Install [OpenMM](https://openmm.org/): ```bash diff --git a/setup.py b/setup.py index 963f58ac9..57fd41d1b 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,8 @@ "freud-analysis", "networkx>=3.0", "optax>=0.1.4", + "jaxopt>=0.8.0", + "pymbar>=4.0.0", "tqdm" ] diff --git a/tests/test_dimer/test_run_dimer_energy.py b/tests/test_dimer/test_run_dimer_energy.py index ce2787c51..8519da784 100644 --- a/tests/test_dimer/test_run_dimer_energy.py +++ b/tests/test_dimer/test_run_dimer_energy.py @@ -8,7 +8,6 @@ import jax import openmm.app as app import openmm.unit as unit -from rdkit import Chem import numpy as np @@ -45,6 +44,7 @@ def test_load_protein(): def test_build_dimer(): + from rdkit import Chem ffinfo = load_xml() mol_top = DMFFTopology(from_sdf="tests/data/dimer/ligand.mol") @@ -88,6 +88,7 @@ def test_build_dimer(): print(energy_func.getPotentialFunc()(pos, box, pairs, hamilt.paramset)) def test_dimer_coul(): + from rdkit import Chem ffinfo = load_xml() print("------- Jax Ener Func --------") @@ -214,6 +215,7 @@ def test_dimer_coul(): def test_dimer_lj(): + from rdkit import Chem ffinfo = load_xml() paramset = ParamSet() lj_gen = LennardJonesGenerator(ffinfo, paramset) @@ -358,6 +360,7 @@ def test_dimer_lj(): def test_hamiltonian(): + from rdkit import Chem hamilt = Hamiltonian("tests/data/dimer/forcefield.xml", "tests/data/dimer/gaff2.xml", "tests/data/dimer/amber14_prot.xml") smarts_type = SMARTSATypeOperator(hamilt.ffinfo) smarts_vsite = SMARTSVSiteOperator(hamilt.ffinfo) @@ -434,6 +437,7 @@ def test_hamiltonian(): hamilt.renderXML("test.xml") def test_optax(): + from rdkit import Chem hamilt = Hamiltonian("tests/data/dimer/forcefield.xml", "tests/data/dimer/gaff2.xml", "tests/data/dimer/amber14_prot.xml") hamilt.renderXML("tests/data/dimer/test-init.xml") smarts_type = SMARTSATypeOperator(hamilt.ffinfo) diff --git a/tests/test_mbar/test_mbar.py b/tests/test_mbar/test_mbar.py index 8ac1b56d3..2afa6f49b 100644 --- a/tests/test_mbar/test_mbar.py +++ b/tests/test_mbar/test_mbar.py @@ -8,8 +8,16 @@ import openmm as mm import numpy as np import numpy.testing as npt -import mdtraj as md -from pymbar import MBAR +try: + import mdtraj as md +except ImportError as e: + import warnings + warnings.warn(f"mdtraj not found. Tests about MBAR would fail.") +try: + from pymbar import MBAR +except ImportError as e: + import warnings + warnings.warn(f"pymbar not found. Tests about MBAR would fail.") from dmff import Hamiltonian, NeighborListFreud from tqdm import tqdm diff --git a/tests/test_mbar/test_state.py b/tests/test_mbar/test_state.py index b2282c661..a605337d4 100644 --- a/tests/test_mbar/test_state.py +++ b/tests/test_mbar/test_state.py @@ -9,7 +9,11 @@ import numpy as np import numpy.testing as npt import mdtraj as md -from pymbar import MBAR +try: + from pymbar import MBAR +except ImportError as e: + import warnings + warnings.warn(f"pymbar not found. Tests about MBAR would fail.") from dmff import Hamiltonian, NeighborListFreud