Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QoL helpers for building 'closures' #63

Merged
merged 11 commits into from
May 5, 2024
4 changes: 2 additions & 2 deletions descent/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Custom parameter optimizers."""

from descent.optim._lm import LevenbergMarquardtConfig, levenberg_marquardt
from descent.optim._lm import ClosureFn, LevenbergMarquardtConfig, levenberg_marquardt

__all__ = ["LevenbergMarquardtConfig", "levenberg_marquardt"]
__all__ = ["ClosureFn", "LevenbergMarquardtConfig", "levenberg_marquardt"]
4 changes: 4 additions & 0 deletions descent/targets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Targets to train / assess models to / against."""

from descent.targets._targets import combine_closures

__all__ = ["combine_closures"]
79 changes: 79 additions & 0 deletions descent/targets/_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging

import torch

import descent.optim

_LOGGER = logging.getLogger(__name__)


def combine_closures(
closures: dict[str, descent.optim.ClosureFn],
weights: dict[str, float] | None = None,
verbose: bool = False,
) -> descent.optim.ClosureFn:
"""Combine multiple closures into a single closure.

Args:
closures: A dictionary of closure functions.
weights: Optional dictionary of weights for each closure function.
verbose: Whether to log the loss of each closure function.

Returns:
A combined closure function.
"""

weights = weights if weights is not None else {name: 1.0 for name in closures}

if len(closures) == 0:
raise NotImplementedError("At least one closure function is required.")

if {*closures} != {*weights}:
raise ValueError("The closures and weights must have the same keys.")

