Skip to content

Commit

Permalink
WIP: Use rascaline inside metatensor's atomistic models
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Nov 3, 2023
1 parent 848a023 commit 5ce8b88
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/rascaline-torch/rascaline/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
SphericalExpansion,
SphericalExpansionByPair,
)
from .system import System, systems_to_torch # noqa
from .system import System, metatensor_system_to_rascaline, systems_to_torch # noqa


__all__ = [
Expand Down
13 changes: 13 additions & 0 deletions python/rascaline-torch/rascaline/torch/calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import NeighborsListOptions

from .system import System

Expand Down Expand Up @@ -82,6 +83,18 @@ def cutoffs(self) -> List[float]:
"""all the radial cutoffs used by this calculator's neighbors lists"""
return self._c.cutoffs

def requested_neighbors_lists(self) -> List[NeighborsListOptions]:
options = []
for cutoff in self.cutoffs:
options.append(
NeighborsListOptions(
cutoff=cutoff,
full_list=False,
requestor="rascaline",
)
)
return options

def compute(
self,
systems: Union[System, List[System]],
Expand Down
12 changes: 12 additions & 0 deletions python/rascaline-torch/rascaline/torch/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import List, Optional, Sequence, overload

import torch
from metatensor.torch.atomistic import System as MetatensorSystem

import rascaline


if os.environ.get("RASCALINE_IMPORT_FOR_SPHINX") is None:
System = torch.classes.rascaline.System
metatensor_system_to_rascaline = torch.ops.rascaline.metatensor_system_to_rascaline
else:
# Documentation for the `System` class, only used when loading the code with sphinx
class System:
Expand Down Expand Up @@ -53,6 +55,16 @@ def cell(self) -> torch.Tensor:
boundary conditions, or a matrix filled with ``0`` for non-periodic systems
"""

def metatensor_system_to_rascaline(system: MetatensorSystem) -> List[System]:
"""
Convert metatensor :py:class:`metatensor.torch.atomistic.System` definition to
rascaline's :py:class:`rascaline.torch.System`.
This returns a list for direct compatibility with the
:py:class:`rascaline.torch.CalculatorBase` API, but the list will always contain
a single system.
"""


@overload
def systems_to_torch(
Expand Down
96 changes: 96 additions & 0 deletions python/rascaline-torch/tests/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
from typing import Dict, List

import torch
from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import (
MetatensorAtomisticModule,
ModelCapabilities,
ModelRunOptions,
System,
)

from rascaline.torch import SoapPowerSpectrum, metatensor_system_to_rascaline


HYPERS = {
"cutoff": 3.6,
"max_radial": 12,
"max_angular": 3,
"atomic_gaussian_width": 0.2,
"center_atom_weight": 1.0,
"radial_basis": {"Gto": {}},
"cutoff_function": {"ShiftedCosine": {"width": 0.3}},
}


class Model(torch.nn.Module):
def __init__(self, species: List[int]):
super().__init__()
self.calculator = SoapPowerSpectrum(**HYPERS)
self.species_neighbors = torch.IntTensor(
[(s1, s2) for s1 in species for s2 in species if s1 < s2]
)

n_max = HYPERS["max_radial"]
l_max = HYPERS["max_angular"]
in_features = (n_max * len(species)) ** 2 * l_max
self.linear = torch.nn.Linear(in_features=in_features, out_features=1)

def forward(
self, system: System, run_options: ModelRunOptions
) -> Dict[str, TensorBlock]:
if "energy" not in run_options.outputs:
return {}

options = run_options.outputs["energy"]

selected_atoms = run_options.selected_atoms
if selected_atoms is None:
selected_samples = None
else:
selected_tensor = torch.tensor(selected_atoms, dtype=torch.int32)
selected_samples = Labels("center", selected_tensor.reshape(-1, 1))

soap = self.calculator(
systems=metatensor_system_to_rascaline(system),
selected_samples=selected_samples,
)

soap = soap.keys_to_properties(
Labels(["species_neighbor_1", "species_neighbor_2"], self.species_neighbors)
)
soap = soap.keys_to_samples("species_center")

features = soap.block().values
samples = system.positions.samples

if not options.per_atom:
features = soap.block().values.sum(dim=0, keepdim=True)
samples = Labels(["_"], torch.IntTensor([[0]]))

return {
"energy": TensorBlock(
values=self.linear(features),
samples=samples,
components=[],
properties=Labels(["energy"], torch.IntTensor([[0]])),
)
}


def test_export_as_metatensor_module(tmpdir):
model = Model(species=[1, 6, 8])
model.eval()

export = MetatensorAtomisticModule(model, ModelCapabilities())

# Check we are requesting the right set of neighbors
neighbors = export.requested_neighbors_lists()
assert len(neighbors) == 1
assert neighbors[0].cutoff == HYPERS["cutoff"]
assert not neighbors[0].full_list
assert neighbors[0].requestors() == ["rascaline", "Model.calculator"]

# check we can save the model
export.export(os.path.join(tmpdir, "model.pt"))
10 changes: 7 additions & 3 deletions rascaline-torch/include/rascaline/torch/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torch/script.h>

#include <rascaline.hpp>
#include <metatensor/torch/atomistic.hpp>

#include "rascaline/torch/exports.h"

Expand Down Expand Up @@ -124,9 +125,6 @@ class RASCALINE_TORCH_EXPORT SystemHolder final: public rascaline::System, publi
/// @private implementation of __str__ for TorchScript
std::string str() const;

// TODO: convert from a Dict[str, TorchTensorMap] for the interface with LAMMPS
// static TorchSystem from_metatensor_dict();

private:
torch::Tensor species_;
torch::Tensor positions_;
Expand All @@ -143,6 +141,12 @@ class RASCALINE_TORCH_EXPORT SystemHolder final: public rascaline::System, publi
double last_cutoff_ = -1.0;
};

/// Convert metatensor System definition to rascaline's
///
/// This returns a vector for direct compatibility with the `Calculator` API,
/// but the vector will always contain a single system.
std::vector<TorchSystem> metatensor_system_to_rascaline(metatensor_torch::System system);

}

#endif
7 changes: 7 additions & 0 deletions rascaline-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ TORCH_LIBRARY(rascaline, module) {
") -> __torch__.torch.classes.metatensor.TensorMap",
register_autograd
);

module.def(
"metatensor_system_to_rascaline("
"__torch__.torch.classes.metatensor.System system"
") -> __torch__.torch.classes.rascaline.System[]",
metatensor_system_to_rascaline
);
}
67 changes: 67 additions & 0 deletions rascaline-torch/src/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,70 @@ std::string SystemHolder::str() const {

return result.str();
}


std::vector<TorchSystem> rascaline_torch::metatensor_system_to_rascaline(metatensor_torch::System system) {
// TODO: check sample order
auto positions_samples = system->positions->samples();
assert(positions_samples->names()[0] == "atom");
assert(positions_samples->names()[1] == "species");

auto species = positions_samples->column("species").contiguous();
assert(species.dtype() == torch::kInt32);

auto positions = system->positions->values().reshape({-1, 3});
assert(positions.dtype() == torch::kFloat64);

auto cell = system->cell->values().reshape({3, 3});
assert(cell.dtype() == torch::kFloat64);

auto result = torch::make_intrusive<rascaline_torch::SystemHolder>(
species.to(torch::kCPU),
positions.to(torch::kCPU),
cell.to(torch::kCPU)
);

for (const auto& options: system->known_neighbors_lists()) {
for (const auto& requestor: options->requestors()) {
// only convert neighbors list requested by rascaline
if (requestor == "rascaline") {
auto neighbors = system->get_neighbors_list(options);
auto samples_values = neighbors->samples()->values().to(torch::kCPU);
auto samples = samples_values.accessor<int32_t, 2>();

auto distances_tensor = neighbors->values().reshape({-1, 3}).to(torch::kCPU);
auto distances = distances_tensor.accessor<double, 2>();

auto n_pairs = samples.size(1);

auto pairs = std::vector<rascal_pair_t>();
pairs.reserve(static_cast<size_t>(n_pairs));
for (int64_t i=0; i<n_pairs; i++) {
auto x = distances[i][0];
auto y = distances[i][1];
auto z = distances[i][2];

auto pair = rascal_pair_t {};
pair.first = static_cast<uintptr_t>(samples[i][0]);
pair.second = static_cast<uintptr_t>(samples[i][1]);

pair.distance = std::sqrt(x*x + y*y + z*z);
pair.vector[0] = x;
pair.vector[1] = y;
pair.vector[2] = z;

pair.cell_shift_indices[0] = samples[i][2];
pair.cell_shift_indices[1] = samples[i][3];
pair.cell_shift_indices[2] = samples[i][4];

pairs.emplace_back(pair);
}

result->set_precomputed_pairs(options->cutoff(), std::move(pairs));
continue;
}
}
}

return {result};
}
2 changes: 1 addition & 1 deletion rascaline-torch/tests/calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,5 +300,5 @@ TorchSystem test_system(bool positions_grad, bool cell_grad) {
auto cell = 10 * torch::eye(3);
cell.requires_grad_(cell_grad);

return torch::make_intrusive<SystemHolder>(species, positions, cell);
return torch::make_intrusive<rascaline_torch::SystemHolder>(species, positions, cell);
}
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ changedir = python/rascaline-torch
commands =
# install rascaline-torch
pip install . {[testenv]build-single-wheel} --force-reinstall

# TODO
pip install --no-deps metatensor-operations

# run the unit tests
pytest {[testenv]test_options} --assert=plain {posargs}

Expand Down

0 comments on commit 5ce8b88

Please sign in to comment.