Skip to content

Commit

Permalink
xtb/ase wrappers (#36)
Browse files Browse the repository at this point in the history
* add xtb wrapper

bugfix in xtb

bugfix in xtb

bugfix xtb

bugfix in xtb

...

...

...

...

...

...

...

xtb fixes

* XTB wrapper docs

* skip tests when not installed

* bugfix in xtb

* add xtb to test env

* unit conversion in xtb

* unit fixes in xtb

* err_handling in xtb

* ...

* try with higher temperature

* bugfix in xtb

* bohr as unit in xtb

* ...

* ase wrapper

* ase tests work

* better docs

* add ase to conda env

* added caching

* remove duplicate code

* fix warnings

* skip ase/xtb tests if not installed

* added forgotten super() call to ASE Bridge

* add optional deps to README

* unify interfaces to remove duplicate code

* remove unused import
  • Loading branch information
Olllom authored Feb 15, 2022
1 parent fcc50de commit ae1dc6a
Show file tree
Hide file tree
Showing 10 changed files with 537 additions and 60 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,16 @@ pytest
* [matplotlib](https://github.com/matplotlib/matplotlib)
* [pytest](https://github.com/pytest-dev/pytest) (for testing)
* [nflows](https://github.com/bayesiains/nflows) (for Neural Spline Flows)
* [OpenMM](https://github.com/openmm/openmm) (for molecular examples)
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq) (for neural ODEs)
* [ANODE](https://github.com/amirgholami/anode) (for neural ODEs)
* [OpenMM](https://github.com/openmm/openmm) (for molecular mechanics energies)
* [ase](https://wiki.fysik.dtu.dk/ase/index.html) (for quantum and molecular mechanics energies through the atomic simulation environment)
* [xtb-python](https://xtb-python.readthedocs.io) (for semi-empirical GFN quantum energies)
* [netCDF4](https://unidata.github.io/netcdf4-python/) (for the `ReplayBufferReporter`)
* [jax](https://github.com/google/jax) (for smooth flows / implicit backprop)
* [jax2torch](https://github.com/lucidrains/jax2torch) (for smooth flows / implicit backprop)


***
## [License](#dependencies)
[MIT License](LICENSE)
3 changes: 2 additions & 1 deletion bgflow/distribution/energy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@
from .multi_double_well_potential import *
from .linlogcut import *
from .openmm import *

from .xtb import *
from .ase import *
98 changes: 98 additions & 0 deletions bgflow/distribution/energy/ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Wrapper around ASE (atomic simulation environment)
"""
__all__ = ["ASEBridge", "ASEEnergy"]


import warnings
import torch
import numpy as np
from .base import _BridgeEnergy, _Bridge


class ASEBridge(_Bridge):
"""Wrapper around Atomic Simulation Environment.
Parameters
----------
atoms : ase.Atoms
An `Atoms` object that has a calculator attached to it.
temperature : float
Temperature in Kelvin.
err_handling : str
How to deal with exceptions inside ase. One of `["ignore", "warning", "error"]`
Notes
-----
Requires the ase package (installable with `conda install -c conda-forge ase`).
"""
def __init__(
self,
atoms,
temperature: float,
err_handling: str = "warning"
):
super().__init__()
assert hasattr(atoms, "calc")
self.atoms = atoms
self.temperature = temperature
self.err_handling = err_handling

@property
def n_atoms(self):
return len(self.atoms)

def _evaluate_single(
self,
positions: torch.Tensor,
evaluate_force=True,
evaluate_energy=True,
):
from ase.units import kB, nm
kbt = kB * self.temperature
energy, force = None, None
try:
self.atoms.positions = positions * nm
if evaluate_energy:
energy = self.atoms.get_potential_energy() / kbt
if evaluate_force:
force = self.atoms.get_forces() / (kbt / nm)
assert not np.isnan(energy)
assert not np.isnan(force).any()
except AssertionError as e:
force[np.isnan(force)] = 0.
energy = np.infty
if self.err_handling == "warning":
warnings.warn("Found nan in ase force or energy. Returning infinite energy and zero force.")
elif self.err_handling == "error":
raise e
return energy, force


class ASEEnergy(_BridgeEnergy):
"""Energy computation with calculators from the atomic simulation environment (ASE).
Various molecular simulation programs provide wrappers for ASE,
see https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html
for a list of available calculators.
Examples
--------
Use the calculator from the xtb package to compute the energy of a water molecule with the GFN2-xTB method.
>>> from ase.build import molecule
>>> from xtb.ase.calculator import XTB
>>> water = molecule("H2O")
>>> water.calc = XTB()
>>> target = ASEEnergy(ASEBridge(water, 300.))
>>> pos = torch.tensor(0.1*water.positions, **ctx)
>>> energy = target.energy(pos)
Parameters
----------
ase_bridge : ASEBridge
The wrapper object.
two_event_dims : bool
Whether to use two event dimensions.
In this case, the energy call expects positions of shape (*batch_shape, n_atoms, 3).
Otherwise, it expects positions of shape (*batch_shape, n_atoms * 3).
"""
pass
117 changes: 115 additions & 2 deletions bgflow/distribution/energy/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@

__all__ = ["Energy"]


from typing import Union, Optional, Sequence
from collections.abc import Sequence as _Sequence
import warnings

import torch

__all__ = ["Energy"]
import numpy as np
from ...utils.types import assert_numpy


def _is_non_empty_sequence_of_integers(x):
Expand Down Expand Up @@ -208,3 +212,112 @@ def force(
if len(self._event_shapes) == 1:
forces = forces[0]
return forces


class _BridgeEnergyWrapper(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bridge):
energy, force, *_ = bridge.evaluate(input)
ctx.save_for_backward(-force)
return energy

@staticmethod
def backward(ctx, grad_output):
neg_force, = ctx.saved_tensors
grad_input = grad_output * neg_force
return grad_input, None


_evaluate_bridge_energy = _BridgeEnergyWrapper.apply


class _Bridge:
_FLOATING_TYPE = np.float64
_SPATIAL_DIM = 3

def __init__(self):
self.last_energies = None
self.last_forces = None

def evaluate(
self,
positions: torch.Tensor,
*args,
evaluate_force: bool = True,
evaluate_energy: bool = True,
**kwargs
):
shape = positions.shape
assert shape[-2:] == (self.n_atoms, 3) or shape[-1] == self.n_atoms * 3
energy_shape = shape[:-2] if shape[-2:] == (self.n_atoms, 3) else shape[:-1]
# the stupid last dim
energy_shape = [*energy_shape, 1]
position_batch = assert_numpy(positions.reshape(-1, self.n_atoms, 3), arr_type=self._FLOATING_TYPE)

energy_batch = np.zeros(energy_shape, dtype=position_batch.dtype)
force_batch = np.zeros_like(position_batch)

for i, pos in enumerate(position_batch):
energy_batch[i], force_batch[i] = self._evaluate_single(
pos,
*args,
evaluate_energy=evaluate_energy,
evaluate_force=evaluate_force,
**kwargs
)

energies = torch.tensor(energy_batch.reshape(*energy_shape)).to(positions)
forces = torch.tensor(force_batch.reshape(*shape)).to(positions)

# store
self.last_energies = energies
self.last_forces = forces

return energies, forces

def _evaluate_single(
self,
positions: torch.Tensor,
*args,
evaluate_force=True,
evaluate_energy=True,
**kwargs
):
raise NotImplementedError

@property
def n_atoms(self):
raise NotImplementedError()


class _BridgeEnergy(Energy):

def __init__(self, bridge, two_event_dims=True):
event_shape = (bridge.n_atoms, 3) if two_event_dims else (bridge.n_atoms * 3, )
super().__init__(event_shape)
self._bridge = bridge
self._last_batch = None

@property
def last_batch(self):
return self._last_batch

@property
def bridge(self):
return self._bridge

def _energy(self, batch, no_grads=False):
# check if we have already computed this energy (hash of string representation should be sufficient)
if hash(str(batch)) == self._last_batch:
return self._bridge.last_energies
else:
self._last_batch = hash(str(batch))
return _evaluate_bridge_energy(batch, self._bridge)

def force(self, batch, temperature=None):
# check if we have already computed this energy
if hash(str(batch)) == self.last_batch:
return self.bridge.last_forces
else:
self._last_batch = hash(str(batch))
return self._bridge.evaluate(batch)[1]
75 changes: 22 additions & 53 deletions bgflow/distribution/energy/openmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,13 @@
import torch

from ...utils.types import assert_numpy
from .base import Energy
from .base import _BridgeEnergy, _Bridge

__all__ = ["OpenMMBridge", "OpenMMEnergy"]


_OPENMM_FLOATING_TYPE = np.float64
_SPATIAL_DIM = 3


class _OpenMMEnergyWrapper(torch.autograd.Function):

@staticmethod
def forward(ctx, input, openmm_energy_bridge):
energy, force, *_ = openmm_energy_bridge.evaluate(input)
ctx.save_for_backward(-force)
return energy

@staticmethod
def backward(ctx, grad_output):
neg_force, = ctx.saved_tensors
grad_input = grad_output * neg_force
return grad_input, None


_evaluate_openmm_energy = _OpenMMEnergyWrapper.apply
__all__ = ["OpenMMBridge", "OpenMMEnergy"]


class OpenMMBridge:
class OpenMMBridge(_Bridge):
"""Bridge object to evaluate energies in OpenMM.
Input positions are in nm, returned energies are dimensionless (units of kT), returned forces are in kT/nm.
Expand Down Expand Up @@ -87,8 +66,11 @@ def __init__(
self._n_simulation_steps = n_simulation_steps
self._unit_reciprocal = 1/(openmm_integrator.getTemperature() * unit.MOLAR_GAS_CONSTANT_R
).value_in_unit(unit.kilojoule_per_mole)
self.last_energies = None
self.last_forces = None
super().__init__()

@property
def n_atoms(self):
return self._openmm_system.getNumParticles()

@property
def integrator(self):
Expand Down Expand Up @@ -139,13 +121,13 @@ def evaluate(
"""

# make a list of positions
batch_array = assert_numpy(batch, arr_type=_OPENMM_FLOATING_TYPE)
batch_array = assert_numpy(batch, arr_type=self._FLOATING_TYPE)

# assert correct number of positions
assert batch_array.shape[1] == self._openmm_system.getNumParticles() * _SPATIAL_DIM
assert batch_array.shape[1] == self._openmm_system.getNumParticles() * self._SPATIAL_DIM

# reshape to (B, N, D)
batch_array = batch_array.reshape(batch.shape[0], -1, _SPATIAL_DIM)
batch_array = batch_array.reshape(batch.shape[0], -1, self._SPATIAL_DIM)
energies, forces, new_positions, log_path_probability_ratio = self.context_wrapper.evaluate(
batch_array,
evaluate_energy=evaluate_energy,
Expand All @@ -163,11 +145,11 @@ def evaluate(
# to PyTorch tensors
energies = torch.tensor(energies).to(batch).reshape(-1, 1) if evaluate_energy else None
forces = (
torch.tensor(forces).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*_SPATIAL_DIM)
torch.tensor(forces).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*self._SPATIAL_DIM)
if evaluate_force else None
)
new_positions = (
torch.tensor(new_positions).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*_SPATIAL_DIM)
torch.tensor(new_positions).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*self._SPATIAL_DIM)
if evaluate_positions else None
)
log_path_probability_ratio = (
Expand Down Expand Up @@ -527,25 +509,12 @@ def evaluate(
)


class OpenMMEnergy(Energy):

def __init__(self, dimension, openmm_energy_bridge):
super().__init__(dimension)
self._openmm_energy_bridge = openmm_energy_bridge
self._last_batch = None

def _energy(self, batch, no_grads=False):
# check if we have already computed this energy (hash of string representation should be sufficient)
if hash(str(batch)) == self._last_batch:
return self._openmm_energy_bridge.last_energies
else:
self._last_batch = hash(str(batch))
return _evaluate_openmm_energy(batch, self._openmm_energy_bridge)

def force(self, batch, temperature=None):
# check if we have already computed this energy
if hash(str(batch)) == self._last_batch:
return self._openmm_energy_bridge.last_forces
else:
self._last_batch = hash(str(batch))
return self._openmm_energy_bridge.evaluate(batch)[1]
class OpenMMEnergy(_BridgeEnergy):
def __init__(self, dimension=None, bridge=None, two_event_dims=False):
if dimension is not None:
warnings.warn(
"dimension argument in OpenMMEnergy is deprecated and will be ignored. "
"The dimension is directly inferred from the system.",
DeprecationWarning
)
super().__init__(bridge, two_event_dims=two_event_dims)
Loading

0 comments on commit ae1dc6a

Please sign in to comment.