Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clebsch gordan submodule - implementation of TorchScript interface #269

Merged
merged 24 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3f8369f
initalize clebsch_gordan submodule in rascaline.torch
agoscinski Dec 29, 2023
e67d917
checkpoint all-deps and all-deps-torch tests passing
agoscinski Feb 13, 2024
5caaec8
all-deps all-deps-torch pass
agoscinski Feb 14, 2024
155f36c
change ClebschGordanReal to TensorMap
agoscinski Feb 14, 2024
3c3961b
adding for torch backend
agoscinski Feb 15, 2024
358d170
fixing TorchScript
agoscinski Feb 15, 2024
f94b8c4
fix dispatch and refactor tests
agoscinski Feb 15, 2024
2d621b6
remove _dispatch.max_axis not needed
agoscinski Feb 15, 2024
25f7ad3
add tests for properties of DensityCorrelations
agoscinski Feb 15, 2024
39aee61
simplify _parse_selected_keys, now it does not need to be scritable
agoscinski Feb 15, 2024
f6d88ef
Make CG cache contiguous, fix some
jwa7 Feb 15, 2024
3513558
Remove comment block
jwa7 Feb 15, 2024
1504c19
Update python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gorda…
agoscinski Feb 16, 2024
d024c7f
Update python/rascaline/rascaline/utils/clebsch_gordan/correlate_dens…
agoscinski Feb 16, 2024
5a69107
Test save/load for checking contiguous. Clean up. Docstring arg.
jwa7 Feb 16, 2024
3fb4ef7
Merge branch 'master' into cg-torchscript
jwa7 Feb 16, 2024
a45b73c
Partial resolution of review comments
jwa7 Feb 17, 2024
7c14c22
Get rid of __all__
jwa7 Feb 19, 2024
be331f1
Fix Python import
Luthaf Feb 19, 2024
993f005
Add CG to API docs
Luthaf Feb 19, 2024
f4337f4
linter noqa
jwa7 Feb 19, 2024
96cee2a
Make the CG cache a function not a class. Fix the mops CG cache and i…
jwa7 Feb 19, 2024
7c77adc
Review round 2
jwa7 Feb 21, 2024
44ded8c
Final review comment
jwa7 Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/references/api/python/utils/clebsch-gordan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Clebsch-Gordan products
=======================

.. autoclass:: rascaline.utils.DensityCorrelations
:members:
1 change: 1 addition & 0 deletions docs/src/references/api/python/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Utility functions and classes that extend the core usage of rascaline.
radial-basis
power-spectrum
splines
clebsch-gordan
5 changes: 5 additions & 0 deletions docs/src/references/api/torch/utils/clebsch-gordan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Clebsch-Gordan products
=======================

.. autoclass:: rascaline.torch.utils.DensityCorrelations
:members:
1 change: 1 addition & 0 deletions docs/src/references/api/torch/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Utility functions and classes that extend the core usage of rascaline-torch
:maxdepth: 1

power-spectrum
clebsch-gordan
8 changes: 4 additions & 4 deletions python/rascaline-torch/rascaline/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

_load_library()

from . import utils # noqa
from .calculator_base import CalculatorModule, register_autograd # noqa
from . import utils # noqa: E402, F401
from .calculator_base import CalculatorModule, register_autograd # noqa: E402, F401

