Skip to content

Commit

Permalink
Fix solutes with v-sites
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Nov 22, 2024
1 parent 697a106 commit abffa10
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 68 deletions.
152 changes: 138 additions & 14 deletions absolv/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Run calculations defined by a config."""

import collections
import functools
import multiprocessing
import pathlib
Expand All @@ -17,6 +18,7 @@
import openff.toolkit
import openff.utilities
import openmm
import openmm.app
import openmm.unit
import pymbar
import tqdm
Expand All @@ -35,12 +37,128 @@ class PreparedSystem(typing.NamedTuple):
system: openmm.System
"""The alchemically modified OpenMM system."""

topology: openff.toolkit.Topology
"""The OpenFF topology with any box vectors set."""
topology: openmm.app.Topology
"""The OpenMM topology with any box vectors set."""
coords: openmm.unit.Quantity
"""The coordinates of the system."""


def _rebuild_topology(
orig_top: openff.toolkit.Topology,
orig_coords: openmm.unit.Quantity,
system: openmm.System,
) -> tuple[openmm.app.Topology, openmm.unit.Quantity]:
"""Rebuild the topology to also include virtual sites."""
atom_idx_to_residue_idx = {}
atom_idx = 0

for residue_idx, molecule in enumerate(orig_top.molecules):
for _ in molecule.atoms:
atom_idx_to_residue_idx[atom_idx] = residue_idx
atom_idx += 1

particle_idx_to_atom_idx = {}
atom_idx = 0

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
continue

particle_idx_to_atom_idx[particle_idx] = atom_idx
atom_idx += 1

atoms_off = [*orig_top.atoms]
particles = []

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
v_site = system.getVirtualSite(particle_idx)

parent_idxs = {
particle_idx_to_atom_idx[v_site.getParticle(i)]
for i in range(v_site.getNumParticles())
}
parent_residue = atom_idx_to_residue_idx[next(iter(parent_idxs))]

particles.append((-1, parent_residue))
continue

atom_idx = particle_idx_to_atom_idx[particle_idx]
residue_idx = atom_idx_to_residue_idx[atom_idx]

particles.append((atoms_off[atom_idx].atomic_number, residue_idx))

topology = openmm.app.Topology()

if orig_top.box_vectors is not None:
topology.setPeriodicBoxVectors(orig_top.box_vectors.to_openmm())

chain = topology.addChain()

atom_counts_per_residue = collections.defaultdict(
lambda: collections.defaultdict(int)
)

last_residue_idx = -1
residue = None

for atomic_num, residue_idx in particles:
if residue_idx != last_residue_idx:
last_residue_idx = residue_idx
residue = topology.addResidue("UNK", chain)

element = (
None if atomic_num < 0 else openmm.app.Element.getByAtomicNumber(atomic_num)
)
symbol = "X" if element is None else element.symbol

atom_counts_per_residue[residue_idx][atomic_num] += 1
topology.addAtom(
f"{symbol}{atom_counts_per_residue[residue_idx][atomic_num]}".ljust(3, "x"),
element,
residue,
)

_rename_residues(topology)

coords_with_v_sites = []

for particle_idx in range(system.getNumParticles()):
if particle_idx in particle_idx_to_atom_idx:
coords_i = orig_coords[particle_idx_to_atom_idx[particle_idx]]
coords_with_v_sites.append(coords_i.value_in_unit(openmm.unit.angstrom))
else:
coords_with_v_sites.append(numpy.zeros((1, 3)))

coords_with_v_sites = numpy.vstack(coords_with_v_sites) * openmm.unit.angstrom

if len(orig_coords) != len(coords_with_v_sites):
context = openmm.Context(system, openmm.VerletIntegrator(1.0))
context.setPositions(coords_with_v_sites)
context.computeVirtualSites()
coords_with_v_sites = context.getState(getPositions=True).getPositions(
asNumpy=True
)

return topology, coords_with_v_sites


def _rename_residues(topology: openmm.app.Topology):
"""Attempts to assign standard residue names to known residues"""

for residue in topology.residues():
symbols = sorted(
(
atom.element.symbol
for atom in residue.atoms()
if atom.element is not None
)
)

if symbols == ["H", "H", "O"]:
residue.name = "HOH"


def _setup_solvent(
solvent_idx: typing.Literal["solvent-a", "solvent-b"],
components: list[tuple[str, int]],
Expand All @@ -67,19 +185,25 @@ def _setup_solvent(

is_vacuum = n_solvent_molecules == 0

topology, coords = absolv.setup.setup_system(components)
topology.box_vectors = None if is_vacuum else topology.box_vectors
topology_off, coords = absolv.setup.setup_system(components)
topology_off.box_vectors = None if is_vacuum else topology_off.box_vectors

if isinstance(force_field, openff.toolkit.ForceField):
original_system = force_field.create_openmm_system(topology_off)
else:
original_system: openmm.System = force_field(topology_off, coords, solvent_idx)

topology, coords = _rebuild_topology(topology_off, coords, original_system)

atom_indices = absolv.utils.topology.topology_to_atom_indices(topology)
atom_indices = [
{atom.index for atom in residue.atoms()}
for chain in topology.chains()
for residue in chain.residues()
]

alchemical_indices = atom_indices[:n_solute_molecules]
persistent_indices = atom_indices[n_solute_molecules:]

if isinstance(force_field, openff.toolkit.ForceField):
original_system = force_field.create_openmm_system(topology)
else:
original_system: openmm.System = force_field(topology, coords, solvent_idx)

alchemical_system = absolv.fep.apply_fep(
original_system,
alchemical_indices,
Expand Down Expand Up @@ -196,7 +320,7 @@ def _run_eq_phase(
"""
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

