Skip to content

Commit

Permalink
Add QoL helpers for building 'closures' (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored May 5, 2024
1 parent cd94f82 commit e5a7f1c
Show file tree
Hide file tree
Showing 12 changed files with 382 additions and 55 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ name: Publish Documentation

on:
push:
branches:
- main
tags:
- '*'
branches: ["main"]
tags: ["*"]

jobs:
deploy-docs:
Expand Down Expand Up @@ -43,11 +41,12 @@ jobs:
git config --global --add safe.directory "$PWD"
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
git pull origin gh-pages
git fetch --all --prune
make env
sed -i 's/# extensions/extensions/' mkdocs.yml
make docs-insiders INSIDER_DOCS_TOKEN="${INSIDER_DOCS_TOKEN}"
make docs-insiders INSIDER_DOCS_TOKEN="${INSIDER_DOCS_TOKEN}"
make docs-deploy VERSION="$VERSION"
4 changes: 2 additions & 2 deletions descent/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Custom parameter optimizers."""

from descent.optim._lm import LevenbergMarquardtConfig, levenberg_marquardt
from descent.optim._lm import ClosureFn, LevenbergMarquardtConfig, levenberg_marquardt

__all__ = ["LevenbergMarquardtConfig", "levenberg_marquardt"]
__all__ = ["ClosureFn", "LevenbergMarquardtConfig", "levenberg_marquardt"]
29 changes: 29 additions & 0 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import tqdm

import descent.utils.dataset
import descent.utils.loss
import descent.utils.molecule
import descent.utils.reporting

if typing.TYPE_CHECKING:
import pandas

import descent.train


EnergyFn = typing.Callable[
["pandas.DataFrame", tuple[str, ...], torch.Tensor], torch.Tensor
Expand Down Expand Up @@ -277,6 +280,32 @@ def predict(
return torch.cat(reference), torch.cat(predicted)


def default_closure(
trainable: "descent.train.Trainable",
topologies: dict[str, smee.TensorTopology],
dataset: datasets.Dataset,
):
"""Return a default closure function for training against dimer energies.
Args:
trainable: The wrapper around trainable parameters.
topologies: The topologies of the molecules present in the dataset, with keys
of mapped SMILES patterns.
dataset: The dataset to train against.
Returns:
The default closure function.
"""

def loss_fn(_x: torch.Tensor) -> torch.Tensor:
y_ref, y_pred = descent.targets.dimers.predict(
dataset, trainable.to_force_field(_x), topologies
)
return ((y_pred - y_ref) ** 2).sum()

return descent.utils.loss.to_closure(loss_fn)


def _plot_energies(energies: dict[str, torch.Tensor]) -> str:
from matplotlib import pyplot

Expand Down
107 changes: 86 additions & 21 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
import smee.mm
import smee.utils
import torch
from rdkit import Chem

import descent.optim
import descent.utils.dataset
import descent.utils.loss
import descent.utils.molecule

if typing.TYPE_CHECKING:
import descent.train


_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -138,24 +144,6 @@ class _Observables(typing.NamedTuple):
_SystemDict = dict[SimulationKey, smee.TensorSystem]


def _map_smiles(smiles: str) -> str:
"""Add atom mapping to a SMILES string if it is not already present."""
params = Chem.SmilesParserParams()
params.removeHs = False

mol = Chem.AddHs(Chem.MolFromSmiles(smiles, params))

map_idxs = sorted(atom.GetAtomMapNum() for atom in mol.GetAtoms())

if map_idxs == list(range(1, len(map_idxs) + 1)):
return smiles

for i, atom in enumerate(mol.GetAtoms()):
atom.SetAtomMapNum(i + 1)

return Chem.MolToSmiles(mol)


def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""Create a dataset from a list of existing data points.
Expand All @@ -167,12 +155,12 @@ def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""

for row in rows:
row["smiles_a"] = _map_smiles(row["smiles_a"])
row["smiles_a"] = descent.utils.molecule.map_smiles(row["smiles_a"])

if row["smiles_b"] is None:
continue

row["smiles_b"] = _map_smiles(row["smiles_b"])
row["smiles_b"] = descent.utils.molecule.map_smiles(row["smiles_b"])

# TODO: validate rows
table = pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)
Expand Down Expand Up @@ -582,6 +570,7 @@ def predict(
output_dir: pathlib.Path,
cached_dir: pathlib.Path | None = None,
per_type_scales: dict[DataType, float] | None = None,
verbose: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict the properties in a dataset using molecular simulation, or by reweighting
previous simulation data.
Expand All @@ -596,6 +585,7 @@ def predict(
from.
per_type_scales: The scale factor to apply to each data type. A default of 1.0
will be used for any data type not specified.
verbose: Whether to log additional information.
"""

entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)]
Expand All @@ -616,6 +606,8 @@ def predict(
reference = []
reference_std = []

verbose_rows = []

per_type_scales = per_type_scales if per_type_scales is not None else {}

for entry, keys in zip(entries, entry_to_simulation):
Expand All @@ -631,10 +623,83 @@ def predict(
torch.nan if entry["std"] is None else entry["std"] * abs(type_scale)
)

if verbose:
std_ref = "" if entry["std"] is None else f" ± {float(entry['std']):.3f}"

verbose_rows.append(
{
"type": f'{entry["type"]} [{entry["units"]}]',
"smiles_a": descent.utils.molecule.unmap_smiles(entry["smiles_a"]),
"smiles_b": (
""
if entry["smiles_b"] is None
else descent.utils.molecule.unmap_smiles(entry["smiles_b"])
),
"pred": f"{float(value):.3f} ± {float(std):.3f}",
"ref": f"{float(entry['value']):.3f}{std_ref}",
}
)

if verbose:
import pandas

_LOGGER.info(f"predicted {len(entries)} properties")
_LOGGER.info("\n" + pandas.DataFrame(verbose_rows).to_string(index=False))

predicted = torch.stack(predicted)
predicted_std = torch.stack(predicted_std)

reference = smee.utils.tensor_like(reference, predicted)
reference_std = smee.utils.tensor_like(reference_std, predicted_std)

return reference, reference_std, predicted, predicted_std


def default_closure(
trainable: "descent.train.Trainable",
topologies: dict[str, smee.TensorTopology],
dataset: datasets.Dataset,
per_type_scales: dict[DataType, float] | None = None,
verbose: bool = False,
) -> descent.optim.ClosureFn:
"""Return a default closure function for training against thermodynamic
properties.
Args:
trainable: The wrapper around trainable parameters.
topologies: The topologies of the molecules present in the dataset, with keys
of mapped SMILES patterns.
dataset: The dataset to train against.
per_type_scales: The scale factor to apply to each data type.
verbose: Whether to log additional information about predictions.
Returns:
The default closure function.
"""

def closure_fn(
x: torch.Tensor,
compute_gradient: bool,
compute_hessian: bool,
):
force_field = trainable.to_force_field(x)

y_ref, _, y_pred, _ = descent.targets.thermo.predict(
dataset,
force_field,
topologies,
pathlib.Path.cwd(),
None,
per_type_scales,
verbose,
)
loss, gradient, hessian = ((y_pred - y_ref) ** 2).sum(), None, None

if compute_hessian:
hessian = descent.utils.loss.approximate_hessian(x, y_pred)
if compute_gradient:
gradient = torch.autograd.grad(loss, x, retain_graph=True)[0].detach()

return loss.detach(), gradient, hessian

return closure_fn
31 changes: 31 additions & 0 deletions descent/tests/targets/test_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
compute_dimer_energy,
create_dataset,
create_from_des,
default_closure,
extract_smiles,
predict,
report,
Expand Down Expand Up @@ -184,6 +185,36 @@ def test_predict(mock_dimer, mocker):
)


