Skip to content

Commit

Permalink
Make Python PS torch compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Aug 21, 2023
1 parent e800e96 commit 7f911d8
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 34 deletions.
1 change: 1 addition & 0 deletions docs/src/references/api/torch/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ of rascaline are documented below for an usage from Python:

systems
calculators
utils/index

--------------------------------------------------------------------------------

Expand Down
10 changes: 10 additions & 0 deletions docs/src/references/api/torch/utils/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Utils
=====

Utility functions and classes that extend the core usage of rascaline-torch


.. toctree::
:maxdepth: 1

power-spectrum
6 changes: 6 additions & 0 deletions docs/src/references/api/torch/utils/power-spectrum.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
PowerSpectrum
=============

.. autoclass:: rascaline.torch.utils.PowerSpectrum
:members:
:show-inheritance:
4 changes: 4 additions & 0 deletions python/rascaline-torch/rascaline/torch/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .power_spectrum import PowerSpectrum


__all__ = ["PowerSpectrum"]
18 changes: 18 additions & 0 deletions python/rascaline-torch/rascaline/torch/utils/_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from equistore.torch import Labels, TensorBlock, TensorMap
from torch.nn import Module as torch_nn_module

from rascaline.utils import _dispatch

from ..calculator_base import CalculatorModule as CalculatorBase
from ..system import System as IntoSystem


__all__ = [
"_dispatch",
"CalculatorBase",
"IntoSystem",
"Labels",
"torch_nn_module",
"TensorBlock",
"TensorMap",
]
20 changes: 20 additions & 0 deletions python/rascaline-torch/rascaline/torch/utils/power_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import importlib
import sys

import rascaline.utils.power_spectrum


# For details what is happening here take a look an `rascaline.torch.calculators`.
spec = importlib.util.spec_from_file_location(
# create a module with this name
"rascaline.torch.utils.power_spectrum",
# using the code from there
rascaline.utils.power_spectrum.__file__,
)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

# don't forget to also update `rascaline/utils/__init__.py` and
# `rascaline/torch/utils.__init__.py` when modifying this file
PowerSpectrum = module.PowerSpectrum
51 changes: 51 additions & 0 deletions python/rascaline-torch/tests/utils/power_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from packaging import version

from rascaline.torch import System
from rascaline.torch.calculators import SphericalExpansion
from rascaline.torch.utils import PowerSpectrum


def system():
return System(
species=torch.tensor([1, 1, 8, 8]),
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]),
cell=torch.tensor([[10, 0, 0], [0, 10, 0], [0, 0, 10]]),
)


def calculator():
return SphericalExpansion(
cutoff=5.0,
max_radial=6,
max_angular=4,
atomic_gaussian_width=0.3,
center_atom_weight=1.0,
radial_basis={
"Gto": {},
},
cutoff_function={
"ShiftedCosine": {"width": 0.5},
},
)


def check_operation(PowerSpectrum):
# this only runs basic checks functionality checks, and that the code produces
# output with the right type

descriptor = calculator().compute(system(), gradients=["positions"])

assert isinstance(descriptor, torch.ScriptObject)
if version.parse(torch.__version__) >= version.parse("2.1"):
assert descriptor._type().name() == "TensorMap"


def test_operation_as_python():
check_operation(PowerSpectrum)


def test_operation_as_torch_script():
scripted = torch.jit.script(calculator())

check_operation(scripted)
20 changes: 20 additions & 0 deletions python/rascaline/rascaline/utils/_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from equistore.core import Labels, TensorBlock, TensorMap

from ..calculator_base import CalculatorBase
from ..systems import IntoSystem
from . import _dispatch


# dummy object which is only relevant for torch
torch_nn_module = object


__all__ = [
"CalculatorBase",
"IntoSystem",
"Labels",
"TensorBlock",
"TensorMap",
"_dispatch",
"torch_nn_module",
]
103 changes: 103 additions & 0 deletions python/rascaline/rascaline/utils/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Helper functions to dispatch methods between numpy and torch.
The functions are similar to those in equistore-operations. Missing functions may
already exist there. Functions are ordered alphabetically.
"""

from typing import List

import numpy as np


try:
import torch
from torch import Tensor as TorchTensor
except ImportError:

class TorchTensor:
pass


UNKNOWN_ARRAY_TYPE = (
"unknown array type, only numpy arrays and torch tensors are supported"
)


def _check_all_np_ndarray(arrays):
for array in arrays:
if not isinstance(array, np.ndarray):
raise TypeError(
f"expected argument to be a np.ndarray, but got {type(array)}"
)


def _check_all_torch_tensor(arrays):
for array in arrays:
if not isinstance(array, TorchTensor):
raise TypeError(
f"expected argument to be a torch.Tensor, but got {type(array)}"
)


def list_to_array(array, data: List):
"""Create an object from data with the same type as ``array``."""
if isinstance(array, TorchTensor):
return torch.Tensor(data, device=array.device, dtype=array.dtype)
elif isinstance(array, np.ndarray):
return np.array(data, dtype=array.dtype)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)


def matmul(a, b, out=None):
"""Matrix product of two arrays."""
if isinstance(a, TorchTensor):
_check_all_torch_tensor([b])
return torch.matmul(a, b, out=out)
elif isinstance(a, np.ndarray):
_check_all_np_ndarray([b])
return np.matmul(a, b, out=out)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)


def unique(array, return_inverse=False, return_counts=False, axis=None):
"""Find the unique elements of an array."""
if isinstance(array, TorchTensor):
return torch.unique(
array, return_inverse=return_inverse, return_counts=return_counts, dim=axis
)
elif isinstance(array, np.ndarray):
return np.unique(
array, return_inverse=return_inverse, return_counts=return_counts, axis=axis
)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)


def zeros_like(array, shape=None, requires_grad=False):
"""
Create an array filled with zeros, with the given ``shape``, and similar
dtype, device and other options as ``array``.
If ``shape`` is :py:obj:`None`, the array shape is used instead.
``requires_grad`` is only used for torch tensors, and set the corresponding
value on the returned array.
This is the equivalent to ``np.zeros_like(array, shape=shape)``.
"""
if isinstance(array, TorchTensor):
if shape is None:
shape = array.size()

return torch.zeros(
shape,
dtype=array.dtype,
layout=array.layout,
device=array.device,
requires_grad=requires_grad,
)
elif isinstance(array, np.ndarray):
return np.zeros_like(array, shape=shape, subok=False)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)
Loading

0 comments on commit 7f911d8

Please sign in to comment.