Expand Down Expand Up @@ -312,7 +436,7 @@ def _run_phase_end_states(
):
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

Expand Down Expand Up @@ -363,11 +487,11 @@ def _run_switching(
):
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology.to_openmm())
mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology)

trajectory_0 = mdtraj.load_dcd(str(output_dir / "state-0.dcd"), mdtraj_topology)
trajectory_1 = mdtraj.load_dcd(str(output_dir / "state-1.dcd"), mdtraj_topology)
Expand Down
125 changes: 123 additions & 2 deletions absolv/tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import femto.md.constants
import numpy
import openff.toolkit
import openff.units
import openmm.unit
import pytest

Expand Down Expand Up @@ -58,6 +60,125 @@
)


def test_rebuild_topology():
ff = openff.toolkit.ForceField("tip4p_fb.offxml", "openff-2.0.0.offxml")

v_site_handler = ff.get_parameter_handler("VirtualSites")
v_site_handler.add_parameter(
{
"type": "DivalentLonePair",
"match": "once",
"smirks": "[*:2][#7:1][*:3]",
"distance": 0.4 * openff.units.unit.angstrom,
"epsilon": 0.0 * openff.units.unit.kilojoule_per_mole,
"sigma": 0.1 * openff.units.unit.nanometer,
"outOfPlaneAngle": 0.0 * openff.units.unit.degree,
"charge_increment1": 0.0 * openff.units.unit.elementary_charge,
"charge_increment2": 0.0 * openff.units.unit.elementary_charge,
"charge_increment3": 0.0 * openff.units.unit.elementary_charge,
}
)

solute = openff.toolkit.Molecule.from_smiles("c1ccncc1")
solute.generate_conformers(n_conformers=1)
solvent = openff.toolkit.Molecule.from_smiles("O")
solvent.generate_conformers(n_conformers=1)

orig_coords = (
numpy.vstack(
[
solute.conformers[0].m_as("angstrom"),
solvent.conformers[0].m_as("angstrom") + numpy.array([10.0, 0.0, 0.0]),
solvent.conformers[0].m_as("angstrom") + numpy.array([20.0, 0.0, 0.0]),
]
)
* openmm.unit.angstrom
)

expected_box_vectors = numpy.eye(3) * 30.0

orig_top = openff.toolkit.topology.Topology.from_molecules(
[solute, solvent, solvent]
)
orig_top.box_vectors = expected_box_vectors * openmm.unit.angstrom

system = ff.create_openmm_system(orig_top)

n_v_sites = sum(
1 for i in range(system.getNumParticles()) if system.isVirtualSite(i)
)
assert n_v_sites == 3

top, coords = absolv.runner._rebuild_topology(orig_top, orig_coords, system)

