diff --git a/absolv/runner.py b/absolv/runner.py index a6a6dad..0e0c255 100644 --- a/absolv/runner.py +++ b/absolv/runner.py @@ -1,5 +1,6 @@ """Run calculations defined by a config.""" +import collections import functools import multiprocessing import pathlib @@ -17,6 +18,7 @@ import openff.toolkit import openff.utilities import openmm +import openmm.app import openmm.unit import pymbar import tqdm @@ -35,12 +37,128 @@ class PreparedSystem(typing.NamedTuple): system: openmm.System """The alchemically modified OpenMM system.""" - topology: openff.toolkit.Topology - """The OpenFF topology with any box vectors set.""" + topology: openmm.app.Topology + """The OpenMM topology with any box vectors set.""" coords: openmm.unit.Quantity """The coordinates of the system.""" +def _rebuild_topology( + orig_top: openff.toolkit.Topology, + orig_coords: openmm.unit.Quantity, + system: openmm.System, +) -> tuple[openmm.app.Topology, openmm.unit.Quantity]: + """Rebuild the topology to also include virtual sites.""" + atom_idx_to_residue_idx = {} + atom_idx = 0 + + for residue_idx, molecule in enumerate(orig_top.molecules): + for _ in molecule.atoms: + atom_idx_to_residue_idx[atom_idx] = residue_idx + atom_idx += 1 + + particle_idx_to_atom_idx = {} + atom_idx = 0 + + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + continue + + particle_idx_to_atom_idx[particle_idx] = atom_idx + atom_idx += 1 + + atoms_off = [*orig_top.atoms] + particles = [] + + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + v_site = system.getVirtualSite(particle_idx) + + parent_idxs = { + particle_idx_to_atom_idx[v_site.getParticle(i)] + for i in range(v_site.getNumParticles()) + } + parent_residue = atom_idx_to_residue_idx[next(iter(parent_idxs))] + + particles.append((-1, parent_residue)) + continue + + atom_idx = particle_idx_to_atom_idx[particle_idx] + residue_idx = atom_idx_to_residue_idx[atom_idx] + + particles.append((atoms_off[atom_idx].atomic_number, residue_idx)) + + topology = openmm.app.Topology() + + if orig_top.box_vectors is not None: + topology.setPeriodicBoxVectors(orig_top.box_vectors.to_openmm()) + + chain = topology.addChain() + + atom_counts_per_residue = collections.defaultdict( + lambda: collections.defaultdict(int) + ) + + last_residue_idx = -1 + residue = None + + for atomic_num, residue_idx in particles: + if residue_idx != last_residue_idx: + last_residue_idx = residue_idx + residue = topology.addResidue("UNK", chain) + + element = ( + None if atomic_num < 0 else openmm.app.Element.getByAtomicNumber(atomic_num) + ) + symbol = "X" if element is None else element.symbol + + atom_counts_per_residue[residue_idx][atomic_num] += 1 + topology.addAtom( + f"{symbol}{atom_counts_per_residue[residue_idx][atomic_num]}".ljust(3, "x"), + element, + residue, + ) + + _rename_residues(topology) + + coords_with_v_sites = [] + + for particle_idx in range(system.getNumParticles()): + if particle_idx in particle_idx_to_atom_idx: + coords_i = orig_coords[particle_idx_to_atom_idx[particle_idx]] + coords_with_v_sites.append(coords_i.value_in_unit(openmm.unit.angstrom)) + else: + coords_with_v_sites.append(numpy.zeros((1, 3))) + + coords_with_v_sites = numpy.vstack(coords_with_v_sites) * openmm.unit.angstrom + + if len(orig_coords) != len(coords_with_v_sites): + context = openmm.Context(system, openmm.VerletIntegrator(1.0)) + context.setPositions(coords_with_v_sites) + context.computeVirtualSites() + coords_with_v_sites = context.getState(getPositions=True).getPositions( + asNumpy=True + ) + + return topology, coords_with_v_sites + + +def _rename_residues(topology: openmm.app.Topology): + """Attempts to assign standard residue names to known residues""" + + for residue in topology.residues(): + symbols = sorted( + ( + atom.element.symbol + for atom in residue.atoms() + if atom.element is not None + ) + ) + + if symbols == ["H", "H", "O"]: + residue.name = "HOH" + + def _setup_solvent( solvent_idx: typing.Literal["solvent-a", "solvent-b"], components: list[tuple[str, int]], @@ -67,19 +185,25 @@ def _setup_solvent( is_vacuum = n_solvent_molecules == 0 - topology, coords = absolv.setup.setup_system(components) - topology.box_vectors = None if is_vacuum else topology.box_vectors + topology_off, coords = absolv.setup.setup_system(components) + topology_off.box_vectors = None if is_vacuum else topology_off.box_vectors + + if isinstance(force_field, openff.toolkit.ForceField): + original_system = force_field.create_openmm_system(topology_off) + else: + original_system: openmm.System = force_field(topology_off, coords, solvent_idx) + + topology, coords = _rebuild_topology(topology_off, coords, original_system) - atom_indices = absolv.utils.topology.topology_to_atom_indices(topology) + atom_indices = [ + {atom.index for atom in residue.atoms()} + for chain in topology.chains() + for residue in chain.residues() + ] alchemical_indices = atom_indices[:n_solute_molecules] persistent_indices = atom_indices[n_solute_molecules:] - if isinstance(force_field, openff.toolkit.ForceField): - original_system = force_field.create_openmm_system(topology) - else: - original_system: openmm.System = force_field(topology, coords, solvent_idx) - alchemical_system = absolv.fep.apply_fep( original_system, alchemical_indices, @@ -196,7 +320,7 @@ def _run_eq_phase( """ platform = ( femto.md.constants.OpenMMPlatform.REFERENCE - if prepared_system.topology.box_vectors is None + if prepared_system.topology.getPeriodicBoxVectors() is None else platform ) @@ -312,7 +436,7 @@ def _run_phase_end_states( ): platform = ( femto.md.constants.OpenMMPlatform.REFERENCE - if prepared_system.topology.box_vectors is None + if prepared_system.topology.getPeriodicBoxVectors() is None else platform ) @@ -363,11 +487,11 @@ def _run_switching( ): platform = ( femto.md.constants.OpenMMPlatform.REFERENCE - if prepared_system.topology.box_vectors is None + if prepared_system.topology.getPeriodicBoxVectors() is None else platform ) - mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology.to_openmm()) + mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology) trajectory_0 = mdtraj.load_dcd(str(output_dir / "state-0.dcd"), mdtraj_topology) trajectory_1 = mdtraj.load_dcd(str(output_dir / "state-1.dcd"), mdtraj_topology) diff --git a/absolv/tests/test_runner.py b/absolv/tests/test_runner.py index 7c30193..1c6b8d7 100644 --- a/absolv/tests/test_runner.py +++ b/absolv/tests/test_runner.py @@ -1,5 +1,7 @@ import femto.md.constants +import numpy import openff.toolkit +import openff.units import openmm.unit import pytest @@ -58,6 +60,125 @@ ) +def test_rebuild_topology(): + ff = openff.toolkit.ForceField("tip4p_fb.offxml", "openff-2.0.0.offxml") + + v_site_handler = ff.get_parameter_handler("VirtualSites") + v_site_handler.add_parameter( + { + "type": "DivalentLonePair", + "match": "once", + "smirks": "[*:2][#7:1][*:3]", + "distance": 0.4 * openff.units.unit.angstrom, + "epsilon": 0.0 * openff.units.unit.kilojoule_per_mole, + "sigma": 0.1 * openff.units.unit.nanometer, + "outOfPlaneAngle": 0.0 * openff.units.unit.degree, + "charge_increment1": 0.0 * openff.units.unit.elementary_charge, + "charge_increment2": 0.0 * openff.units.unit.elementary_charge, + "charge_increment3": 0.0 * openff.units.unit.elementary_charge, + } + ) + + solute = openff.toolkit.Molecule.from_smiles("c1ccncc1") + solute.generate_conformers(n_conformers=1) + solvent = openff.toolkit.Molecule.from_smiles("O") + solvent.generate_conformers(n_conformers=1) + + orig_coords = ( + numpy.vstack( + [ + solute.conformers[0].m_as("angstrom"), + solvent.conformers[0].m_as("angstrom") + numpy.array([10.0, 0.0, 0.0]), + solvent.conformers[0].m_as("angstrom") + numpy.array([20.0, 0.0, 0.0]), + ] + ) + * openmm.unit.angstrom + ) + + expected_box_vectors = numpy.eye(3) * 30.0 + + orig_top = openff.toolkit.topology.Topology.from_molecules( + [solute, solvent, solvent] + ) + orig_top.box_vectors = expected_box_vectors * openmm.unit.angstrom + + system = ff.create_openmm_system(orig_top) + + n_v_sites = sum( + 1 for i in range(system.getNumParticles()) if system.isVirtualSite(i) + ) + assert n_v_sites == 3 + + top, coords = absolv.runner._rebuild_topology(orig_top, orig_coords, system) + + found_atoms = [ + ( + atom.name, + atom.element.symbol if atom.element is not None else None, + atom.residue.index, + atom.residue.name, + ) + for atom in top.atoms() + ] + expected_atoms = [ + ("C1x", "C", 0, "UNK"), + ("C2x", "C", 0, "UNK"), + ("C3x", "C", 0, "UNK"), + ("N1x", "N", 0, "UNK"), + ("C4x", "C", 0, "UNK"), + ("C5x", "C", 0, "UNK"), + ("H1x", "H", 0, "UNK"), + ("H2x", "H", 0, "UNK"), + ("H3x", "H", 0, "UNK"), + ("H4x", "H", 0, "UNK"), + ("H5x", "H", 0, "UNK"), + ("O1x", "O", 1, "HOH"), + ("H1x", "H", 1, "HOH"), + ("H2x", "H", 1, "HOH"), + ("O1x", "O", 2, "HOH"), + ("H1x", "H", 2, "HOH"), + ("H2x", "H", 2, "HOH"), + ("X1x", None, 3, "UNK"), + ("X1x", None, 4, "UNK"), + ("X1x", None, 5, "UNK"), + ] + + assert found_atoms == expected_atoms + + expected_coords = numpy.array( + [ + [0.00241, 0.10097, -0.05663], + [-0.11673, 0.03377, -0.03801], + [-0.11801, -0.08778, 0.03043], + [-0.00041, -0.1375, 0.07759], + [0.112, -0.068, 0.05657], + [0.12301, 0.0524, -0.00964], + [4e-05, 0.19505, -0.11016], + [-0.2116, 0.07095, -0.0744], + [-0.20828, -0.14529, 0.04827], + [0.19995, -0.11721, 0.09863], + [0.21761, 0.10263, -0.02266], + [0.99992, 0.03664, 0.0], + [0.91877, -0.01835, 0.0], + [1.08131, -0.01829, 0.0], + [1.99992, 0.03664, 0.0], + [1.91877, -0.01835, 0.0], + [2.08131, -0.01829, 0.0], + [0.0011, -0.1722, 0.09743], + [0.99994, 0.02611, 0.0], + [1.99994, 0.02611, 0.0], + ] + ) # manually visually inspected + + assert coords.shape == expected_coords.shape + assert numpy.allclose(coords, expected_coords, atol=1.0e-5) + + box_vectors = top.getPeriodicBoxVectors().value_in_unit(openmm.unit.angstrom) + box_vectors = numpy.array(box_vectors) + + assert numpy.allclose(box_vectors, expected_box_vectors) + + def test_setup_fn(): system = absolv.config.System( solutes={"[Na+]": 1, "[Cl-]": 1}, solvent_a=None, solvent_b={"O": 1} @@ -70,8 +191,8 @@ def test_setup_fn(): assert prepared_system_a.system.getNumParticles() == 2 assert prepared_system_b.system.getNumParticles() == 5 - assert prepared_system_a.topology.box_vectors is None - assert prepared_system_b.topology.box_vectors is not None + assert prepared_system_a.topology.getPeriodicBoxVectors() is None + assert prepared_system_b.topology.getPeriodicBoxVectors() is not None @pytest.mark.parametrize( diff --git a/absolv/tests/utils/test_openmm.py b/absolv/tests/utils/test_openmm.py index 20a0152..7ee7723 100644 --- a/absolv/tests/utils/test_openmm.py +++ b/absolv/tests/utils/test_openmm.py @@ -45,7 +45,7 @@ def test_create_simulation(): simulation = create_simulation( system, - topology, + topology.to_openmm(), expected_coords, integrator, femto.md.constants.OpenMMPlatform.REFERENCE, diff --git a/absolv/tests/utils/test_topology.py b/absolv/tests/utils/test_topology.py index facc168..d26457e 100644 --- a/absolv/tests/utils/test_topology.py +++ b/absolv/tests/utils/test_topology.py @@ -1,7 +1,7 @@ import openff.toolkit import pytest -from absolv.utils.topology import topology_to_atom_indices, topology_to_components +from absolv.utils.topology import topology_to_components @pytest.mark.parametrize("n_counts", [[3, 1, 2], [3, 1, 1]]) @@ -19,22 +19,3 @@ def test_topology_to_components(n_counts): ("[H][C]([H])([H])[H]", n_counts[1]), ("[H][O][H]", n_counts[2]), ] - - -def test_topology_to_atom_indices(): - topology = openff.toolkit.Topology.from_molecules( - [openff.toolkit.Molecule.from_smiles("O")] * 1 - + [openff.toolkit.Molecule.from_smiles("C")] * 2 - + [openff.toolkit.Molecule.from_smiles("O")] * 3 - ) - - atom_indices = topology_to_atom_indices(topology) - - assert atom_indices == [ - {0, 1, 2}, - {3, 4, 5, 6, 7}, - {8, 9, 10, 11, 12}, - {13, 14, 15}, - {16, 17, 18}, - {19, 20, 21}, - ] diff --git a/absolv/utils/openmm.py b/absolv/utils/openmm.py index 916e3ca..4c06a91 100644 --- a/absolv/utils/openmm.py +++ b/absolv/utils/openmm.py @@ -1,4 +1,5 @@ """Utilities to manipulate OpenMM objects.""" + import typing import femto.md.constants @@ -46,7 +47,7 @@ def add_barostat( def create_simulation( system: openmm.System, - topology: openff.toolkit.Topology, + topology: openmm.app.Topology, coords: openmm.unit.Quantity, integrator: openmm.Integrator, platform: femto.md.constants.OpenMMPlatform, @@ -69,17 +70,20 @@ def create_simulation( ) platform = openmm.Platform.getPlatformByName(platform) - if topology.box_vectors is not None: - system.setDefaultPeriodicBoxVectors(*topology.box_vectors.to_openmm()) + is_periodic = topology.getPeriodicBoxVectors() is not None + + if is_periodic: + system.setDefaultPeriodicBoxVectors(*topology.getPeriodicBoxVectors()) simulation = openmm.app.Simulation( - topology.to_openmm(), system, integrator, platform, platform_properties + topology, system, integrator, platform, platform_properties ) - if topology.box_vectors is not None: - simulation.context.setPeriodicBoxVectors(*topology.box_vectors.to_openmm()) + if is_periodic: + simulation.context.setPeriodicBoxVectors(*topology.getPeriodicBoxVectors()) simulation.context.setPositions(coords) + simulation.context.computeVirtualSites() simulation.context.setVelocitiesToTemperature(integrator.getTemperature()) return simulation @@ -192,4 +196,6 @@ def extract_frame(trajectory: mdtraj.Trajectory, idx: int) -> openmm.State: if trajectory.unitcell_vectors is not None: context.setPeriodicBoxVectors(*trajectory.openmm_boxes(idx)) + context.computeVirtualSites() + return context.getState(getPositions=True) diff --git a/absolv/utils/topology.py b/absolv/utils/topology.py index 8ba8ce6..5377133 100644 --- a/absolv/utils/topology.py +++ b/absolv/utils/topology.py @@ -1,4 +1,5 @@ """Utilities for manipulating OpenFF topology objects.""" + import openff.toolkit @@ -41,24 +42,3 @@ def topology_to_components(topology: openff.toolkit.Topology) -> list[tuple[str, components.append((current_smiles, current_count)) return components - - -def topology_to_atom_indices(topology: openff.toolkit.Topology) -> list[set[int]]: - """A helper method for extracting the sets of atom indices associated with each - molecule in a topology. - - Args: - topology: The topology to extract the atom indices from. - - Returns: - The set of atoms indices associated with each molecule in the topology. - """ - - atom_indices: list[set[int]] = [] - current_atom_idx = 0 - - for molecule in topology.molecules: - atom_indices.append({i + current_atom_idx for i in range(molecule.n_atoms)}) - current_atom_idx += molecule.n_atoms - - return atom_indices diff --git a/regression/run.py b/regression/run.py index 76a9960..35e1fae 100644 --- a/regression/run.py +++ b/regression/run.py @@ -1,8 +1,8 @@ +import datetime import logging import pathlib import tempfile import urllib.request -import datetime import click import femto.md.config @@ -16,8 +16,8 @@ from rdkit import Chem import absolv.config -import absolv.utils.openmm import absolv.runner +import absolv.utils.openmm DEFAULT_TEMPERATURE = 298.15 * openmm.unit.kelvin DEFAULT_PRESSURE = 1.0 * openmm.unit.atmosphere @@ -170,11 +170,11 @@ def run_replica( femto.md.system.apply_hmr( prepared_system_a.system, - parmed.openmm.load_topology(prepared_system_a.topology.to_openmm()), + parmed.openmm.load_topology(prepared_system_a.topology), ) femto.md.system.apply_hmr( prepared_system_b.system, - parmed.openmm.load_topology(prepared_system_a.topology.to_openmm()), + parmed.openmm.load_topology(prepared_system_a.topology), ) if method == "neq":