diff --git a/python/rascaline-torch/tests/calculator.py b/python/rascaline-torch/tests/calculator.py index 28c2c6999..1c26be1bb 100644 --- a/python/rascaline-torch/tests/calculator.py +++ b/python/rascaline-torch/tests/calculator.py @@ -190,3 +190,4 @@ def forward( with tmpdir.as_cwd(): torch.jit.save(module, "test-save.torch") + module = torch.jit.load("test-save.torch") diff --git a/rascaline-torch/include/rascaline/torch/calculator.hpp b/rascaline-torch/include/rascaline/torch/calculator.hpp index bdc09bb1f..ee54929de 100644 --- a/rascaline-torch/include/rascaline/torch/calculator.hpp +++ b/rascaline-torch/include/rascaline/torch/calculator.hpp @@ -75,7 +75,8 @@ class RASCALINE_TORCH_EXPORT CalculatorHolder: public torch::CustomClassHolder { public: /// Create a new calculator with the given `name` and JSON `parameters` CalculatorHolder(std::string name, std::string parameters): - calculator_(std::move(name), std::move(parameters)) + c_name_(std::move(name)), + calculator_(c_name_, std::move(parameters)) {} /// Get the name of this calculator @@ -83,6 +84,11 @@ class RASCALINE_TORCH_EXPORT CalculatorHolder: public torch::CustomClassHolder { return calculator_.name(); } + /// Get the name used to register this calculator + std::string c_name() const { + return c_name_; + } + /// Get the parameters of this calculator std::string parameters() const { return calculator_.parameters(); @@ -100,6 +106,7 @@ class RASCALINE_TORCH_EXPORT CalculatorHolder: public torch::CustomClassHolder { ); private: + std::string c_name_; rascaline::Calculator calculator_; }; diff --git a/rascaline-torch/src/register.cpp b/rascaline-torch/src/register.cpp index 965b10865..75cf845d4 100644 --- a/rascaline-torch/src/register.cpp +++ b/rascaline-torch/src/register.cpp @@ -52,13 +52,13 @@ TORCH_LIBRARY(rascaline, module) { }) .def_pickle( // __getstate__ - [](const TorchCalculator& self) -> std::vector { - return {self->name(), self->parameters()}; + [](const TorchCalculator& self) -> std::tuple { + return {self->c_name(), self->parameters()}; }, // __setstate__ - [](std::vector state) -> TorchCalculator { + [](std::tuple state) -> TorchCalculator { return c10::make_intrusive( - state[0], state[1] + std::get<0>(state), std::get<1>(state) ); }) ;