Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Export TorchScript for Ridge #50

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions src/equisolve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,32 @@
__version__ = "0.0.0-dev"
__authors__ = "the equisolve development team"

# For a global consistent state of the package, we try to load here torch,
# since torch is an optional dependency
try:
import torch # noqa: F401

HAS_TORCH = True
except ImportError:
HAS_TORCH = False

def refresh_global_flags():
"""
Refreshes all global flags set on import of library. This function might be useful
if one is in an interactive session and installed some of the optional dependenicies
(torch, metatensor-torch) after importing the library.
"""
global HAS_TORCH
global HAS_METATENSOR_TORCH

try:
import torch # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there maybe a more clever way to check this with importlib. Then we do not have to escape the linting. A quick search found this: https://stackoverflow.com/questions/14050281/how-to-check-if-a-python-module-exists-without-importing-it


HAS_TORCH = True
except ImportError:
HAS_TORCH = False

try:
from metatensor.torch import Labels, TensorBlock, TensorMap # noqa: F401

HAS_METATENSOR_TORCH = True
except ImportError:
from metatensor import Labels, TensorBlock, TensorMap # noqa: F401

HAS_METATENSOR_TORCH = False


# For a global consistent state of the package, we set the global flags once here
refresh_global_flags()
6 changes: 1 addition & 5 deletions src/equisolve/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
try:
import torch # noqa: F401
from .. import HAS_TORCH

HAS_TORCH = True
except ImportError:
HAS_TORCH = False

if HAS_TORCH:
from .module_tensor import Linear, ModuleTensorMap # noqa: F401
Expand Down
60 changes: 50 additions & 10 deletions src/equisolve/nn/module_tensor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
try:
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap
from .. import HAS_METATENSOR_TORCH

HAS_METATENSOR_TORCH = True
except ImportError:
from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap

HAS_METATENSOR_TORCH = False
if HAS_METATENSOR_TORCH:
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap
else:
from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap

from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Union

import torch
from torch.nn import Module, ModuleDict
Expand Down Expand Up @@ -168,23 +167,36 @@ class Linear(ModuleTensorMap):
properties, the labels of the properties cannot be persevered.

:param bias:
See :py:class:`torch.nn.Linear`
See :py:class:`torch.nn.Linear` for bool as input. For each TensorMap key the
bias can be also individually tuend by using a TensorMap with one value for the
bool.
"""

def __init__(
self,
in_tensor: TensorMap,
out_tensor: TensorMap,
bias: bool = True,
bias: Union[bool, TensorMap] = True,
):
if isinstance(bias, bool):
blocks = [
TensorBlock(
values=torch.tensor(bias).reshape(1, 1),
samples=Labels.range("_", 1),
components=[],
properties=Labels.range("_", 1),
)
for _ in in_tensor.keys
]
bias = TensorMap(keys=in_tensor.keys, blocks=blocks)
module_map = ModuleDict()
for key, in_block in in_tensor.items():
module_key = ModuleTensorMap.module_key(key)
out_block = out_tensor.block(key)
module = torch.nn.Linear(
len(in_block.properties),
len(out_block.properties),
bias,
bias.block(key).values.flatten()[0],
in_block.values.device,
in_block.values.dtype,
)
Expand Down Expand Up @@ -230,6 +242,34 @@ def from_module(
module = torch.nn.Linear(in_features, out_features, bias, device, dtype)
return ModuleTensorMap.from_module(in_keys, module, many_to_one, out_tensor)

@classmethod
def from_weights(cls, weights: TensorMap, bias: Optional[TensorMap] = None):
"""
:param weights:
The weight tensor map from which we create the linear modules. The
properties of the tensor map describe the input dimension and the samples
describe the output dimension.

:param bias:
The weight tensor map from which we create the linear layers.
"""
module_map = ModuleDict()
for key, weights_block in weights.items():
module_key = ModuleTensorMap.module_key(key)
module = torch.nn.Linear(
len(weights_block.samples),
len(weights_block.properties),
bias=False,
device=weights_block.values.device,
dtype=weights_block.values.dtype,
)
module.weight = torch.nn.Parameter(weights_block.values.T)
if bias is not None:
module.bias = torch.nn.Parameter(bias.block(key).values)
module_map[module_key] = module

return ModuleTensorMap(module_map, weights)

def forward(self, tensor: TensorMap) -> TensorMap:
# added to appear in doc, :inherited-members: is not compatible with torch
return super().forward(tensor)
48 changes: 28 additions & 20 deletions src/equisolve/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
import scipy.linalg
from metatensor import Labels, TensorBlock, TensorMap

from ... import HAS_TORCH
from ...module import NumpyModule, _Estimator
from ...module import _Estimator
from ...utils.metrics import rmse
from ..utils import array_from_block, dict_to_tensor_map, tensor_map_to_dict
from ..utils import array_from_block, core_tensor_map_to_torch, transpose_tensor_map


class _Ridge(_Estimator):
class Ridge(_Estimator):
r"""Linear least squares with l2 regularization for :class:`metatensor.Tensormap`'s.