found_atoms = [
(
atom.name,
atom.element.symbol if atom.element is not None else None,
atom.residue.index,
atom.residue.name,
)
for atom in top.atoms()
]
expected_atoms = [
("C1x", "C", 0, "UNK"),
("C2x", "C", 0, "UNK"),
("C3x", "C", 0, "UNK"),
("N1x", "N", 0, "UNK"),
("C4x", "C", 0, "UNK"),
("C5x", "C", 0, "UNK"),
("H1x", "H", 0, "UNK"),
("H2x", "H", 0, "UNK"),
("H3x", "H", 0, "UNK"),
("H4x", "H", 0, "UNK"),
("H5x", "H", 0, "UNK"),
("O1x", "O", 1, "HOH"),
("H1x", "H", 1, "HOH"),
("H2x", "H", 1, "HOH"),
("O1x", "O", 2, "HOH"),
("H1x", "H", 2, "HOH"),
("H2x", "H", 2, "HOH"),
("X1x", None, 3, "UNK"),
("X1x", None, 4, "UNK"),
("X1x", None, 5, "UNK"),
]

assert found_atoms == expected_atoms

expected_coords = numpy.array(
[
[0.00241, 0.10097, -0.05663],
[-0.11673, 0.03377, -0.03801],
[-0.11801, -0.08778, 0.03043],
[-0.00041, -0.1375, 0.07759],
[0.112, -0.068, 0.05657],
[0.12301, 0.0524, -0.00964],
[4e-05, 0.19505, -0.11016],
[-0.2116, 0.07095, -0.0744],
[-0.20828, -0.14529, 0.04827],
[0.19995, -0.11721, 0.09863],
[0.21761, 0.10263, -0.02266],
[0.99992, 0.03664, 0.0],
[0.91877, -0.01835, 0.0],
[1.08131, -0.01829, 0.0],
[1.99992, 0.03664, 0.0],
[1.91877, -0.01835, 0.0],
[2.08131, -0.01829, 0.0],
[0.0011, -0.1722, 0.09743],
[0.99994, 0.02611, 0.0],
[1.99994, 0.02611, 0.0],
]
) # manually visually inspected

assert coords.shape == expected_coords.shape
assert numpy.allclose(coords, expected_coords, atol=1.0e-5)

box_vectors = top.getPeriodicBoxVectors().value_in_unit(openmm.unit.angstrom)
box_vectors = numpy.array(box_vectors)

assert numpy.allclose(box_vectors, expected_box_vectors)


def test_setup_fn():
system = absolv.config.System(
solutes={"[Na+]": 1, "[Cl-]": 1}, solvent_a=None, solvent_b={"O": 1}
Expand All @@ -70,8 +191,8 @@ def test_setup_fn():
assert prepared_system_a.system.getNumParticles() == 2
assert prepared_system_b.system.getNumParticles() == 5

assert prepared_system_a.topology.box_vectors is None
assert prepared_system_b.topology.box_vectors is not None
assert prepared_system_a.topology.getPeriodicBoxVectors() is None
assert prepared_system_b.topology.getPeriodicBoxVectors() is not None


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion absolv/tests/utils/test_openmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_create_simulation():

simulation = create_simulation(
system,
topology,
topology.to_openmm(),
expected_coords,
integrator,
femto.md.constants.OpenMMPlatform.REFERENCE,
Expand Down
21 changes: 1 addition & 20 deletions absolv/tests/utils/test_topology.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import openff.toolkit
import pytest

from absolv.utils.topology import topology_to_atom_indices, topology_to_components
from absolv.utils.topology import topology_to_components


@pytest.mark.parametrize("n_counts", [[3, 1, 2], [3, 1, 1]])
Expand All @@ -19,22 +19,3 @@ def test_topology_to_components(n_counts):
("[H][C]([H])([H])[H]", n_counts[1]),
("[H][O][H]", n_counts[2]),
]


def test_topology_to_atom_indices():
topology = openff.toolkit.Topology.from_molecules(
[openff.toolkit.Molecule.from_smiles("O")] * 1
+ [openff.toolkit.Molecule.from_smiles("C")] * 2
+ [openff.toolkit.Molecule.from_smiles("O")] * 3
)

atom_indices = topology_to_atom_indices(topology)

assert atom_indices == [
{0, 1, 2},
{3, 4, 5, 6, 7},
{8, 9, 10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
]
Loading

0 comments on commit abffa10

Please sign in to comment.