def test_default_closure(mock_dimer, mocker):
dataset = create_dataset([mock_dimer])

mock_x = torch.tensor([2.0], requires_grad=True)

mock_y_pred = torch.tensor([3.0, 4.0]) * mock_x
mock_y_ref = torch.Tensor([-1.23, 4.56])

mocker.patch(
"descent.targets.dimers.predict",
autospec=True,
return_value=(mock_y_ref, mock_y_pred),
)
mock_topologies = {
mock_dimer["smiles_a"]: mocker.MagicMock(),
mock_dimer["smiles_b"]: mocker.MagicMock(),
}
mock_trainable = mocker.MagicMock()

closure_fn = default_closure(mock_trainable, mock_topologies, dataset)

expected_loss = (mock_y_pred - mock_y_ref).pow(2).sum()

loss, grad, hess = closure_fn(mock_x, compute_gradient=True, compute_hessian=True)

assert torch.isclose(loss, expected_loss)
assert grad.shape == mock_x.shape
assert hess.shape == (1, 1)


def test_report(tmp_cwd, mock_dimer, mocker):
dataset = create_dataset([mock_dimer])

Expand Down
52 changes: 35 additions & 17 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
SimulationKey,
_compute_observables,
_convert_entry_to_system,
_map_smiles,
_Observables,
_plan_simulations,
_predict,
_simulate,
create_dataset,
default_closure,
default_config,
extract_smiles,
predict,
Expand Down Expand Up @@ -91,21 +91,6 @@ def mock_hmix() -> DataEntry:
}


