This repository has been archived by the owner on Apr 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
14a3401
commit 0a22e04
Showing
8 changed files
with
902 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" | ||
Computing a SoR Kernel Model | ||
============================ | ||
.. start-body | ||
In this tutorial we calculate a kernel model using subset of regressor (SoR) | ||
Kernel model. | ||
""" | ||
|
||
import ase.io | ||
import metatensor | ||
from rascaline import SoapPowerSpectrum | ||
|
||
from equisolve.numpy.models import SorKernelRidge | ||
from equisolve.numpy.sample_selection import FPS | ||
from equisolve.utils import ase_to_tensormap | ||
|
||
|
||
frames = ase.io.read("dataset.xyz", ":20") | ||
y = ase_to_tensormap(frames, energy="energy") | ||
n_to_select = 100 | ||
degree = 3 | ||
|
||
HYPER_PARAMETERS = { | ||
"cutoff": 5.0, | ||
"max_radial": 6, | ||
"max_angular": 4, | ||
"atomic_gaussian_width": 0.3, | ||
"center_atom_weight": 1.0, | ||
"radial_basis": { | ||
"Gto": {}, | ||
}, | ||
"cutoff_function": { | ||
"ShiftedCosine": {"width": 0.5}, | ||
}, | ||
} | ||
|
||
calculator = SoapPowerSpectrum(**HYPER_PARAMETERS) | ||
|
||
descriptor = calculator.compute(frames, gradients=[]) | ||
|
||
descriptor = descriptor.keys_to_samples("species_center") | ||
descriptor = descriptor.keys_to_properties(["species_neighbor_1", "species_neighbor_2"]) | ||
|
||
pseudo_points = FPS(n_to_select=n_to_select).fit_transform(descriptor) | ||
|
||
clf = SorKernelRidge() | ||
clf.fit( | ||
descriptor, | ||
pseudo_points, | ||
y, | ||
kernel_type="polynomial", | ||
kernel_kwargs={"degree": 3, "aggregate_names": ["center", "species_center"]}, | ||
) | ||
y_pred = clf.predict(descriptor) | ||
|
||
print( | ||
"MAE:", | ||
metatensor.mean_over_samples( | ||
metatensor.abs(metatensor.subtract(y_pred, y)), "structure" | ||
)[0].values[0, 0], | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._aggregate_kernel import AggregateKernel # noqa: F401 | ||
from ._aggregate_kernel import AggregateLinear, AggregatePolynomial | ||
|
||
|
||
__all__ = ["AggregateLinear", "AggregatePolynomial"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from typing import List, Tuple, Union | ||
|
||
import metatensor | ||
from metatensor import TensorMap | ||
|
||
|
||
try: | ||
import torch | ||
|
||
HAS_TORCH = True | ||
TorchModule = torch.nn.Module | ||
except ImportError: | ||
HAS_TORCH = False | ||
import abc | ||
|
||
# TODO move to more module.py | ||
class Module(metaclass=abc.ABCMeta): | ||
@abc.abstractmethod | ||
def forward(self, *args, **kwargs): | ||
pass | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
@abc.abstractmethod | ||
def export_torch(self): | ||
pass | ||
|
||
|
||
class AggregateKernel(Module): | ||
""" | ||
A kernel that aggregates values in a kernel over :param aggregate_names: using | ||
a aggregaten function given by :param aggregate_type: | ||
:param aggregate_names: | ||
:param aggregate_type: | ||
""" | ||
|
||
def __init__( | ||
self, | ||
aggregate_names: Union[str, List[str]] = "aggregate", | ||
aggregate_type: str = "sum", | ||
structurewise_aggregate: bool = False, | ||
): | ||
valid_aggregate_types = ["sum", "mean"] | ||
if aggregate_type not in valid_aggregate_types: | ||
raise ValueError( | ||
f"Given aggregate_type {aggregate_type!r} but only " | ||
f"{aggregate_type!r} are supported." | ||
) | ||
if structurewise_aggregate: | ||
raise NotImplementedError( | ||
"structurewise aggregation has not been implemented." | ||
) | ||
|
||
self._aggregate_names = aggregate_names | ||
self._aggregate_type = aggregate_type | ||
self._structurewise_aggregate = structurewise_aggregate | ||
|
||
def aggregate_features(self, tensor: TensorMap) -> TensorMap: | ||
if self._aggregate_type == "sum": | ||
return metatensor.sum_over_samples( | ||
tensor, samples_names=self._aggregate_names | ||
) | ||
elif self._aggregate_type == "mean": | ||
return metatensor.mean_over_samples( | ||
tensor, samples_names=self._aggregate_names | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"aggregate_type {self._aggregate_type!r} has not been implemented." | ||
) | ||
|
||
def aggregate_kernel( | ||
self, kernel: TensorMap, are_pseudo_points: Tuple[bool, bool] = (False, False) | ||
) -> TensorMap: | ||
if self._aggregate_type == "sum": | ||
if not are_pseudo_points[0]: | ||
kernel = metatensor.sum_over_samples(kernel, self._aggregate_names) | ||
if not are_pseudo_points[1]: | ||
# TODO {sum,mean}_over_properties does not exist | ||
raise NotImplementedError( | ||
"properties dimenson cannot be aggregated for the moment" | ||
) | ||
kernel = metatensor.sum_over_properties(kernel, self._aggregate_names) | ||
return kernel | ||
elif self._aggregate_type == "mean": | ||
if not are_pseudo_points[0]: | ||
kernel = metatensor.mean_over_samples(kernel, self._aggregate_names) | ||
if not are_pseudo_points[1]: | ||
# TODO {sum,mean}_over_properties does not exist | ||
raise NotImplementedError( | ||
"properties dimenson cannot be aggregated for the moment" | ||
) | ||
kernel = metatensor.mean_over_properties(kernel, self._aggregate_names) | ||
return kernel | ||
else: | ||
raise NotImplementedError( | ||
f"aggregate_type {self._aggregate_type!r} has not been implemented." | ||
) | ||
|
||
def forward( | ||
self, | ||
tensor1: TensorMap, | ||
tensor2: TensorMap, | ||
are_pseudo_points: Tuple[bool, bool] = (False, False), | ||
) -> TensorMap: | ||
return self.aggregate_kernel( | ||
self.compute_kernel(tensor1, tensor2), are_pseudo_points | ||
) | ||
|
||
def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap: | ||
raise NotImplementedError("compute_kernel needs to be implemented.") | ||
|
||
|
||
class AggregateLinear(AggregateKernel): | ||
def __init__( | ||
self, | ||
aggregate_names: Union[str, List[str]] = "aggregate", | ||
aggregate_type: str = "sum", | ||
structurewise_aggregate: bool = False, | ||
): | ||
super().__init__(aggregate_names, aggregate_type, structurewise_aggregate) | ||
|
||
def forward( | ||
self, | ||
tensor1: TensorMap, | ||
tensor2: TensorMap, | ||
are_pseudo_points: Tuple[bool, bool] = (False, False), | ||
) -> TensorMap: | ||
# we overwrite default behavior because for linear kernels we can do it more | ||
# memory efficient | ||
if not are_pseudo_points[0]: | ||
tensor1 = self.aggregate_features(tensor1) | ||
if not are_pseudo_points[1]: | ||
tensor2 = self.aggregate_features(tensor2) | ||
return self.compute_kernel(tensor1, tensor2) | ||
|
||
def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap: | ||
return metatensor.dot(tensor1, tensor2) | ||
|
||
def export_torch(self): | ||
raise NotImplementedError("export_torch has not been implemented") | ||
# idea is to do something in the lines of | ||
# return euqisolve.torch.kernels.AggregateLinear( | ||
# self._aggregate_names, | ||
# self._aggregate_type) | ||
|
||
|
||
class AggregatePolynomial(AggregateKernel): | ||
def __init__( | ||
self, | ||
aggregate_names: Union[str, List[str]] = "aggregate", | ||
aggregate_type: str = "sum", | ||
structurewise_aggregate: bool = False, | ||
degree: int = 2, | ||
): | ||
super().__init__(aggregate_names, aggregate_type, structurewise_aggregate) | ||
self._degree = 2 | ||
|
||
def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap): | ||
return metatensor.pow(metatensor.dot(tensor1, tensor2), self._degree) | ||
|
||
def export_torch(self): | ||
raise NotImplementedError("export_torch has not been implemented") | ||
# idea is to do something in the lines of | ||
# return euqisolve.torch.kernels.AggregatePolynomial( | ||
# self._aggregate_names, | ||
# self._aggregate_type, | ||
# self._degree) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.