-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e800e96
commit 7f911d8
Showing
11 changed files
with
306 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
Utils | ||
===== | ||
|
||
Utility functions and classes that extend the core usage of rascaline-torch | ||
|
||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
power-spectrum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
PowerSpectrum | ||
============= | ||
|
||
.. autoclass:: rascaline.torch.utils.PowerSpectrum | ||
:members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .power_spectrum import PowerSpectrum | ||
|
||
|
||
__all__ = ["PowerSpectrum"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from equistore.torch import Labels, TensorBlock, TensorMap | ||
from torch.nn import Module as torch_nn_module | ||
|
||
from rascaline.utils import _dispatch | ||
|
||
from ..calculator_base import CalculatorModule as CalculatorBase | ||
from ..system import System as IntoSystem | ||
|
||
|
||
__all__ = [ | ||
"_dispatch", | ||
"CalculatorBase", | ||
"IntoSystem", | ||
"Labels", | ||
"torch_nn_module", | ||
"TensorBlock", | ||
"TensorMap", | ||
] |
20 changes: 20 additions & 0 deletions
20
python/rascaline-torch/rascaline/torch/utils/power_spectrum.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import importlib | ||
import sys | ||
|
||
import rascaline.utils.power_spectrum | ||
|
||
|
||
# For details what is happening here take a look an `rascaline.torch.calculators`. | ||
spec = importlib.util.spec_from_file_location( | ||
# create a module with this name | ||
"rascaline.torch.utils.power_spectrum", | ||
# using the code from there | ||
rascaline.utils.power_spectrum.__file__, | ||
) | ||
module = importlib.util.module_from_spec(spec) | ||
sys.modules[spec.name] = module | ||
spec.loader.exec_module(module) | ||
|
||
# don't forget to also update `rascaline/utils/__init__.py` and | ||
# `rascaline/torch/utils.__init__.py` when modifying this file | ||
PowerSpectrum = module.PowerSpectrum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
from packaging import version | ||
|
||
from rascaline.torch import System | ||
from rascaline.torch.calculators import SphericalExpansion | ||
from rascaline.torch.utils import PowerSpectrum | ||
|
||
|
||
def system(): | ||
return System( | ||
species=torch.tensor([1, 1, 8, 8]), | ||
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]), | ||
cell=torch.tensor([[10, 0, 0], [0, 10, 0], [0, 0, 10]]), | ||
) | ||
|
||
|
||
def calculator(): | ||
return SphericalExpansion( | ||
cutoff=5.0, | ||
max_radial=6, | ||
max_angular=4, | ||
atomic_gaussian_width=0.3, | ||
center_atom_weight=1.0, | ||
radial_basis={ | ||
"Gto": {}, | ||
}, | ||
cutoff_function={ | ||
"ShiftedCosine": {"width": 0.5}, | ||
}, | ||
) | ||
|
||
|
||
def check_operation(PowerSpectrum): | ||
# this only runs basic checks functionality checks, and that the code produces | ||
# output with the right type | ||
|
||
descriptor = calculator().compute(system(), gradients=["positions"]) | ||
|
||
assert isinstance(descriptor, torch.ScriptObject) | ||
if version.parse(torch.__version__) >= version.parse("2.1"): | ||
assert descriptor._type().name() == "TensorMap" | ||
|
||
|
||
def test_operation_as_python(): | ||
check_operation(PowerSpectrum) | ||
|
||
|
||
def test_operation_as_torch_script(): | ||
scripted = torch.jit.script(calculator()) | ||
|
||
check_operation(scripted) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from equistore.core import Labels, TensorBlock, TensorMap | ||
|
||
from ..calculator_base import CalculatorBase | ||
from ..systems import IntoSystem | ||
from . import _dispatch | ||
|
||
|
||
# dummy object which is only relevant for torch | ||
torch_nn_module = object | ||
|
||
|
||
__all__ = [ | ||
"CalculatorBase", | ||
"IntoSystem", | ||
"Labels", | ||
"TensorBlock", | ||
"TensorMap", | ||
"_dispatch", | ||
"torch_nn_module", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Helper functions to dispatch methods between numpy and torch. | ||
The functions are similar to those in equistore-operations. Missing functions may | ||
already exist there. Functions are ordered alphabetically. | ||
""" | ||
|
||
from typing import List | ||
|
||
import numpy as np | ||
|
||
|
||
try: | ||
import torch | ||
from torch import Tensor as TorchTensor | ||
except ImportError: | ||
|
||
class TorchTensor: | ||
pass | ||
|
||
|
||
UNKNOWN_ARRAY_TYPE = ( | ||
"unknown array type, only numpy arrays and torch tensors are supported" | ||
) | ||
|
||
|
||
def _check_all_np_ndarray(arrays): | ||
for array in arrays: | ||
if not isinstance(array, np.ndarray): | ||
raise TypeError( | ||
f"expected argument to be a np.ndarray, but got {type(array)}" | ||
) | ||
|
||
|
||
def _check_all_torch_tensor(arrays): | ||
for array in arrays: | ||
if not isinstance(array, TorchTensor): | ||
raise TypeError( | ||
f"expected argument to be a torch.Tensor, but got {type(array)}" | ||
) | ||
|
||
|
||
def list_to_array(array, data: List): | ||
"""Create an object from data with the same type as ``array``.""" | ||
if isinstance(array, TorchTensor): | ||
return torch.Tensor(data, device=array.device, dtype=array.dtype) | ||
elif isinstance(array, np.ndarray): | ||
return np.array(data, dtype=array.dtype) | ||
else: | ||
raise TypeError(UNKNOWN_ARRAY_TYPE) | ||
|
||
|
||
def matmul(a, b, out=None): | ||
"""Matrix product of two arrays.""" | ||
if isinstance(a, TorchTensor): | ||
_check_all_torch_tensor([b]) | ||
return torch.matmul(a, b, out=out) | ||
elif isinstance(a, np.ndarray): | ||
_check_all_np_ndarray([b]) | ||
return np.matmul(a, b, out=out) | ||
else: | ||
raise TypeError(UNKNOWN_ARRAY_TYPE) | ||
|
||
|
||
def unique(array, return_inverse=False, return_counts=False, axis=None): | ||
"""Find the unique elements of an array.""" | ||
if isinstance(array, TorchTensor): | ||
return torch.unique( | ||
array, return_inverse=return_inverse, return_counts=return_counts, dim=axis | ||
) | ||
elif isinstance(array, np.ndarray): | ||
return np.unique( | ||
array, return_inverse=return_inverse, return_counts=return_counts, axis=axis | ||
) | ||
else: | ||
raise TypeError(UNKNOWN_ARRAY_TYPE) | ||
|
||
|
||
def zeros_like(array, shape=None, requires_grad=False): | ||
""" | ||
Create an array filled with zeros, with the given ``shape``, and similar | ||
dtype, device and other options as ``array``. | ||
If ``shape`` is :py:obj:`None`, the array shape is used instead. | ||
``requires_grad`` is only used for torch tensors, and set the corresponding | ||
value on the returned array. | ||
This is the equivalent to ``np.zeros_like(array, shape=shape)``. | ||
""" | ||
if isinstance(array, TorchTensor): | ||
if shape is None: | ||
shape = array.size() | ||
|
||
return torch.zeros( | ||
shape, | ||
dtype=array.dtype, | ||
layout=array.layout, | ||
device=array.device, | ||
requires_grad=requires_grad, | ||
) | ||
elif isinstance(array, np.ndarray): | ||
return np.zeros_like(array, shape=shape, subok=False) | ||
else: | ||
raise TypeError(UNKNOWN_ARRAY_TYPE) |
Oops, something went wrong.