@pytest.mark.parametrize(
"smiles, expected",
[
("C", "[C:1]([H:2])([H:3])([H:4])[H:5]"),
("[CH4:1]", "[C:1]([H:2])([H:3])([H:4])[H:5]"),
("[Cl:1][H:2]", "[Cl:1][H:2]"),
("[Cl:2][H:1]", "[Cl:2][H:1]"),
("[Cl:2][H:2]", "[Cl:1][H:2]"),
("[Cl:1][H]", "[Cl:1][H:2]"),
],
)
def test_map_smiles(smiles, expected):
assert _map_smiles(smiles) == expected


def test_create_dataset(mock_density_pure, mock_density_binary):
expected_entries = [mock_density_pure, mock_density_binary]

Expand Down Expand Up @@ -533,7 +518,13 @@ def test_predict(tmp_cwd, mock_density_pure, mocker):
mock_scale = 3.0

y_ref, y_ref_std, y_pred, y_pred_std = predict(
dataset, mock_ff, mock_topologies, tmp_cwd, None, {"density": mock_scale}
dataset,
mock_ff,
mock_topologies,
tmp_cwd,
None,
{"density": mock_scale},
verbose=True,
)

mock_compute.assert_called_once_with(
Expand Down Expand Up @@ -565,3 +556,30 @@ def test_predict(tmp_cwd, mock_density_pure, mocker):
assert torch.allclose(y_pred, expected_y_pred)
assert y_pred_std.shape == expected_y_pred_std.shape
assert torch.allclose(y_pred_std, expected_y_pred_std)


def test_default_closure(tmp_cwd, mock_density_pure, mocker):
dataset = create_dataset(mock_density_pure)

mock_x = torch.tensor([2.0], requires_grad=True)

mock_y_pred = torch.tensor([3.0, 4.0]) * mock_x
mock_y_ref = torch.Tensor([-1.23, 4.56])

mocker.patch(
"descent.targets.thermo.predict",
autospec=True,
return_value=(mock_y_ref, None, mock_y_pred, None),
)
mock_topologies = {mock_density_pure["smiles_a"]: mocker.MagicMock()}
mock_trainable = mocker.MagicMock()

closure_fn = default_closure(mock_trainable, mock_topologies, dataset, None)

expected_loss = (mock_y_pred - mock_y_ref).pow(2).sum()

loss, grad, hess = closure_fn(mock_x, compute_gradient=True, compute_hessian=True)

assert torch.isclose(loss, expected_loss)
assert grad.shape == mock_x.shape
assert hess.shape == (1, 1)
Loading

0 comments on commit e5a7f1c

Please sign in to comment.