From ae1dc6a9b3bcd7f0382c98614a572b2ae1b3a540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kr=C3=A4mer?= <30830507+Olllom@users.noreply.github.com> Date: Wed, 16 Feb 2022 00:33:49 +0100 Subject: [PATCH] xtb/ase wrappers (#36) * add xtb wrapper bugfix in xtb bugfix in xtb bugfix xtb bugfix in xtb ... ... ... ... ... ... ... xtb fixes * XTB wrapper docs * skip tests when not installed * bugfix in xtb * add xtb to test env * unit conversion in xtb * unit fixes in xtb * err_handling in xtb * ... * try with higher temperature * bugfix in xtb * bohr as unit in xtb * ... * ase wrapper * ase tests work * better docs * add ase to conda env * added caching * remove duplicate code * fix warnings * skip ase/xtb tests if not installed * added forgotten super() call to ASE Bridge * add optional deps to README * unify interfaces to remove duplicate code * remove unused import --- README.md | 5 +- bgflow/distribution/energy/__init__.py | 3 +- bgflow/distribution/energy/ase.py | 98 ++++++++++++ bgflow/distribution/energy/base.py | 117 ++++++++++++++- bgflow/distribution/energy/openmm.py | 75 +++------- bgflow/distribution/energy/xtb.py | 181 +++++++++++++++++++++++ devtools/conda-env.yml | 2 + tests/distribution/energy/test_ase.py | 39 +++++ tests/distribution/energy/test_openmm.py | 6 +- tests/distribution/energy/test_xtb.py | 71 +++++++++ 10 files changed, 537 insertions(+), 60 deletions(-) create mode 100644 bgflow/distribution/energy/ase.py create mode 100644 bgflow/distribution/energy/xtb.py create mode 100644 tests/distribution/energy/test_ase.py create mode 100644 tests/distribution/energy/test_xtb.py diff --git a/README.md b/README.md index ee35e6b9..0385282d 100644 --- a/README.md +++ b/README.md @@ -136,13 +136,16 @@ pytest * [matplotlib](https://github.com/matplotlib/matplotlib) * [pytest](https://github.com/pytest-dev/pytest) (for testing) * [nflows](https://github.com/bayesiains/nflows) (for Neural Spline Flows) - * [OpenMM](https://github.com/openmm/openmm) (for molecular examples) * [torchdiffeq](https://github.com/rtqichen/torchdiffeq) (for neural ODEs) * [ANODE](https://github.com/amirgholami/anode) (for neural ODEs) + * [OpenMM](https://github.com/openmm/openmm) (for molecular mechanics energies) + * [ase](https://wiki.fysik.dtu.dk/ase/index.html) (for quantum and molecular mechanics energies through the atomic simulation environment) + * [xtb-python](https://xtb-python.readthedocs.io) (for semi-empirical GFN quantum energies) * [netCDF4](https://unidata.github.io/netcdf4-python/) (for the `ReplayBufferReporter`) * [jax](https://github.com/google/jax) (for smooth flows / implicit backprop) * [jax2torch](https://github.com/lucidrains/jax2torch) (for smooth flows / implicit backprop) + *** ## [License](#dependencies) [MIT License](LICENSE) diff --git a/bgflow/distribution/energy/__init__.py b/bgflow/distribution/energy/__init__.py index 9fceccbd..fa0deca6 100644 --- a/bgflow/distribution/energy/__init__.py +++ b/bgflow/distribution/energy/__init__.py @@ -75,4 +75,5 @@ from .multi_double_well_potential import * from .linlogcut import * from .openmm import * - +from .xtb import * +from .ase import * diff --git a/bgflow/distribution/energy/ase.py b/bgflow/distribution/energy/ase.py new file mode 100644 index 00000000..e5b06b7d --- /dev/null +++ b/bgflow/distribution/energy/ase.py @@ -0,0 +1,98 @@ +"""Wrapper around ASE (atomic simulation environment) +""" +__all__ = ["ASEBridge", "ASEEnergy"] + + +import warnings +import torch +import numpy as np +from .base import _BridgeEnergy, _Bridge + + +class ASEBridge(_Bridge): + """Wrapper around Atomic Simulation Environment. + + Parameters + ---------- + atoms : ase.Atoms + An `Atoms` object that has a calculator attached to it. + temperature : float + Temperature in Kelvin. + err_handling : str + How to deal with exceptions inside ase. One of `["ignore", "warning", "error"]` + + Notes + ----- + Requires the ase package (installable with `conda install -c conda-forge ase`). + + """ + def __init__( + self, + atoms, + temperature: float, + err_handling: str = "warning" + ): + super().__init__() + assert hasattr(atoms, "calc") + self.atoms = atoms + self.temperature = temperature + self.err_handling = err_handling + + @property + def n_atoms(self): + return len(self.atoms) + + def _evaluate_single( + self, + positions: torch.Tensor, + evaluate_force=True, + evaluate_energy=True, + ): + from ase.units import kB, nm + kbt = kB * self.temperature + energy, force = None, None + try: + self.atoms.positions = positions * nm + if evaluate_energy: + energy = self.atoms.get_potential_energy() / kbt + if evaluate_force: + force = self.atoms.get_forces() / (kbt / nm) + assert not np.isnan(energy) + assert not np.isnan(force).any() + except AssertionError as e: + force[np.isnan(force)] = 0. + energy = np.infty + if self.err_handling == "warning": + warnings.warn("Found nan in ase force or energy. Returning infinite energy and zero force.") + elif self.err_handling == "error": + raise e + return energy, force + + +class ASEEnergy(_BridgeEnergy): + """Energy computation with calculators from the atomic simulation environment (ASE). + Various molecular simulation programs provide wrappers for ASE, + see https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html + for a list of available calculators. + + Examples + -------- + Use the calculator from the xtb package to compute the energy of a water molecule with the GFN2-xTB method. + >>> from ase.build import molecule + >>> from xtb.ase.calculator import XTB + >>> water = molecule("H2O") + >>> water.calc = XTB() + >>> target = ASEEnergy(ASEBridge(water, 300.)) + >>> pos = torch.tensor(0.1*water.positions, **ctx) + >>> energy = target.energy(pos) + + Parameters + ---------- + ase_bridge : ASEBridge + The wrapper object. + two_event_dims : bool + Whether to use two event dimensions. + In this case, the energy call expects positions of shape (*batch_shape, n_atoms, 3). + Otherwise, it expects positions of shape (*batch_shape, n_atoms * 3). + """ + pass diff --git a/bgflow/distribution/energy/base.py b/bgflow/distribution/energy/base.py index a126402e..c8c2cb56 100644 --- a/bgflow/distribution/energy/base.py +++ b/bgflow/distribution/energy/base.py @@ -1,10 +1,14 @@ + +__all__ = ["Energy"] + + from typing import Union, Optional, Sequence from collections.abc import Sequence as _Sequence import warnings import torch - -__all__ = ["Energy"] +import numpy as np +from ...utils.types import assert_numpy def _is_non_empty_sequence_of_integers(x): @@ -208,3 +212,112 @@ def force( if len(self._event_shapes) == 1: forces = forces[0] return forces + + +class _BridgeEnergyWrapper(torch.autograd.Function): + @staticmethod + def forward(ctx, input, bridge): + energy, force, *_ = bridge.evaluate(input) + ctx.save_for_backward(-force) + return energy + + @staticmethod + def backward(ctx, grad_output): + neg_force, = ctx.saved_tensors + grad_input = grad_output * neg_force + return grad_input, None + + +_evaluate_bridge_energy = _BridgeEnergyWrapper.apply + + +class _Bridge: + _FLOATING_TYPE = np.float64 + _SPATIAL_DIM = 3 + + def __init__(self): + self.last_energies = None + self.last_forces = None + + def evaluate( + self, + positions: torch.Tensor, + *args, + evaluate_force: bool = True, + evaluate_energy: bool = True, + **kwargs + ): + shape = positions.shape + assert shape[-2:] == (self.n_atoms, 3) or shape[-1] == self.n_atoms * 3 + energy_shape = shape[:-2] if shape[-2:] == (self.n_atoms, 3) else shape[:-1] + # the stupid last dim + energy_shape = [*energy_shape, 1] + position_batch = assert_numpy(positions.reshape(-1, self.n_atoms, 3), arr_type=self._FLOATING_TYPE) + + energy_batch = np.zeros(energy_shape, dtype=position_batch.dtype) + force_batch = np.zeros_like(position_batch) + + for i, pos in enumerate(position_batch): + energy_batch[i], force_batch[i] = self._evaluate_single( + pos, + *args, + evaluate_energy=evaluate_energy, + evaluate_force=evaluate_force, + **kwargs + ) + + energies = torch.tensor(energy_batch.reshape(*energy_shape)).to(positions) + forces = torch.tensor(force_batch.reshape(*shape)).to(positions) + + # store + self.last_energies = energies + self.last_forces = forces + + return energies, forces + + def _evaluate_single( + self, + positions: torch.Tensor, + *args, + evaluate_force=True, + evaluate_energy=True, + **kwargs + ): + raise NotImplementedError + + @property + def n_atoms(self): + raise NotImplementedError() + + +class _BridgeEnergy(Energy): + + def __init__(self, bridge, two_event_dims=True): + event_shape = (bridge.n_atoms, 3) if two_event_dims else (bridge.n_atoms * 3, ) + super().__init__(event_shape) + self._bridge = bridge + self._last_batch = None + + @property + def last_batch(self): + return self._last_batch + + @property + def bridge(self): + return self._bridge + + def _energy(self, batch, no_grads=False): + # check if we have already computed this energy (hash of string representation should be sufficient) + if hash(str(batch)) == self._last_batch: + return self._bridge.last_energies + else: + self._last_batch = hash(str(batch)) + return _evaluate_bridge_energy(batch, self._bridge) + + def force(self, batch, temperature=None): + # check if we have already computed this energy + if hash(str(batch)) == self.last_batch: + return self.bridge.last_forces + else: + self._last_batch = hash(str(batch)) + return self._bridge.evaluate(batch)[1] diff --git a/bgflow/distribution/energy/openmm.py b/bgflow/distribution/energy/openmm.py index 098c6d3d..2cde37d3 100644 --- a/bgflow/distribution/energy/openmm.py +++ b/bgflow/distribution/energy/openmm.py @@ -9,34 +9,13 @@ import torch from ...utils.types import assert_numpy -from .base import Energy +from .base import _BridgeEnergy, _Bridge -__all__ = ["OpenMMBridge", "OpenMMEnergy"] - - -_OPENMM_FLOATING_TYPE = np.float64 -_SPATIAL_DIM = 3 - - -class _OpenMMEnergyWrapper(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, openmm_energy_bridge): - energy, force, *_ = openmm_energy_bridge.evaluate(input) - ctx.save_for_backward(-force) - return energy - @staticmethod - def backward(ctx, grad_output): - neg_force, = ctx.saved_tensors - grad_input = grad_output * neg_force - return grad_input, None - - -_evaluate_openmm_energy = _OpenMMEnergyWrapper.apply +__all__ = ["OpenMMBridge", "OpenMMEnergy"] -class OpenMMBridge: +class OpenMMBridge(_Bridge): """Bridge object to evaluate energies in OpenMM. Input positions are in nm, returned energies are dimensionless (units of kT), returned forces are in kT/nm. @@ -87,8 +66,11 @@ def __init__( self._n_simulation_steps = n_simulation_steps self._unit_reciprocal = 1/(openmm_integrator.getTemperature() * unit.MOLAR_GAS_CONSTANT_R ).value_in_unit(unit.kilojoule_per_mole) - self.last_energies = None - self.last_forces = None + super().__init__() + + @property + def n_atoms(self): + return self._openmm_system.getNumParticles() @property def integrator(self): @@ -139,13 +121,13 @@ def evaluate( """ # make a list of positions - batch_array = assert_numpy(batch, arr_type=_OPENMM_FLOATING_TYPE) + batch_array = assert_numpy(batch, arr_type=self._FLOATING_TYPE) # assert correct number of positions - assert batch_array.shape[1] == self._openmm_system.getNumParticles() * _SPATIAL_DIM + assert batch_array.shape[1] == self._openmm_system.getNumParticles() * self._SPATIAL_DIM # reshape to (B, N, D) - batch_array = batch_array.reshape(batch.shape[0], -1, _SPATIAL_DIM) + batch_array = batch_array.reshape(batch.shape[0], -1, self._SPATIAL_DIM) energies, forces, new_positions, log_path_probability_ratio = self.context_wrapper.evaluate( batch_array, evaluate_energy=evaluate_energy, @@ -163,11 +145,11 @@ def evaluate( # to PyTorch tensors energies = torch.tensor(energies).to(batch).reshape(-1, 1) if evaluate_energy else None forces = ( - torch.tensor(forces).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*_SPATIAL_DIM) + torch.tensor(forces).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*self._SPATIAL_DIM) if evaluate_force else None ) new_positions = ( - torch.tensor(new_positions).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*_SPATIAL_DIM) + torch.tensor(new_positions).to(batch).reshape(batch.shape[0], self._openmm_system.getNumParticles()*self._SPATIAL_DIM) if evaluate_positions else None ) log_path_probability_ratio = ( @@ -527,25 +509,12 @@ def evaluate( ) -class OpenMMEnergy(Energy): - - def __init__(self, dimension, openmm_energy_bridge): - super().__init__(dimension) - self._openmm_energy_bridge = openmm_energy_bridge - self._last_batch = None - - def _energy(self, batch, no_grads=False): - # check if we have already computed this energy (hash of string representation should be sufficient) - if hash(str(batch)) == self._last_batch: - return self._openmm_energy_bridge.last_energies - else: - self._last_batch = hash(str(batch)) - return _evaluate_openmm_energy(batch, self._openmm_energy_bridge) - - def force(self, batch, temperature=None): - # check if we have already computed this energy - if hash(str(batch)) == self._last_batch: - return self._openmm_energy_bridge.last_forces - else: - self._last_batch = hash(str(batch)) - return self._openmm_energy_bridge.evaluate(batch)[1] +class OpenMMEnergy(_BridgeEnergy): + def __init__(self, dimension=None, bridge=None, two_event_dims=False): + if dimension is not None: + warnings.warn( + "dimension argument in OpenMMEnergy is deprecated and will be ignored. " + "The dimension is directly inferred from the system.", + DeprecationWarning + ) + super().__init__(bridge, two_event_dims=two_event_dims) diff --git a/bgflow/distribution/energy/xtb.py b/bgflow/distribution/energy/xtb.py new file mode 100644 index 00000000..1b05db93 --- /dev/null +++ b/bgflow/distribution/energy/xtb.py @@ -0,0 +1,181 @@ +"""Wrapper for semi-empirical QM energies with XTB. +""" + +__all__ = ["XTBEnergy", "XTBBridge"] + + +import warnings +import torch +import numpy as np +from .base import _BridgeEnergy, _Bridge + + +class XTBBridge(_Bridge): + """Wrapper around XTB for semi-empirical QM energy calculations. + + Parameters + ---------- + numbers : np.ndarray + Atomic numbers + temperature : float + Temperature in Kelvin. + method : str + The semi-empirical method that is used to compute energies. + solvent : str + The solvent. If empty string, perform a vacuum calculation. + verbosity : int + 0 (muted), 1 (minimal), 2 (full) + err_handling : str + How to deal with exceptions inside XTB. One of `["ignore", "warning", "error"]` + + Attributes + ---------- + n_atoms : int + The number of atoms in this molecules. + available_solvents : List[str] + The solvent models that are available for computations in xtb. + available_methods : List[str] + The semiempirical methods that are available for computations in xtb. + + Examples + -------- + Setting up an XTB energy for a small peptide from bgmol + >>> from bgmol.systems import MiniPeptide + >>> from bgflow import XTBEnergy, XTBBridge + >>> import numpy as np + >>> import torch + >>> system = MiniPeptide("G") + >>> numbers = np.array([atom.element.number for atom in system.mdtraj_topology.atoms]) + >>> target = XTBEnergy(XTBBridge(numbers=numbers, temperature=300, solvent="water")) + >>> xyz = torch.tensor(system.positions) + >>> energy = target.energy(xyz) + + Notes + ----- + Requires the xtb-python program (installable with `conda install -c conda-forge xtb-python`). + + """ + def __init__( + self, + numbers: np.ndarray, + temperature: float, + method: str = "GFN2-xTB", + solvent: str = "", + verbosity: int = 0, + err_handling: str = "warning" + ): + self.numbers = numbers + self.temperature = temperature + self.method = method + self.solvent = solvent + self.verbosity = verbosity + self.err_handling = err_handling + super().__init__() + + @property + def n_atoms(self): + return len(self.numbers) + + @property + def available_solvents(self): + from xtb.utils import _solvents + return list(_solvents.keys()) + + @property + def available_methods(self): + from xtb.utils import _methods + return list(_methods.keys()) + + def _evaluate_single( + self, + positions: torch.Tensor, + evaluate_force=True, + evaluate_energy=True, + ): + from xtb.interface import Calculator, XTBException + from xtb.utils import get_method, get_solvent + positions = _nm2bohr(positions) + energy, force = None, None + try: + calc = Calculator(get_method(self.method), self.numbers, positions) + calc.set_solvent(get_solvent(self.solvent)) + calc.set_verbosity(self.verbosity) + calc.set_electronic_temperature(self.temperature) + try: + res = calc.singlepoint() + except XTBException: + # Try with higher temperature + calc.set_electronic_temperature(10 * self.temperature) + res = calc.singlepoint() + calc.set_electronic_temperature(self.temperature) + res = calc.singlepoint(res) + if evaluate_energy: + energy = _hartree2kbt(res.get_energy(), self.temperature) + if evaluate_force: + force = _hartree_per_bohr2kbt_per_nm( + -res.get_gradient(), + self.temperature + ) + assert not np.isnan(energy) + assert not np.isnan(force).any() + except XTBException as e: + if self.err_handling == "error": + raise e + elif self.err_handling == "warning": + warnings.warn( + f"Caught exception in xtb. " + f"Returning infinite energy and zero force. " + f"Original exception: {e}" + ) + force = np.zeros_like(positions) + energy = np.infty + elif self.err_handling == "ignore": + force = np.zeros_like(positions) + energy = np.infty + except AssertionError: + force[np.isnan(force)] = 0. + energy = np.infty + if self.err_handling in ["error", "warning"]: + warnings.warn("Found nan in xtb force or energy. Returning infinite energy and zero force.") + + return energy, force + + +class XTBEnergy(_BridgeEnergy): + """Semi-empirical energy computation with XTB. + + Parameters + ---------- + xtb_bridge : XTBBridge + The wrapper object. + two_event_dims : bool + Whether to use two event dimensions. + In this case, the energy call expects positions of shape (*batch_shape, n_atoms, 3). + Otherwise, it expects positions of shape (*batch_shape, n_atoms * 3). + """ + pass + + +_BOLTZMANN_CONSTANT_HE = 3.1668115634556076e-06 # in hartree / kelvin +_BOHR_RADIUS = 0.0529177210903 # nm + + +def _bohr2nm(x): + return x * _BOHR_RADIUS + + +def _nm2bohr(x): + return x / _BOHR_RADIUS + + +def _per_bohr2per_nm(x): + return _nm2bohr(x) + + +def _hartree2kbt(x, temperature): + kbt = _BOLTZMANN_CONSTANT_HE * temperature + return x / kbt + + +def _hartree_per_bohr2kbt_per_nm(x, temperature): + return _per_bohr2per_nm(_hartree2kbt(x, temperature)) diff --git a/devtools/conda-env.yml b/devtools/conda-env.yml index 0cd9c450..0a682693 100644 --- a/devtools/conda-env.yml +++ b/devtools/conda-env.yml @@ -11,6 +11,8 @@ dependencies: - numpy - openmm + - xtb-python + - ase - openmmtools - pytorch - jax diff --git a/tests/distribution/energy/test_ase.py b/tests/distribution/energy/test_ase.py new file mode 100644 index 00000000..fb8a54e2 --- /dev/null +++ b/tests/distribution/energy/test_ase.py @@ -0,0 +1,39 @@ + +import pytest +import torch +from bgflow import ASEBridge, ASEEnergy, XTBBridge, XTBEnergy + + +try: + import ase + import xtb + ase_and_xtb_imported = True +except ImportError: + ase_and_xtb_imported = False + +pytestmark = pytest.mark.skipif(not ase_and_xtb_imported, reason="Tests require ASE and XTB") + + +def test_ase_energy(ctx): + from ase.build import molecule + from xtb.ase.calculator import XTB + water = molecule("H2O") + water.calc = XTB() + target = ASEEnergy(ASEBridge(water, 300.)) + pos = torch.tensor(0.1*water.positions, **ctx) + e = target.energy(pos) + f = target.force(pos) + + +def test_ase_vs_xtb(ctx): + # to make sure that unit conversion is the same, etc. + from ase.build import molecule + from xtb.ase.calculator import XTB + water = molecule("H2O") + water.calc = XTB() + target1 = ASEEnergy(ASEBridge(water, 300.)) + target2 = XTBEnergy(XTBBridge(water.numbers, 300.)) + pos = torch.tensor(0.1 * water.positions[None, ...], **ctx) + assert torch.allclose(target1.energy(pos), target2.energy(pos)) + assert torch.allclose(target1.force(pos), target2.force(pos), atol=1e-6) + diff --git a/tests/distribution/energy/test_openmm.py b/tests/distribution/energy/test_openmm.py index d8bb3af0..7fa39dad 100644 --- a/tests/distribution/energy/test_openmm.py +++ b/tests/distribution/energy/test_openmm.py @@ -122,7 +122,7 @@ def test_openmm_bridge_evaluate_openmmtools_testsystem( def test_openmm_bridge_cache(): """Test if hashing and caching works.""" bridge = OneParticleTestBridge() - omm_energy = OpenMMEnergy(3, bridge) + omm_energy = OpenMMEnergy(bridge=bridge) omm_energy._energy(torch.tensor([[0.1, 0.0, 0.0]] * 2)) hash1 = omm_energy._last_batch omm_energy._energy(torch.tensor([[0.2, 0.0, 0.0]] * 2)) @@ -132,9 +132,9 @@ def test_openmm_bridge_cache(): omm_energy._energy(torch.tensor([[0.1, 0.0, 0.0]] * 2)) # test if forces are in the same memory location for same input batch - force_address = hex(id(omm_energy._openmm_energy_bridge.last_forces)) + force_address = hex(id(omm_energy.bridge.last_forces)) force = ( - omm_energy._openmm_energy_bridge.last_forces + omm_energy.bridge.last_forces ) # retain a pointer to last forces so that memory is not freed assert ( hex(id(omm_energy.force(torch.tensor([[0.1, 0.0, 0.0]] * 2)))) == force_address diff --git a/tests/distribution/energy/test_xtb.py b/tests/distribution/energy/test_xtb.py new file mode 100644 index 00000000..3a03fb20 --- /dev/null +++ b/tests/distribution/energy/test_xtb.py @@ -0,0 +1,71 @@ +import pytest +import torch +import numpy as np +from bgflow import XTBEnergy, XTBBridge + +try: + import xtb + xtb_imported = True +except ImportError: + xtb_imported = False + +pytestmark = pytest.mark.skipif(not xtb_imported, reason="Test requires XTB") + + +@pytest.mark.parametrize("pos_shape", [(1, 3, 3), (1, 9)]) +def test_xtb_water(pos_shape, ctx): + unit = pytest.importorskip("openmm.unit") + temperature = 300 + numbers = np.array([8, 1, 1]) + positions = torch.tensor([ + [0.00000000000000, 0.00000000000000, -0.73578586109551], + [1.44183152868459, 0.00000000000000, 0.36789293054775], + [-1.44183152868459, 0.00000000000000, 0.36789293054775]], + **ctx + ) + positions = (positions * unit.bohr).value_in_unit(unit.nanometer) + target = XTBEnergy( + XTBBridge(numbers=numbers, temperature=temperature), + two_event_dims=(pos_shape == (1, 3, 3)) + ) + energy = target.energy(positions.reshape(pos_shape)) + force = target.force(positions.reshape(pos_shape)) + assert energy.shape == (1, 1) + assert force.shape == pos_shape + + kbt = unit.BOLTZMANN_CONSTANT_kB * temperature * unit.kelvin + expected_energy = torch.tensor(-5.070451354836705, **ctx) * unit.hartree / kbt + expected_force = - torch.tensor([ + [6.24500451e-17, - 3.47909735e-17, - 5.07156941e-03], + [-1.24839222e-03, 2.43536791e-17, 2.53578470e-03], + [1.24839222e-03, 1.04372944e-17, 2.53578470e-03], + ], **ctx) * unit.hartree/unit.bohr/(kbt/unit.nanometer) + assert torch.allclose(energy.flatten(), expected_energy.flatten(), atol=1e-5) + assert torch.allclose(force.flatten(), expected_force.flatten(), atol=1e-5) + + +def _eval_invalid(ctx, err_handling): + pos = torch.zeros(1, 3, 3, **ctx) + target = XTBEnergy( + XTBBridge(numbers=np.array([8, 1, 1]), temperature=300, err_handling=err_handling) + ) + return target.energy(pos), target.force(pos) + + +def test_xtb_error(ctx): + from xtb.interface import XTBException + with pytest.raises(XTBException): + _eval_invalid(ctx, err_handling="error") + + +def test_xtb_warning(ctx): + with pytest.warns(UserWarning, match="Caught exception in xtb"): + e, f = _eval_invalid(ctx, err_handling="warning") + assert torch.isinf(e).all() + assert torch.allclose(f, torch.zeros_like(f)) + + +def test_xtb_ignore(ctx): + e, f = _eval_invalid(ctx, err_handling="ignore") + assert torch.isinf(e).all() + assert torch.allclose(f, torch.zeros_like(f))