Skip to content

Commit

Permalink
Fix serialization of calculators in torch
Browse files Browse the repository at this point in the history
We need to store the registration name, not the self-reported calculator name
  • Loading branch information
Luthaf committed Nov 1, 2023
1 parent 30592db commit fca103e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions python/rascaline-torch/tests/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,4 @@ def forward(

with tmpdir.as_cwd():
torch.jit.save(module, "test-save.torch")
module = torch.jit.load("test-save.torch")
9 changes: 8 additions & 1 deletion rascaline-torch/include/rascaline/torch/calculator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,20 @@ class RASCALINE_TORCH_EXPORT CalculatorHolder: public torch::CustomClassHolder {
public:
/// Create a new calculator with the given `name` and JSON `parameters`
CalculatorHolder(std::string name, std::string parameters):
calculator_(std::move(name), std::move(parameters))
c_name_(std::move(name)),
calculator_(c_name_, std::move(parameters))
{}

/// Get the name of this calculator
std::string name() const {
return calculator_.name();
}

/// Get the name used to register this calculator
std::string c_name() const {
return c_name_;
}

/// Get the parameters of this calculator
std::string parameters() const {
return calculator_.parameters();
Expand All @@ -100,6 +106,7 @@ class RASCALINE_TORCH_EXPORT CalculatorHolder: public torch::CustomClassHolder {
);

private:
std::string c_name_;
rascaline::Calculator calculator_;
};

Expand Down
8 changes: 4 additions & 4 deletions rascaline-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ TORCH_LIBRARY(rascaline, module) {
})
.def_pickle(
// __getstate__
[](const TorchCalculator& self) -> std::vector<std::string> {
return {self->name(), self->parameters()};
[](const TorchCalculator& self) -> std::tuple<std::string, std::string> {
return {self->c_name(), self->parameters()};
},
// __setstate__
[](std::vector<std::string> state) -> TorchCalculator {
[](std::tuple<std::string, std::string> state) -> TorchCalculator {
return c10::make_intrusive<CalculatorHolder>(
state[0], state[1]
std::get<0>(state), std::get<1>(state)
);
})
;
Expand Down

0 comments on commit fca103e

Please sign in to comment.