diff --git a/python/rascaline-torch/rascaline/torch/__init__.py b/python/rascaline-torch/rascaline/torch/__init__.py index 177d0c688..244126619 100644 --- a/python/rascaline-torch/rascaline/torch/__init__.py +++ b/python/rascaline-torch/rascaline/torch/__init__.py @@ -32,7 +32,7 @@ SphericalExpansion, SphericalExpansionByPair, ) -from .system import System, systems_to_torch # noqa +from .system import System, metatensor_system_to_rascaline, systems_to_torch # noqa __all__ = [ diff --git a/python/rascaline-torch/rascaline/torch/calculator_base.py b/python/rascaline-torch/rascaline/torch/calculator_base.py index 3b74580bf..3c88d9aae 100644 --- a/python/rascaline-torch/rascaline/torch/calculator_base.py +++ b/python/rascaline-torch/rascaline/torch/calculator_base.py @@ -2,6 +2,7 @@ import torch from metatensor.torch import Labels, TensorMap +from metatensor.torch.atomistic import NeighborsListOptions from .system import System @@ -82,6 +83,18 @@ def cutoffs(self) -> List[float]: """all the radial cutoffs used by this calculator's neighbors lists""" return self._c.cutoffs + def requested_neighbors_lists(self) -> List[NeighborsListOptions]: + options = [] + for cutoff in self.cutoffs: + options.append( + NeighborsListOptions( + cutoff=cutoff, + full_list=False, + requestor="rascaline", + ) + ) + return options + def compute( self, systems: Union[System, List[System]], diff --git a/python/rascaline-torch/rascaline/torch/system.py b/python/rascaline-torch/rascaline/torch/system.py index a9efba500..cd8be627d 100644 --- a/python/rascaline-torch/rascaline/torch/system.py +++ b/python/rascaline-torch/rascaline/torch/system.py @@ -2,12 +2,14 @@ from typing import List, Optional, Sequence, overload import torch +from metatensor.torch.atomistic import System as MetatensorSystem import rascaline if os.environ.get("RASCALINE_IMPORT_FOR_SPHINX") is None: System = torch.classes.rascaline.System + metatensor_system_to_rascaline = torch.ops.rascaline.metatensor_system_to_rascaline else: # Documentation for the `System` class, only used when loading the code with sphinx class System: @@ -53,6 +55,16 @@ def cell(self) -> torch.Tensor: boundary conditions, or a matrix filled with ``0`` for non-periodic systems """ + def metatensor_system_to_rascaline(system: MetatensorSystem) -> List[System]: + """ + Convert metatensor :py:class:`metatensor.torch.atomistic.System` definition to + rascaline's :py:class:`rascaline.torch.System`. + + This returns a list for direct compatibility with the + :py:class:`rascaline.torch.CalculatorBase` API, but the list will always contain + a single system. + """ + @overload def systems_to_torch( diff --git a/python/rascaline-torch/tests/export.py b/python/rascaline-torch/tests/export.py new file mode 100644 index 000000000..4264d029a --- /dev/null +++ b/python/rascaline-torch/tests/export.py @@ -0,0 +1,96 @@ +import os +from typing import Dict, List + +import torch +from metatensor.torch import Labels, TensorBlock +from metatensor.torch.atomistic import ( + MetatensorAtomisticModule, + ModelCapabilities, + ModelRunOptions, + System, +) + +from rascaline.torch import SoapPowerSpectrum, metatensor_system_to_rascaline + + +HYPERS = { + "cutoff": 3.6, + "max_radial": 12, + "max_angular": 3, + "atomic_gaussian_width": 0.2, + "center_atom_weight": 1.0, + "radial_basis": {"Gto": {}}, + "cutoff_function": {"ShiftedCosine": {"width": 0.3}}, +} + + +class Model(torch.nn.Module): + def __init__(self, species: List[int]): + super().__init__() + self.calculator = SoapPowerSpectrum(**HYPERS) + self.species_neighbors = torch.IntTensor( + [(s1, s2) for s1 in species for s2 in species if s1 < s2] + ) + + n_max = HYPERS["max_radial"] + l_max = HYPERS["max_angular"] + in_features = (n_max * len(species)) ** 2 * l_max + self.linear = torch.nn.Linear(in_features=in_features, out_features=1) + + def forward( + self, system: System, run_options: ModelRunOptions + ) -> Dict[str, TensorBlock]: + if "energy" not in run_options.outputs: + return {} + + options = run_options.outputs["energy"] + + selected_atoms = run_options.selected_atoms + if selected_atoms is None: + selected_samples = None + else: + selected_tensor = torch.tensor(selected_atoms, dtype=torch.int32) + selected_samples = Labels("center", selected_tensor.reshape(-1, 1)) + + soap = self.calculator( + systems=metatensor_system_to_rascaline(system), + selected_samples=selected_samples, + ) + + soap = soap.keys_to_properties( + Labels(["species_neighbor_1", "species_neighbor_2"], self.species_neighbors) + ) + soap = soap.keys_to_samples("species_center") + + features = soap.block().values + samples = system.positions.samples + + if not options.per_atom: + features = soap.block().values.sum(dim=0, keepdim=True) + samples = Labels(["_"], torch.IntTensor([[0]])) + + return { + "energy": TensorBlock( + values=self.linear(features), + samples=samples, + components=[], + properties=Labels(["energy"], torch.IntTensor([[0]])), + ) + } + + +def test_export_as_metatensor_module(tmpdir): + model = Model(species=[1, 6, 8]) + model.eval() + + export = MetatensorAtomisticModule(model, ModelCapabilities()) + + # Check we are requesting the right set of neighbors + neighbors = export.requested_neighbors_lists() + assert len(neighbors) == 1 + assert neighbors[0].cutoff == HYPERS["cutoff"] + assert not neighbors[0].full_list + assert neighbors[0].requestors() == ["rascaline", "Model.calculator"] + + # check we can save the model + export.export(os.path.join(tmpdir, "model.pt")) diff --git a/rascaline-torch/include/rascaline/torch/system.hpp b/rascaline-torch/include/rascaline/torch/system.hpp index 22f810758..2628f11dd 100644 --- a/rascaline-torch/include/rascaline/torch/system.hpp +++ b/rascaline-torch/include/rascaline/torch/system.hpp @@ -7,6 +7,7 @@ #include #include +#include #include "rascaline/torch/exports.h" @@ -124,9 +125,6 @@ class RASCALINE_TORCH_EXPORT SystemHolder final: public rascaline::System, publi /// @private implementation of __str__ for TorchScript std::string str() const; - // TODO: convert from a Dict[str, TorchTensorMap] for the interface with LAMMPS - // static TorchSystem from_metatensor_dict(); - private: torch::Tensor species_; torch::Tensor positions_; @@ -143,6 +141,12 @@ class RASCALINE_TORCH_EXPORT SystemHolder final: public rascaline::System, publi double last_cutoff_ = -1.0; }; +/// Convert metatensor System definition to rascaline's +/// +/// This returns a vector for direct compatibility with the `Calculator` API, +/// but the vector will always contain a single system. +std::vector metatensor_system_to_rascaline(metatensor_torch::System system); + } #endif diff --git a/rascaline-torch/src/register.cpp b/rascaline-torch/src/register.cpp index 75cf845d4..1761d986a 100644 --- a/rascaline-torch/src/register.cpp +++ b/rascaline-torch/src/register.cpp @@ -71,4 +71,11 @@ TORCH_LIBRARY(rascaline, module) { ") -> __torch__.torch.classes.metatensor.TensorMap", register_autograd ); + + module.def( + "metatensor_system_to_rascaline(" + "__torch__.torch.classes.metatensor.System system" + ") -> __torch__.torch.classes.rascaline.System[]", + metatensor_system_to_rascaline + ); } diff --git a/rascaline-torch/src/system.cpp b/rascaline-torch/src/system.cpp index b1bd999a2..54e5c4b8e 100644 --- a/rascaline-torch/src/system.cpp +++ b/rascaline-torch/src/system.cpp @@ -145,3 +145,70 @@ std::string SystemHolder::str() const { return result.str(); } + + +std::vector rascaline_torch::metatensor_system_to_rascaline(metatensor_torch::System system) { + // TODO: check sample order + auto positions_samples = system->positions->samples(); + assert(positions_samples->names()[0] == "atom"); + assert(positions_samples->names()[1] == "species"); + + auto species = positions_samples->column("species").contiguous(); + assert(species.dtype() == torch::kInt32); + + auto positions = system->positions->values().reshape({-1, 3}); + assert(positions.dtype() == torch::kFloat64); + + auto cell = system->cell->values().reshape({3, 3}); + assert(cell.dtype() == torch::kFloat64); + + auto result = torch::make_intrusive( + species.to(torch::kCPU), + positions.to(torch::kCPU), + cell.to(torch::kCPU) + ); + + for (const auto& options: system->known_neighbors_lists()) { + for (const auto& requestor: options->requestors()) { + // only convert neighbors list requested by rascaline + if (requestor == "rascaline") { + auto neighbors = system->get_neighbors_list(options); + auto samples_values = neighbors->samples()->values().to(torch::kCPU); + auto samples = samples_values.accessor(); + + auto distances_tensor = neighbors->values().reshape({-1, 3}).to(torch::kCPU); + auto distances = distances_tensor.accessor(); + + auto n_pairs = samples.size(1); + + auto pairs = std::vector(); + pairs.reserve(static_cast(n_pairs)); + for (int64_t i=0; i(samples[i][0]); + pair.second = static_cast(samples[i][1]); + + pair.distance = std::sqrt(x*x + y*y + z*z); + pair.vector[0] = x; + pair.vector[1] = y; + pair.vector[2] = z; + + pair.cell_shift_indices[0] = samples[i][2]; + pair.cell_shift_indices[1] = samples[i][3]; + pair.cell_shift_indices[2] = samples[i][4]; + + pairs.emplace_back(pair); + } + + result->set_precomputed_pairs(options->cutoff(), std::move(pairs)); + continue; + } + } + } + + return {result}; +} diff --git a/rascaline-torch/tests/calculator.cpp b/rascaline-torch/tests/calculator.cpp index d841d603a..35053a645 100644 --- a/rascaline-torch/tests/calculator.cpp +++ b/rascaline-torch/tests/calculator.cpp @@ -300,5 +300,5 @@ TorchSystem test_system(bool positions_grad, bool cell_grad) { auto cell = 10 * torch::eye(3); cell.requires_grad_(cell_grad); - return torch::make_intrusive(species, positions, cell); + return torch::make_intrusive(species, positions, cell); } diff --git a/tox.ini b/tox.ini index 27993f6b8..4b6d2baa7 100644 --- a/tox.ini +++ b/tox.ini @@ -94,6 +94,10 @@ changedir = python/rascaline-torch commands = # install rascaline-torch pip install . {[testenv]build-single-wheel} --force-reinstall + + # TODO + pip install --no-deps metatensor-operations + # run the unit tests pytest {[testenv]test_options} --assert=plain {posargs}