def combined_closure_fn(
x: torch.Tensor, compute_gradient: bool, compute_hessian: bool
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:

loss = []
grad = None if not compute_gradient else []
hess = None if not compute_hessian else []

verbose_rows = []

for name, closure_fn in closures.items():

local_loss, local_grad, local_hess = closure_fn(
x, compute_gradient, compute_hessian
)

loss.append(weights[name] * local_loss)

if compute_gradient:
grad.append(weights[name] * local_grad)
if compute_hessian:
hess.append(weights[name] * local_hess)

if verbose:
verbose_rows.append(
{"target": name, "loss": float(f"{local_loss:.5f}")}
)

loss = sum(loss[1:], loss[0])

if compute_gradient:
grad = sum(grad[1:], grad[0])
if compute_hessian:
hess = sum(hess[1:], hess[0])

if verbose:
import pandas

_LOGGER.info(
"loss breakdown:\n"
+ pandas.DataFrame(verbose_rows).to_string(index=False)
)

return loss.detach(), grad, hess

return combined_closure_fn
28 changes: 28 additions & 0 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import torch
import tqdm

import descent.train
import descent.utils.dataset
import descent.utils.loss
import descent.utils.molecule
import descent.utils.reporting

Expand Down Expand Up @@ -277,6 +279,32 @@ def predict(
return torch.cat(reference), torch.cat(predicted)


def default_closure(
trainable: descent.train.Trainable,
topologies: dict[str, smee.TensorTopology],
dataset: datasets.Dataset,
):
"""Return a default closure function for training against dimer energies.

Args:
trainable: The wrapper around trainable parameters.
topologies: The topologies of the molecules present in the dataset, with keys
of mapped SMILES patterns.
dataset: The dataset to train against.

Returns:
The default closure function.
"""

def loss_fn(_x: torch.Tensor) -> torch.Tensor:
y_ref, y_pred = descent.targets.dimers.predict(
dataset, trainable.to_force_field(_x), topologies
)
return ((y_pred - y_ref) ** 2).sum()

return descent.utils.loss.to_closure(loss_fn)


def _plot_energies(energies: dict[str, torch.Tensor]) -> str:
from matplotlib import pyplot

Expand Down
98 changes: 77 additions & 21 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import smee.mm
import smee.utils
import torch
from rdkit import Chem

import descent.optim
import descent.train
import descent.utils.dataset
import descent.utils.loss
import descent.utils.molecule

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -138,24 +141,6 @@ class _Observables(typing.NamedTuple):
_SystemDict = dict[SimulationKey, smee.TensorSystem]


def _map_smiles(smiles: str) -> str:
"""Add atom mapping to a SMILES string if it is not already present."""
params = Chem.SmilesParserParams()
params.removeHs = False

mol = Chem.AddHs(Chem.MolFromSmiles(smiles, params))

map_idxs = sorted(atom.GetAtomMapNum() for atom in mol.GetAtoms())

if map_idxs == list(range(1, len(map_idxs) + 1)):
return smiles

for i, atom in enumerate(mol.GetAtoms()):
atom.SetAtomMapNum(i + 1)

return Chem.MolToSmiles(mol)


def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""Create a dataset from a list of existing data points.

Expand All @@ -167,12 +152,12 @@ def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""

for row in rows:
row["smiles_a"] = _map_smiles(row["smiles_a"])
row["smiles_a"] = descent.utils.molecule.map_smiles(row["smiles_a"])

if row["smiles_b"] is None:
continue

row["smiles_b"] = _map_smiles(row["smiles_b"])
row["smiles_b"] = descent.utils.molecule.map_smiles(row["smiles_b"])

# TODO: validate rows
table = pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)
Expand Down Expand Up @@ -582,6 +567,7 @@ def predict(
output_dir: pathlib.Path,
cached_dir: pathlib.Path | None = None,
per_type_scales: dict[DataType, float] | None = None,
verbose: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict the properties in a dataset using molecular simulation, or by reweighting
previous simulation data.
Expand All @@ -596,6 +582,7 @@ def predict(
from.
per_type_scales: The scale factor to apply to each data type. A default of 1.0
will be used for any data type not specified.
verbose: Whether to log additional information.
"""

entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)]
Expand All @@ -616,6 +603,8 @@ def predict(
reference = []
reference_std = []

verbose_rows = []

per_type_scales = per_type_scales if per_type_scales is not None else {}

for entry, keys in zip(entries, entry_to_simulation):
Expand All @@ -631,10 +620,77 @@ def predict(
torch.nan if entry["std"] is None else entry["std"] * abs(type_scale)
)

if verbose:
std_ref = "" if entry["std"] is None else " ± {float(entry['std']):.3f}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std_ref = "" if entry["std"] is None else " ± {float(entry['std']):.3f}"
std_ref = "" if entry["std"] is None else f" ± {float(entry['std']):.3f}"

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, should be fixed in the latest commit!


verbose_rows.append(
{
"type": f'{entry["type"]} [{entry["units"]}]',
"smiles_a": descent.utils.molecule.unmap_smiles(entry["smiles_a"]),
"smiles_b": (
""
if entry["smiles_b"] is None
else descent.utils.molecule.unmap_smiles(entry["smiles_b"])
),
"pred": f"{float(value):.3f} ± {float(std):.3f}",
"ref": f"{float(entry['value']):.3f}{std_ref}",
Copy link
Collaborator

@jthorton jthorton Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If entry["std"] is not None this prints the string " ± {float(entry['std']):.3f}"

}
)

if verbose:
import pandas

_LOGGER.info(f"predicted {len(entries)} properties")
_LOGGER.info("\n" + pandas.DataFrame(verbose_rows).to_string(index=False))

predicted = torch.stack(predicted)
predicted_std = torch.stack(predicted_std)

reference = smee.utils.tensor_like(reference, predicted)
reference_std = smee.utils.tensor_like(reference_std, predicted_std)

return reference, reference_std, predicted, predicted_std


def default_closure(
trainable: descent.train.Trainable,
topologies: dict[str, smee.TensorTopology],
dataset: datasets.Dataset,
scales: dict[DataType, float],
verbose: bool = False,
) -> descent.optim.ClosureFn:
"""Return a default closure function for training against thermodynamic
properties.

Args:
trainable: The wrapper around trainable parameters.
topologies: The topologies of the molecules present in the dataset, with keys
of mapped SMILES patterns.
dataset: The dataset to train against.
scales: The scale factor to apply to each data type.
verbose: Whether to log additional information about predictions.

Returns:
The default closure function.
"""

def closure_fn(
x: torch.Tensor,
compute_gradient: bool,
compute_hessian: bool,
):
force_field = trainable.to_force_field(x)

y_ref, y_ref_std, y_pred, y_pred_std = descent.targets.thermo.predict(
dataset, force_field, topologies, pathlib.Path.cwd(), None, scales, verbose
)
loss, gradient, hessian = ((y_pred - y_ref) ** 2).sum(), None, None

if compute_hessian:
hessian = descent.utils.loss.approximate_hessian(x, y_pred)
if compute_gradient:
gradient = torch.autograd.grad(loss, x, retain_graph=True)[0].detach()

return loss.detach(), gradient, hessian

return closure_fn
22 changes: 21 additions & 1 deletion descent/tests/utils/test_molecule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from rdkit import Chem

from descent.utils.molecule import mol_to_smiles
from descent.utils.molecule import map_smiles, mol_to_smiles, unmap_smiles


@pytest.mark.parametrize(
Expand All @@ -16,3 +16,23 @@ def test_mol_to_smiles(input_smiles, expected_smiles, canonical):
actual_smiles = mol_to_smiles(mol, canonical)

assert actual_smiles == expected_smiles


def test_unmap_smiles():
smiles = "[H:1][C:4]([H:2])([O:3][H:5])[H:6]"
unmapped_smiles = unmap_smiles(smiles)

assert unmapped_smiles == "CO"


@pytest.mark.parametrize(
"input_smiles, expected_smiles",
[
("[H:1][C:4]([H:2])([O:3][H:5])[H:6]", "[H:1][C:4]([H:2])([O:3][H:5])[H:6]"),
("CO", "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]"),
],
)
def test_map_smiles(input_smiles, expected_smiles):
mapped_smiles = map_smiles(input_smiles)

assert mapped_smiles == expected_smiles
Loading
Loading