-
Notifications
You must be signed in to change notification settings - Fork 1
Export TorchScript for Ridge #50
base: main
Are you sure you want to change the base?
Changes from all commits
c0c7c19
177b4ef
29799cb
75a997c
1331a29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,13 +13,12 @@ | |
import scipy.linalg | ||
from metatensor import Labels, TensorBlock, TensorMap | ||
|
||
from ... import HAS_TORCH | ||
from ...module import NumpyModule, _Estimator | ||
from ...module import _Estimator | ||
from ...utils.metrics import rmse | ||
from ..utils import array_from_block, dict_to_tensor_map, tensor_map_to_dict | ||
from ..utils import array_from_block, core_tensor_map_to_torch, transpose_tensor_map | ||
|
||
|
||
class _Ridge(_Estimator): | ||
class Ridge(_Estimator): | ||
r"""Linear least squares with l2 regularization for :class:`metatensor.Tensormap`'s. | ||
|
||
Weights :math:`w` are calculated according to | ||
|
@@ -307,8 +306,7 @@ def fit( | |
|
||
weights_blocks.append(weight_block) | ||
|
||
# convert weights to a dictionary allowing pickle dump of an instance | ||
self._weights = tensor_map_to_dict(TensorMap(X.keys, weights_blocks)) | ||
self._weights = TensorMap(X.keys, weights_blocks) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, that this workaround is gone now! We could also remove the function |
||
|
||
return self | ||
|
||
|
@@ -319,7 +317,7 @@ def weights(self) -> TensorMap: | |
if self._weights is None: | ||
raise ValueError("No weights. Call fit method first.") | ||
|
||
return dict_to_tensor_map(self._weights) | ||
return self._weights | ||
|
||
def predict(self, X: TensorMap) -> TensorMap: | ||
""" | ||
|
@@ -353,21 +351,31 @@ def score(self, X: TensorMap, y: TensorMap, parameter_key: str) -> float: | |
y_pred = self.predict(X) | ||
return rmse(y, y_pred, parameter_key) | ||
|
||
def export_torch_module(self, device=None, dtype=None): | ||
""" | ||
Export existing weights to a child class :py:class:`torch.nn.Module` so it can | ||
:py:mod:`torch.jit` utils can be applied. | ||
|
||
class NumpyRidge(_Ridge, NumpyModule): | ||
def __init__(self) -> None: | ||
NumpyModule.__init__(self) | ||
_Ridge.__init__(self) | ||
:param device: | ||
:py:class:`torch.device` of values in the resulting module | ||
|
||
:param dtye: | ||
:py:class:`torch.dtype` of the values in the resulting module | ||
|
||
if HAS_TORCH: | ||
import torch | ||
:returns linear: | ||
a :py:class:`equisolve.nn.Linear` | ||
""" | ||
from ... import HAS_METATENSOR_TORCH | ||
|
||
class TorchRidge(_Ridge, torch.nn.Module): | ||
def __init__(self) -> None: | ||
torch.nn.Module.__init__(self) | ||
_Ridge.__init__(self) | ||
if not HAS_METATENSOR_TORCH: | ||
raise ImportError( | ||
"To export your model to TorchScript torch needs to be installed. " | ||
"Please install torch, then reimport equisolve or " | ||
"use equisolve.refresh_global_flags()." | ||
) | ||
from ...nn import Linear | ||
|
||
Ridge = TorchRidge | ||
else: | ||
Ridge = NumpyRidge | ||
torch_weights = core_tensor_map_to_torch( | ||
transpose_tensor_map(self.weights), device, dtype | ||
) | ||
return Linear.from_weights(torch_weights) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
import metatensor | ||
import numpy as np | ||
from metatensor import TensorBlock, TensorMap | ||
from metatensor import Labels, TensorBlock, TensorMap | ||
|
||
|
||
def array_from_block(block: TensorBlock) -> np.ndarray: | ||
|
@@ -76,3 +76,89 @@ def dict_to_tensor_map(tensor_map_dict: dict): | |
tmp_filename = tempfile.mktemp() + ".npz" | ||
np.savez(tmp_filename, **tensor_map_dict) | ||
return metatensor.load(tmp_filename) | ||
|
||
|
||
def core_tensor_map_to_torch(core_tensor: TensorMap, device=None, dtype=None): | ||
"""Transforms a tensor map from metatensor-core to metatensor-torch | ||
|
||
:param core_tensor: | ||
tensor map from metatensor-core | ||
|
||
:param device: | ||
:py:class:`torch.device` of values in the resulting tensor map | ||
|
||
:param dtye: | ||
:py:class:`torch.dtype` of values in the resulting tensor map | ||
|
||
:returns torch_tensor: | ||
tensor map from metatensor-torch | ||
""" | ||
from metatensor.torch import TensorMap as TorchTensorMap | ||
|
||
torch_blocks = [] | ||
for _, core_block in core_tensor.items(): | ||
torch_blocks.append(core_tensor_block_to_torch(core_block, device, dtype)) | ||
torch_keys = core_labels_to_torch(core_tensor.keys) | ||
return TorchTensorMap(torch_keys, torch_blocks) | ||
Comment on lines
+81
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to me like functions that should live upstream metatensor directly. Pinging @Luthaf here for thoughts. |
||
|
||
|
||
def core_tensor_block_to_torch(core_block: TensorBlock, device=None, dtype=None): | ||
"""Transforms a tensor block from metatensor-core to metatensor-torch | ||
|
||
:param core_block: | ||
tensor block from metatensor-core | ||
|
||
:param device: | ||
:py:class:`torch.device` of values in the resulting block and labels | ||
|
||
:param dtye: | ||
:py:class:`torch.dtype` of values in the resulting block and labels | ||
|
||
:returns torch_block: | ||
tensor block from metatensor-torch | ||
""" | ||
import torch | ||
from metatensor.torch import TensorBlock as TorchTensorBlock | ||
|
||
return TorchTensorBlock( | ||
values=torch.tensor(core_block.values, device=device, dtype=dtype), | ||
samples=core_labels_to_torch(core_block.samples, device=device), | ||
components=[ | ||
core_labels_to_torch(component, device=device) | ||
for component in core_block.components | ||
], | ||
properties=core_labels_to_torch(core_block.properties, device=device), | ||
) | ||
|
||
|
||
def core_labels_to_torch(core_labels: Labels, device=None): | ||
"""Transforms labels from metatensor-core to metatensor-torch | ||
|
||
:param core_block: | ||
tensor block from metatensor-core | ||
|
||
:param device: | ||
:py:class:`torch.device` of values in the resulting labels | ||
|
||
:returns torch_block: | ||
labels from metatensor-torch | ||
""" | ||
import torch | ||
from metatensor.torch import Labels as TorchLabels | ||
|
||
return TorchLabels( | ||
core_labels.names, torch.tensor(core_labels.values, device=device) | ||
) | ||
|
||
|
||
def transpose_tensor_map(tensor: TensorMap): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't work if there are components right? Also this should be in metatensor operations. Even though if we do not support components for now... |
||
blocks = [] | ||
for block in tensor.blocks(): | ||
block = TensorBlock( | ||
values=block.values.T, | ||
samples=block.properties, | ||
components=block.components, | ||
properties=block.samples, | ||
) | ||
blocks.append(block) | ||
return TensorMap(tensor.keys, blocks) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,9 @@ | |
from metatensor import Labels, TensorBlock, TensorMap | ||
from numpy.testing import assert_allclose, assert_equal | ||
|
||
from equisolve import HAS_METATENSOR_TORCH | ||
from equisolve.numpy.models import Ridge | ||
from equisolve.numpy.utils import core_tensor_map_to_torch | ||
|
||
from ..utilities import tensor_to_tensormap | ||
|
||
|
@@ -79,6 +81,32 @@ def equisolve_solver_from_numpy_arrays( | |
clf.fit(X=X, y=y, alpha=alpha, sample_weight=sw, solver=solver) | ||
return clf | ||
|
||
@pytest.mark.skipif( | ||
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run" | ||
) | ||
def test_export_torch_module(self): | ||
"""Test if ridge is working and all shapes are converted correctly. | ||
Test is performed for two blocks. | ||
""" | ||
|
||
num_targets = 50 | ||
num_properties = 5 | ||
|
||
# Create input values | ||
X_arr = self.rng.random([2, num_targets, num_properties]) | ||
y_arr = self.rng.random([2, num_targets, 1]) | ||
|
||
X = tensor_to_tensormap(X_arr) | ||
y = tensor_to_tensormap(y_arr) | ||
|
||
clf = Ridge() | ||
clf.fit(X=X, y=y) | ||
y_pred_torch = core_tensor_map_to_torch(clf.predict(X)) | ||
|
||
module = clf.export_torch_module() | ||
y_pred_torch_module = module.forward(core_tensor_map_to_torch(X)) | ||
metatensor.torch.allclose_raise(y_pred_torch, y_pred_torch_module) | ||
Comment on lines
+106
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should check that this is scriptable! At least since we promise this. |
||
|
||
num_properties = np.array([91]) | ||
num_targets = np.array([1000]) | ||
means = np.array([-0.5, 0, 0.1]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there maybe a more clever way to check this with importlib. Then we do not have to escape the linting. A quick search found this: https://stackoverflow.com/questions/14050281/how-to-check-if-a-python-module-exists-without-importing-it