Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Export TorchScript for Ridge #50

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Export TorchScript for Ridge #50

wants to merge 5 commits into from

Conversation

agoscinski
Copy link
Collaborator

@agoscinski agoscinski commented Jun 21, 2023

To support TorchScript we ran into some issues with the equistore.core.{TensorMap,TensorBlock,Labels} typehints. For a model supporting TorchScript, we needed to switch to the types of equistore.torch.{TensorMap,TensorBlock,Labels} . That made typehints complicated. Therefore we came to the conclusion to add a function that produces something torch jit compilable. In the torch_geometrics package they offer a function jittable that makes the current model compilable. We could have done something like this and dynamically change the typehints to equistor.torch types, but that would be a bit intransparent to the user. So we decided to add an export function that produces a compilable torch module from the current model parameters that is defined in a separate class. We need to reimplement the forward function for this separate class each time, but that is relative small code duplication for more readability of the code.

Notes

  • I updated the equistore version to include a version that includes PR Initial TorchScript version of the core classes metatensor/metatensor#263 But we will have to update the version until equistore.operations are full functional. I think we can wait with merging till then
  • I added the equistore-torch dependency in a requirement file, because I had the problem to use the #subdirectory tag in the tox.ini, maybe this can be made work but I was not able to

We might want to remove the inheritance scheme we initially tried to support TorchScript support for classes without rewriting the internal logic.


📚 Documentation preview 📚: https://equisolve--50.org.readthedocs.build/en/50/

* in nn module use HAS_TORCH from equisolve/__init__.py
* mv HAS_METATENSOR_TORCH to equisolve/__init__.py
* add function refresh_global_flags that allows to refresh the global
  flags
tensor maps are now pickable so we dont need to do it anymore
this commit prepares the Ridge class to use an export function to obtain
TorchScriptable modules
* add from_weight constructor to equisolve Linear torch module
* add helper function core_tensor_map_to_torch and transpose_tensor_map
@agoscinski agoscinski force-pushed the export-torchscript branch 2 times, most recently from 109d40e to 1331a29 Compare October 2, 2023 12:01
@agoscinski agoscinski marked this pull request as ready for review October 2, 2023 12:01
Copy link
Collaborator

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Alex, this is a very useful and important addition to export the models!

One general question besides the minor comments within in the code is that we set HAS_METATENSOR_TORCH to True whenever torch is available in the current environment. Should't we the user allow to set the behavior?

global HAS_METATENSOR_TORCH

try:
import torch # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there maybe a more clever way to check this with importlib. Then we do not have to escape the linting. A quick search found this: https://stackoverflow.com/questions/14050281/how-to-check-if-a-python-module-exists-without-importing-it

@@ -307,8 +306,7 @@ def fit(

weights_blocks.append(weight_block)

# convert weights to a dictionary allowing pickle dump of an instance
self._weights = tensor_map_to_dict(TensorMap(X.keys, weights_blocks))
self._weights = TensorMap(X.keys, weights_blocks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, that this workaround is gone now!

We could also remove the function tensor_map_to_dict?

Comment on lines +106 to +108
module = clf.export_torch_module()
y_pred_torch_module = module.forward(core_tensor_map_to_torch(X))
metatensor.torch.allclose_raise(y_pred_torch, y_pred_torch_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check that this is scriptable! At least since we promise this.

Comment on lines +81 to +102
def core_tensor_map_to_torch(core_tensor: TensorMap, device=None, dtype=None):
"""Transforms a tensor map from metatensor-core to metatensor-torch

:param core_tensor:
tensor map from metatensor-core

:param device:
:py:class:`torch.device` of values in the resulting tensor map

:param dtye:
:py:class:`torch.dtype` of values in the resulting tensor map

:returns torch_tensor:
tensor map from metatensor-torch
"""
from metatensor.torch import TensorMap as TorchTensorMap

torch_blocks = []
for _, core_block in core_tensor.items():
torch_blocks.append(core_tensor_block_to_torch(core_block, device, dtype))
torch_keys = core_labels_to_torch(core_tensor.keys)
return TorchTensorMap(torch_keys, torch_blocks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to me like functions that should live upstream metatensor directly. Pinging @Luthaf here for thoughts.

)


def transpose_tensor_map(tensor: TensorMap):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't work if there are components right?

Also this should be in metatensor operations. Even though if we do not support components for now...

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants