From 1ed1d6073f788fbd93a3699a6551457cc9fe8d56 Mon Sep 17 00:00:00 2001 From: "Matthew W. Thompson" Date: Mon, 4 Mar 2024 13:54:29 -0600 Subject: [PATCH] ENH: Store cosmetic attributes --- .../unit_tests/components/test_potentials.py | 25 +++++- .../_tests/unit_tests/smirnoff/test_create.py | 42 +++++++++ openff/interchange/models.py | 6 +- openff/interchange/smirnoff/_base.py | 18 +++- openff/interchange/smirnoff/_valence.py | 87 ++++++++++++++++--- 5 files changed, 162 insertions(+), 16 deletions(-) diff --git a/openff/interchange/_tests/unit_tests/components/test_potentials.py b/openff/interchange/_tests/unit_tests/components/test_potentials.py index 274e43000..3e2fd0d3e 100644 --- a/openff/interchange/_tests/unit_tests/components/test_potentials.py +++ b/openff/interchange/_tests/unit_tests/components/test_potentials.py @@ -1,12 +1,13 @@ import pytest +from openff.toolkit import Molecule, unit from openff.toolkit.typing.engines.smirnoff.parameters import BondHandler -from openff.units import unit from openff.interchange.components.potentials import ( Collection, Potential, WrappedPotential, ) +from openff.interchange.smirnoff._valence import SMIRNOFFBondCollection class TestWrappedPotential: @@ -40,6 +41,28 @@ def test_interpolated_potentials(self): assert simple.parameters == pot2.parameters +class TestCosmeticAttributes: + def test_potential_with_cosmetic_attributes(self): + cosmetic = BondHandler.BondType( + smirks="[*:1]~[*:2]", + id="cos1", + k="430.0 * kilocalories_per_mole/angstrom**2", + length="1.33 * angstrom", + foo="bar", + allow_cosmetic_attributes=True, + ) + + handler = BondHandler(version=0.4) + handler.add_parameter(parameter=cosmetic) + + collection = SMIRNOFFBondCollection() + + collection.store_matches( + parameter_handler=handler, + topology=Molecule.from_smiles("O").to_topology(), + ) + + class TestCollectionSubclassing: def test_dummy_collection(self): handler = Collection( diff --git a/openff/interchange/_tests/unit_tests/smirnoff/test_create.py b/openff/interchange/_tests/unit_tests/smirnoff/test_create.py index e57e3f5b7..f97f1fa2f 100644 --- a/openff/interchange/_tests/unit_tests/smirnoff/test_create.py +++ b/openff/interchange/_tests/unit_tests/smirnoff/test_create.py @@ -64,6 +64,48 @@ def test_infer_positions(self, sage, ethanol): 3, ) + def test_cosmetic_attributes(self): + from openff.toolkit._tests.test_forcefield import xml_ff_w_cosmetic_elements + + force_field = ForceField( + "openff-2.1.0.offxml", + xml_ff_w_cosmetic_elements.replace( + 'Bonds version="0.3"', + 'Bonds version="0.4"', + ), + allow_cosmetic_attributes=True, + ) + + bonds = SMIRNOFFBondCollection() + + bonds.store_matches( + parameter_handler=force_field["Bonds"], + topology=Molecule.from_smiles("CC").to_topology(), + ) + + for key in bonds.potentials: + if key.id == "[#6X4:1]-[#6X4:2]": + assert key.cosmetic_attributes == { + "parameters": "k, length", + "parameterize_eval": "blah=blah2", + } + + def test_all_cosmetic(self, sage, basic_top): + for handler in sage.registered_parameter_handlers: + for parameter in sage[handler].parameters: + parameter._cosmetic_attribs = ["fOO"] + parameter._fOO = "bAR" + parameter.fOO = "bAR" + + out = sage.create_interchange(basic_top) + + for collection in out.collections: + if collection == "Electrostatics": + continue + + for potential_key in out[collection].potentials: + assert potential_key.cosmetic_attributes["fOO"] == "bAR" + @pytest.mark.slow() class TestUnassignedParameters: diff --git a/openff/interchange/models.py b/openff/interchange/models.py index e1df0cf78..16438c37f 100644 --- a/openff/interchange/models.py +++ b/openff/interchange/models.py @@ -1,7 +1,7 @@ """Custom Pydantic models.""" import abc -from typing import Literal, Optional +from typing import Any, Literal, Optional from openff.models.models import DefaultModel @@ -289,6 +289,10 @@ class PotentialKey(DefaultModel): None, description="The 'type' of virtual site (i.e. `BondCharge`) this parameter is associated with.", ) + cosmetic_attributes: dict[str, Any] = Field( + dict(), + description="A dictionary of cosmetic attributes associated with this potential key.", + ) def __hash__(self) -> int: return hash((self.id, self.mult, self.associated_handler, self.bond_order)) diff --git a/openff/interchange/smirnoff/_base.py b/openff/interchange/smirnoff/_base.py index bb72ce3e8..28966cf5f 100644 --- a/openff/interchange/smirnoff/_base.py +++ b/openff/interchange/smirnoff/_base.py @@ -78,6 +78,7 @@ def collection_loader(data: str) -> dict: else: topology_key = LibraryChargeTopologyKey.parse_raw(key_) + # TODO: Not obvious if cosmetic attributes survive here potential_key = PotentialKey(**val_) key_map[topology_key] = potential_key @@ -247,13 +248,28 @@ def store_matches( Union[TopologyKey, LibraryChargeTopologyKey], PotentialKey, ] = dict() + matches = parameter_handler.find_matches(topology) + for key, val in matches.items(): + parameter: ParameterHandler.ParameterType = val.parameter_type + + cosmetic_attributes = { + cosmetic_attribute: getattr( + parameter, + f"_{cosmetic_attribute}", + ) + for cosmetic_attribute in parameter._cosmetic_attribs + } + topology_key = TopologyKey(atom_indices=key) + potential_key = PotentialKey( - id=val.parameter_type.smirks, + id=parameter.smirks, associated_handler=parameter_handler.TAGNAME, + cosmetic_attributes=cosmetic_attributes, ) + self.key_map[topology_key] = potential_key if self.__class__.__name__ in [ diff --git a/openff/interchange/smirnoff/_valence.py b/openff/interchange/smirnoff/_valence.py index 73a6ea776..f33952aa4 100644 --- a/openff/interchange/smirnoff/_valence.py +++ b/openff/interchange/smirnoff/_valence.py @@ -141,8 +141,17 @@ def store_matches( self.key_map: dict[BondKey, PotentialKey] = dict() # type: ignore[assignment] matches = parameter_handler.find_matches(topology) for key, val in matches.items(): - param = val.parameter_type - if param.k_bondorder or param.length_bondorder: + parameter: BondHandler.BondType = val.parameter_type + + cosmetic_attributes = { + cosmetic_attribute: getattr( + parameter, + f"_{cosmetic_attribute}", + ) + for cosmetic_attribute in parameter._cosmetic_attribs + } + + if parameter.k_bondorder or parameter.length_bondorder: bond = topology.get_bond_between(*key) fractional_bond_order = bond.fractional_bond_order if not fractional_bond_order: @@ -158,10 +167,12 @@ def store_matches( ) potential_key = PotentialKey( - id=val.parameter_type.smirks, + id=parameter.smirks, associated_handler=parameter_handler.TAGNAME, bond_order=fractional_bond_order, + cosmetic_attributes=cosmetic_attributes, ) + self.key_map[topology_key] = potential_key valence_terms = self.valence_terms(topology) @@ -359,16 +370,31 @@ def store_constraints( for key, match in constraint_matches.items(): topology_key = BondKey(atom_indices=key) - smirks = match.parameter_type.smirks - distance = match.parameter_type.distance + + parameter = match.parameter_type + + smirks = parameter.smirks + distance = parameter.distance + cosmetic_attributes = { + cosmetic_attribute: getattr( + parameter, + f"_{cosmetic_attribute}", + ) + for cosmetic_attribute in parameter._cosmetic_attribs + } + if distance is not None: # This constraint parameter is fully specified potential_key = PotentialKey( id=smirks, associated_handler="Constraints", + cosmetic_attributes=cosmetic_attributes, ) + self.key_map[topology_key] = potential_key - distance = match.parameter_type.distance + + distance = parameter.distance + else: # This constraint parameter depends on the BondHandler ... if bond_handler is None: @@ -377,15 +403,20 @@ def store_constraints( "specified, and no corresponding bond parameters were found. The distance " "of this constraint is not specified.", ) + # ... so use the same PotentialKey instance as the BondHandler to look up the distance potential_key = bonds.key_map[topology_key] # type: ignore[union-attr] + self.key_map[topology_key] = potential_key + distance = bonds.potentials[potential_key].parameters["length"] # type: ignore[union-attr] + potential = Potential( parameters={ "distance": distance, }, ) + self.potentials[potential_key] = potential @@ -470,11 +501,22 @@ def store_matches( self.key_map: dict[ProperTorsionKey, PotentialKey] = dict() # type: ignore[assignment] matches = parameter_handler.find_matches(topology) for key, val in matches.items(): - param = val.parameter_type - n_terms = len(val.parameter_type.phase) + parameter: ProperTorsionHandler.ProperTorsionType = val.parameter_type + + n_terms = len(parameter.phase) + + cosmetic_attributes = { + cosmetic_attribute: getattr( + parameter, + f"_{cosmetic_attribute}", + ) + for cosmetic_attribute in parameter._cosmetic_attribs + } + for n in range(n_terms): - smirks = param.smirks - if param.k_bondorder: + smirks = parameter.smirks + + if parameter.k_bondorder: # The relevant bond order is that of the _central_ bond in the torsion bond = topology.get_bond_between(key[1], key[2]) fractional_bond_order = bond.fractional_bond_order @@ -484,17 +526,21 @@ def store_matches( ) else: fractional_bond_order = None + topology_key = ProperTorsionKey( atom_indices=key, mult=n, bond_order=fractional_bond_order, ) + potential_key = PotentialKey( id=smirks, mult=n, associated_handler="ProperTorsions", bond_order=fractional_bond_order, + cosmetic_attributes=cosmetic_attributes, ) + self.key_map[topology_key] = potential_key _check_all_valence_terms_assigned( @@ -513,7 +559,7 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None: smirks = potential_key.id n = potential_key.mult parameter = parameter_handler.parameters[smirks] - # n_terms = len(parameter.k) + if topology_key.bond_order: bond_order = topology_key.bond_order data = parameter.k_bondorder[n] @@ -634,9 +680,22 @@ def store_matches( (1, 3), ], ) - n_terms = len(val.parameter_type.k) + + parameter: ImproperTorsionHandler.ImproperTorsionType = val.parameter_type + + n_terms = len(parameter.phase) + + cosmetic_attributes = { + cosmetic_attribute: getattr( + parameter, + f"_{cosmetic_attribute}", + ) + for cosmetic_attribute in parameter._cosmetic_attribs + } + for n in range(n_terms): - smirks = val.parameter_type.smirks + smirks = parameter.smirks + non_central_indices = [key[0], key[2], key[3]] for permuted_key in [ @@ -655,7 +714,9 @@ def store_matches( id=smirks, mult=n, associated_handler="ImproperTorsions", + cosmetic_attributes=cosmetic_attributes, ) + self.key_map[topology_key] = potential_key def store_potentials(self, parameter_handler: ImproperTorsionHandler) -> None: