Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
add export_torch_module to linear module and tests
Browse files Browse the repository at this point in the history
* add from_weight constructor to equisolve Linear torch module
* add helper function core_tensor_map_to_torch and transpose_tensor_map
  • Loading branch information
agoscinski committed Oct 2, 2023
1 parent 75a997c commit 1331a29
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 2 deletions.
28 changes: 28 additions & 0 deletions src/equisolve/nn/module_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,34 @@ def from_module(
module = torch.nn.Linear(in_features, out_features, bias, device, dtype)
return ModuleTensorMap.from_module(in_keys, module, many_to_one, out_tensor)

@classmethod
def from_weights(cls, weights: TensorMap, bias: Optional[TensorMap] = None):
"""
:param weights:
The weight tensor map from which we create the linear modules. The
properties of the tensor map describe the input dimension and the samples
describe the output dimension.
:param bias:
The weight tensor map from which we create the linear layers.
"""
module_map = ModuleDict()
for key, weights_block in weights.items():
module_key = ModuleTensorMap.module_key(key)
module = torch.nn.Linear(
len(weights_block.samples),
len(weights_block.properties),
bias=False,
device=weights_block.values.device,
dtype=weights_block.values.dtype,
)
module.weight = torch.nn.Parameter(weights_block.values.T)
if bias is not None:
module.bias = torch.nn.Parameter(bias.block(key).values)
module_map[module_key] = module

return ModuleTensorMap(module_map, weights)

def forward(self, tensor: TensorMap) -> TensorMap:
# added to appear in doc, :inherited-members: is not compatible with torch
return super().forward(tensor)
31 changes: 30 additions & 1 deletion src/equisolve/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ...module import _Estimator
from ...utils.metrics import rmse
from ..utils import array_from_block
from ..utils import array_from_block, core_tensor_map_to_torch, transpose_tensor_map


class Ridge(_Estimator):
Expand Down Expand Up @@ -350,3 +350,32 @@ def score(self, X: TensorMap, y: TensorMap, parameter_key: str) -> float:
"""
y_pred = self.predict(X)
return rmse(y, y_pred, parameter_key)

def export_torch_module(self, device=None, dtype=None):
"""
Export existing weights to a child class :py:class:`torch.nn.Module` so it can
:py:mod:`torch.jit` utils can be applied.
:param device:
:py:class:`torch.device` of values in the resulting module
:param dtye:
:py:class:`torch.dtype` of the values in the resulting module
:returns linear:
a :py:class:`equisolve.nn.Linear`
"""
from ... import HAS_METATENSOR_TORCH

if not HAS_METATENSOR_TORCH:
raise ImportError(
"To export your model to TorchScript torch needs to be installed. "
"Please install torch, then reimport equisolve or "
"use equisolve.refresh_global_flags()."
)
from ...nn import Linear

torch_weights = core_tensor_map_to_torch(
transpose_tensor_map(self.weights), device, dtype
)
return Linear.from_weights(torch_weights)
88 changes: 87 additions & 1 deletion src/equisolve/numpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import metatensor
import numpy as np
from metatensor import TensorBlock, TensorMap
from metatensor import Labels, TensorBlock, TensorMap


def array_from_block(block: TensorBlock) -> np.ndarray:
Expand Down Expand Up @@ -76,3 +76,89 @@ def dict_to_tensor_map(tensor_map_dict: dict):
tmp_filename = tempfile.mktemp() + ".npz"
np.savez(tmp_filename, **tensor_map_dict)
return metatensor.load(tmp_filename)


def core_tensor_map_to_torch(core_tensor: TensorMap, device=None, dtype=None):
"""Transforms a tensor map from metatensor-core to metatensor-torch
:param core_tensor:
tensor map from metatensor-core
:param device:
:py:class:`torch.device` of values in the resulting tensor map
:param dtye:
:py:class:`torch.dtype` of values in the resulting tensor map
:returns torch_tensor:
tensor map from metatensor-torch
"""
from metatensor.torch import TensorMap as TorchTensorMap

torch_blocks = []
for _, core_block in core_tensor.items():
torch_blocks.append(core_tensor_block_to_torch(core_block, device, dtype))
torch_keys = core_labels_to_torch(core_tensor.keys)
return TorchTensorMap(torch_keys, torch_blocks)


def core_tensor_block_to_torch(core_block: TensorBlock, device=None, dtype=None):
"""Transforms a tensor block from metatensor-core to metatensor-torch
:param core_block:
tensor block from metatensor-core
:param device:
:py:class:`torch.device` of values in the resulting block and labels
:param dtye:
:py:class:`torch.dtype` of values in the resulting block and labels
:returns torch_block:
tensor block from metatensor-torch
"""
import torch
from metatensor.torch import TensorBlock as TorchTensorBlock

return TorchTensorBlock(
values=torch.tensor(core_block.values, device=device, dtype=dtype),
samples=core_labels_to_torch(core_block.samples, device=device),
components=[
core_labels_to_torch(component, device=device)
for component in core_block.components
],
properties=core_labels_to_torch(core_block.properties, device=device),
)


def core_labels_to_torch(core_labels: Labels, device=None):
"""Transforms labels from metatensor-core to metatensor-torch
:param core_block:
tensor block from metatensor-core
:param device:
:py:class:`torch.device` of values in the resulting labels
:returns torch_block:
labels from metatensor-torch
"""
import torch
from metatensor.torch import Labels as TorchLabels

return TorchLabels(
core_labels.names, torch.tensor(core_labels.values, device=device)
)


def transpose_tensor_map(tensor: TensorMap):
blocks = []
for block in tensor.blocks():
block = TensorBlock(
values=block.values.T,
samples=block.properties,
components=block.components,
properties=block.samples,
)
blocks.append(block)
return TensorMap(tensor.keys, blocks)
28 changes: 28 additions & 0 deletions tests/equisolve_tests/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from metatensor import Labels, TensorBlock, TensorMap
from numpy.testing import assert_allclose, assert_equal

from equisolve import HAS_METATENSOR_TORCH
from equisolve.numpy.models import Ridge
from equisolve.numpy.utils import core_tensor_map_to_torch

from ..utilities import tensor_to_tensormap

Expand Down Expand Up @@ -79,6 +81,32 @@ def equisolve_solver_from_numpy_arrays(
clf.fit(X=X, y=y, alpha=alpha, sample_weight=sw, solver=solver)
return clf

@pytest.mark.skipif(
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run"
)
def test_export_torch_module(self):
"""Test if ridge is working and all shapes are converted correctly.
Test is performed for two blocks.
"""

num_targets = 50
num_properties = 5

# Create input values
X_arr = self.rng.random([2, num_targets, num_properties])
y_arr = self.rng.random([2, num_targets, 1])

X = tensor_to_tensormap(X_arr)
y = tensor_to_tensormap(y_arr)

clf = Ridge()
clf.fit(X=X, y=y)
y_pred_torch = core_tensor_map_to_torch(clf.predict(X))

module = clf.export_torch_module()
y_pred_torch_module = module.forward(core_tensor_map_to_torch(X))
metatensor.torch.allclose_raise(y_pred_torch, y_pred_torch_module)

num_properties = np.array([91])
num_targets = np.array([1000])
means = np.array([-0.5, 0, 0.1])
Expand Down

0 comments on commit 1331a29

Please sign in to comment.