# don't forget to also update `rascaline/__init__.py` and
# `rascaline/torch/calculators.py` when modifying this file
from .calculators import ( # noqa
from .calculators import ( # noqa: E402, F401
AtomicComposition,
LodeSphericalExpansion,
NeighborList,
Expand All @@ -24,7 +24,7 @@
SphericalExpansion,
SphericalExpansionByPair,
)
from .system import systems_to_torch # noqa
from .system import systems_to_torch # noqa: E402, F401


__all__ = [
Expand Down
88 changes: 88 additions & 0 deletions python/rascaline-torch/rascaline/torch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import importlib
import os
import sys
from typing import Any

import torch
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap

import rascaline.utils

from .calculator_base import CalculatorModule
from .system import System


_HERE = os.path.dirname(__file__)


# For details what is happening here take a look an `rascaline.torch.calculators`.

# create the `_backend` module as an empty module
spec = importlib.util.spec_from_loader(
"rascaline.torch.utils._backend",
loader=None,
)
module = importlib.util.module_from_spec(spec)
# This module only exposes a handful of things, defined here. Any changes here MUST also
# be made to the `metatensor/operations/_classes.py` file, which is used in non
# TorchScript mode.
module.__dict__["Labels"] = Labels
module.__dict__["TensorBlock"] = TensorBlock
module.__dict__["TensorMap"] = TensorMap
module.__dict__["LabelsEntry"] = LabelsEntry
module.__dict__["torch_jit_is_scripting"] = torch.jit.is_scripting
module.__dict__["torch_jit_annotate"] = torch.jit.annotate
module.__dict__["torch_jit_export"] = torch.jit.export
module.__dict__["TorchTensor"] = torch.Tensor
module.__dict__["TorchModule"] = torch.nn.Module
module.__dict__["TorchScriptClass"] = torch.ScriptClass
module.__dict__["Array"] = torch.Tensor
module.__dict__["CalculatorBase"] = CalculatorModule
module.__dict__["IntoSystem"] = System


def is_labels(obj: Any):
return isinstance(obj, Labels)


if os.environ.get("RASCALINE_IMPORT_FOR_SPHINX") is None:
is_labels = torch.jit.script(is_labels)

module.__dict__["is_labels"] = is_labels


def check_isinstance(obj, ty):
if isinstance(ty, torch.ScriptClass):
# This branch is taken when `ty` is a custom class (TensorMap, …). since `ty` is
# an instance of `torch.ScriptClass` and not a class itself, there is no way to
# check if obj is an "instance" of this class, so we always return True and hope
# for the best. Most errors should be caught by the TorchScript compiler anyway.
return True
else:
assert isinstance(ty, type)
return isinstance(obj, ty)


# register the module in sys.modules, so future import find it directly
sys.modules[spec.name] = module

# create a module named `rascaline.torch.utils` using code from
# `rascaline.utils`
spec = importlib.util.spec_from_file_location(
"rascaline.torch.utils", rascaline.utils.__file__
)

module = importlib.util.module_from_spec(spec)


cmake_prefix_path = os.path.realpath(os.path.join(_HERE, "..", "lib", "cmake"))
"""
Path containing the CMake configuration files for the underlying C library
"""

module.__dict__["cmake_prefix_path"] = cmake_prefix_path

# override `rascaline.torch.utils` (the module associated with the current file)
# with the newly created module
sys.modules[spec.name] = module
spec.loader.exec_module(module)
13 changes: 0 additions & 13 deletions python/rascaline-torch/rascaline/torch/utils/__init__.py

This file was deleted.

99 changes: 0 additions & 99 deletions python/rascaline-torch/rascaline/torch/utils/power_spectrum.py

This file was deleted.

113 changes: 113 additions & 0 deletions python/rascaline-torch/tests/utils/correlate_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
import io
import os
from typing import Any, List

import ase.io
import metatensor.torch
import pytest
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap # noqa

import rascaline.torch
from rascaline.torch.utils.clebsch_gordan.correlate_density import DensityCorrelations


DATA_ROOT = os.path.join(os.path.dirname(__file__), "data")


def is_tensor_map(obj: Any):
return isinstance(obj, TensorMap)


is_tensor_map = torch.jit.script(is_tensor_map)
jwa7 marked this conversation as resolved.
Show resolved Hide resolved

SPHERICAL_EXPANSION_HYPERS = {
"cutoff": 2.5,
"max_radial": 3,
"max_angular": 3,
"atomic_gaussian_width": 0.2,
"radial_basis": {"Gto": {}},
"cutoff_function": {"ShiftedCosine": {"width": 0.5}},
"center_atom_weight": 1.0,
}

SELECTED_KEYS = Labels(
names=["spherical_harmonics_l"], values=torch.tensor([1, 3]).reshape(-1, 1)
)


def h2o_isolated():
return ase.io.read(os.path.join(DATA_ROOT, "h2o_isolated.xyz"), ":")


def spherical_expansion(frames: List[ase.Atoms]):
"""Returns a rascaline SphericalExpansion"""
calculator = rascaline.torch.SphericalExpansion(**SPHERICAL_EXPANSION_HYPERS)
return calculator.compute(rascaline.torch.systems_to_torch(frames))


# copy of def test_correlate_density_angular_selection(
@pytest.mark.parametrize("selected_keys", [None, SELECTED_KEYS])
@pytest.mark.parametrize("skip_redundant", [True, False])
def test_torch_script_correlate_density_angular_selection(
selected_keys: Labels,
skip_redundant: bool,
):
"""
Tests that the correct angular channels are output based on the specified
``selected_keys``.
"""
frames = h2o_isolated()
nu_1 = spherical_expansion(frames)
correlation_order = 2
corr_calculator = DensityCorrelations(
max_angular=SPHERICAL_EXPANSION_HYPERS["max_angular"] * correlation_order,
correlation_order=correlation_order,
angular_cutoff=None,
selected_keys=selected_keys,
skip_redundant=skip_redundant,
)

ref_nu_2 = corr_calculator.compute(nu_1)
scripted_corr_calculator = torch.jit.script(corr_calculator)

# Test compute
scripted_nu_2 = scripted_corr_calculator.compute(nu_1)
assert metatensor.torch.equal_metadata(scripted_nu_2, ref_nu_2)
assert metatensor.torch.allclose(scripted_nu_2, ref_nu_2)

# Test compute_metadata
scripted_nu_2 = scripted_corr_calculator.compute_metadata(nu_1)
assert metatensor.torch.equal_metadata(scripted_nu_2, ref_nu_2)


def test_jit_save_load():
corr_calculator = DensityCorrelations(
max_angular=2,
correlation_order=2,
angular_cutoff=1,
)
scripted_correlate_density = torch.jit.script(corr_calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted_correlate_density, buffer)
buffer.seek(0)
torch.jit.load(buffer)
buffer.close()
jwa7 marked this conversation as resolved.
Show resolved Hide resolved


def test_save_load():
"""Tests for saving and loading with cg_backend="python-dense",
which makes the DensityCorrelations object non-scriptable due to
a non-contiguous CG cache."""
corr_calculator = DensityCorrelations(
max_angular=2,
correlation_order=2,
angular_cutoff=1,
cg_backend="python-dense",
)
with io.BytesIO() as buffer:
torch.save(corr_calculator, buffer)
buffer.seek(0)
torch.load(buffer)
buffer.close()
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
2 changes: 1 addition & 1 deletion python/rascaline/rascaline/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from .clebsch_gordan import correlate_density, correlate_density_metadata # noqa
from .clebsch_gordan import DensityCorrelations # noqa
from .power_spectrum import PowerSpectrum # noqa
from .splines import ( # noqa
AtomicDensityBase,
Expand Down
Loading
Loading