Weights :math:`w` are calculated according to
Expand Down Expand Up @@ -307,8 +306,7 @@ def fit(

weights_blocks.append(weight_block)

# convert weights to a dictionary allowing pickle dump of an instance
self._weights = tensor_map_to_dict(TensorMap(X.keys, weights_blocks))
self._weights = TensorMap(X.keys, weights_blocks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, that this workaround is gone now!

We could also remove the function tensor_map_to_dict?


return self

Expand All @@ -319,7 +317,7 @@ def weights(self) -> TensorMap:
if self._weights is None:
raise ValueError("No weights. Call fit method first.")

return dict_to_tensor_map(self._weights)
return self._weights

def predict(self, X: TensorMap) -> TensorMap:
"""
Expand Down Expand Up @@ -353,21 +351,31 @@ def score(self, X: TensorMap, y: TensorMap, parameter_key: str) -> float:
y_pred = self.predict(X)
return rmse(y, y_pred, parameter_key)

def export_torch_module(self, device=None, dtype=None):
"""
Export existing weights to a child class :py:class:`torch.nn.Module` so it can
:py:mod:`torch.jit` utils can be applied.

class NumpyRidge(_Ridge, NumpyModule):
def __init__(self) -> None:
NumpyModule.__init__(self)
_Ridge.__init__(self)
:param device:
:py:class:`torch.device` of values in the resulting module

:param dtye:
:py:class:`torch.dtype` of the values in the resulting module

if HAS_TORCH:
import torch
:returns linear:
a :py:class:`equisolve.nn.Linear`
"""
from ... import HAS_METATENSOR_TORCH

class TorchRidge(_Ridge, torch.nn.Module):
def __init__(self) -> None:
torch.nn.Module.__init__(self)
_Ridge.__init__(self)
if not HAS_METATENSOR_TORCH:
raise ImportError(
"To export your model to TorchScript torch needs to be installed. "
"Please install torch, then reimport equisolve or "
"use equisolve.refresh_global_flags()."
)
from ...nn import Linear

Ridge = TorchRidge
else:
Ridge = NumpyRidge
torch_weights = core_tensor_map_to_torch(
transpose_tensor_map(self.weights), device, dtype
)
return Linear.from_weights(torch_weights)
88 changes: 87 additions & 1 deletion src/equisolve/numpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import metatensor
import numpy as np
from metatensor import TensorBlock, TensorMap
from metatensor import Labels, TensorBlock, TensorMap


def array_from_block(block: TensorBlock) -> np.ndarray:
Expand Down Expand Up @@ -76,3 +76,89 @@ def dict_to_tensor_map(tensor_map_dict: dict):
tmp_filename = tempfile.mktemp() + ".npz"
np.savez(tmp_filename, **tensor_map_dict)
return metatensor.load(tmp_filename)


def core_tensor_map_to_torch(core_tensor: TensorMap, device=None, dtype=None):
"""Transforms a tensor map from metatensor-core to metatensor-torch

:param core_tensor:
tensor map from metatensor-core

:param device:
:py:class:`torch.device` of values in the resulting tensor map

:param dtye:
:py:class:`torch.dtype` of values in the resulting tensor map

:returns torch_tensor:
tensor map from metatensor-torch
"""
from metatensor.torch import TensorMap as TorchTensorMap

torch_blocks = []
for _, core_block in core_tensor.items():
torch_blocks.append(core_tensor_block_to_torch(core_block, device, dtype))
torch_keys = core_labels_to_torch(core_tensor.keys)
return TorchTensorMap(torch_keys, torch_blocks)
Comment on lines +81 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to me like functions that should live upstream metatensor directly. Pinging @Luthaf here for thoughts.



def core_tensor_block_to_torch(core_block: TensorBlock, device=None, dtype=None):
"""Transforms a tensor block from metatensor-core to metatensor-torch

:param core_block:
tensor block from metatensor-core

:param device:
:py:class:`torch.device` of values in the resulting block and labels

:param dtye:
:py:class:`torch.dtype` of values in the resulting block and labels

:returns torch_block:
tensor block from metatensor-torch
"""
import torch
from metatensor.torch import TensorBlock as TorchTensorBlock

return TorchTensorBlock(
values=torch.tensor(core_block.values, device=device, dtype=dtype),
samples=core_labels_to_torch(core_block.samples, device=device),
components=[
core_labels_to_torch(component, device=device)
for component in core_block.components
],
properties=core_labels_to_torch(core_block.properties, device=device),
)


def core_labels_to_torch(core_labels: Labels, device=None):
"""Transforms labels from metatensor-core to metatensor-torch

:param core_block:
tensor block from metatensor-core

:param device:
:py:class:`torch.device` of values in the resulting labels

:returns torch_block:
labels from metatensor-torch
"""
import torch
from metatensor.torch import Labels as TorchLabels

return TorchLabels(
core_labels.names, torch.tensor(core_labels.values, device=device)
)


def transpose_tensor_map(tensor: TensorMap):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't work if there are components right?

Also this should be in metatensor operations. Even though if we do not support components for now...

blocks = []
for block in tensor.blocks():
block = TensorBlock(
values=block.values.T,
samples=block.properties,
components=block.components,
properties=block.samples,
)
blocks.append(block)
return TensorMap(tensor.keys, blocks)
28 changes: 28 additions & 0 deletions tests/equisolve_tests/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from metatensor import Labels, TensorBlock, TensorMap
from numpy.testing import assert_allclose, assert_equal

from equisolve import HAS_METATENSOR_TORCH
from equisolve.numpy.models import Ridge
from equisolve.numpy.utils import core_tensor_map_to_torch

from ..utilities import tensor_to_tensormap

Expand Down Expand Up @@ -79,6 +81,32 @@ def equisolve_solver_from_numpy_arrays(
clf.fit(X=X, y=y, alpha=alpha, sample_weight=sw, solver=solver)
return clf

@pytest.mark.skipif(
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run"
)
def test_export_torch_module(self):
"""Test if ridge is working and all shapes are converted correctly.
Test is performed for two blocks.
"""

num_targets = 50
num_properties = 5

# Create input values
X_arr = self.rng.random([2, num_targets, num_properties])
y_arr = self.rng.random([2, num_targets, 1])

X = tensor_to_tensormap(X_arr)
y = tensor_to_tensormap(y_arr)

clf = Ridge()
clf.fit(X=X, y=y)
y_pred_torch = core_tensor_map_to_torch(clf.predict(X))

module = clf.export_torch_module()
y_pred_torch_module = module.forward(core_tensor_map_to_torch(X))
metatensor.torch.allclose_raise(y_pred_torch, y_pred_torch_module)
Comment on lines +106 to +108
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check that this is scriptable! At least since we promise this.


num_properties = np.array([91])
num_targets = np.array([1000])
means = np.array([-0.5, 0, 0.1])
Expand Down