-
Notifications
You must be signed in to change notification settings - Fork 1
Export TorchScript for Ridge #50
base: main
Are you sure you want to change the base?
Conversation
* in nn module use HAS_TORCH from equisolve/__init__.py * mv HAS_METATENSOR_TORCH to equisolve/__init__.py * add function refresh_global_flags that allows to refresh the global flags
tensor maps are now pickable so we dont need to do it anymore
109d40e
to
56f856e
Compare
this commit prepares the Ridge class to use an export function to obtain TorchScriptable modules
* add from_weight constructor to equisolve Linear torch module * add helper function core_tensor_map_to_torch and transpose_tensor_map
109d40e
to
1331a29
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Alex, this is a very useful and important addition to export the models!
One general question besides the minor comments within in the code is that we set HAS_METATENSOR_TORCH
to True
whenever torch is available in the current environment. Should't we the user allow to set the behavior?
global HAS_METATENSOR_TORCH | ||
|
||
try: | ||
import torch # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there maybe a more clever way to check this with importlib. Then we do not have to escape the linting. A quick search found this: https://stackoverflow.com/questions/14050281/how-to-check-if-a-python-module-exists-without-importing-it
@@ -307,8 +306,7 @@ def fit( | |||
|
|||
weights_blocks.append(weight_block) | |||
|
|||
# convert weights to a dictionary allowing pickle dump of an instance | |||
self._weights = tensor_map_to_dict(TensorMap(X.keys, weights_blocks)) | |||
self._weights = TensorMap(X.keys, weights_blocks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, that this workaround is gone now!
We could also remove the function tensor_map_to_dict
?
module = clf.export_torch_module() | ||
y_pred_torch_module = module.forward(core_tensor_map_to_torch(X)) | ||
metatensor.torch.allclose_raise(y_pred_torch, y_pred_torch_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check that this is scriptable! At least since we promise this.
def core_tensor_map_to_torch(core_tensor: TensorMap, device=None, dtype=None): | ||
"""Transforms a tensor map from metatensor-core to metatensor-torch | ||
|
||
:param core_tensor: | ||
tensor map from metatensor-core | ||
|
||
:param device: | ||
:py:class:`torch.device` of values in the resulting tensor map | ||
|
||
:param dtye: | ||
:py:class:`torch.dtype` of values in the resulting tensor map | ||
|
||
:returns torch_tensor: | ||
tensor map from metatensor-torch | ||
""" | ||
from metatensor.torch import TensorMap as TorchTensorMap | ||
|
||
torch_blocks = [] | ||
for _, core_block in core_tensor.items(): | ||
torch_blocks.append(core_tensor_block_to_torch(core_block, device, dtype)) | ||
torch_keys = core_labels_to_torch(core_tensor.keys) | ||
return TorchTensorMap(torch_keys, torch_blocks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to me like functions that should live upstream metatensor directly. Pinging @Luthaf here for thoughts.
) | ||
|
||
|
||
def transpose_tensor_map(tensor: TensorMap): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this doesn't work if there are components right?
Also this should be in metatensor operations. Even though if we do not support components for now...
To support TorchScript we ran into some issues with the equistore.core.{TensorMap,TensorBlock,Labels} typehints. For a model supporting TorchScript, we needed to switch to the types of equistore.torch.{TensorMap,TensorBlock,Labels} . That made typehints complicated. Therefore we came to the conclusion to add a function that produces something torch jit compilable. In the
torch_geometrics
package they offer a functionjittable
that makes the current model compilable. We could have done something like this and dynamically change the typehints to equistor.torch types, but that would be a bit intransparent to the user. So we decided to add anexport
function that produces a compilable torch module from the current model parameters that is defined in a separate class. We need to reimplement the forward function for this separate class each time, but that is relative small code duplication for more readability of the code.Notes
#subdirectory
tag in the tox.ini, maybe this can be made work but I was not able toWe might want to remove the inheritance scheme we initially tried to support TorchScript support for classes without rewriting the internal logic.
📚 Documentation preview 📚: https://equisolve--50.org.readthedocs.build/en/50/