-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a general torch CompositionModel #280
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
95c9153
Add a general torch CompositionModel
PicoCentauri 4791d2e
Finish composition model
frostedoyster e7ceecf
Merge branch 'main' into composition-model
frostedoyster 7c9ac4a
Integrate with SOAP-BPNN
frostedoyster eb5515e
Also use the new CompositionModel in GAP
frostedoyster fc71849
Add test for `remove_composition`
frostedoyster 9a7718b
Exclude `mtt::aux::` quantities from composition models
frostedoyster 043effe
Remove composition from original SOAP-BPNN
frostedoyster bb756d7
Fix bug
frostedoyster f5b3527
Update metatensor
frostedoyster e297fad
`._module` -> `.module`
frostedoyster b6caa7b
Update dataset
frostedoyster 32c24be
Fix alchemical model
frostedoyster 0b35fed
Add tests for errors
frostedoyster 661ebe5
Only warn if atomic types are present in the validation dataset but n…
frostedoyster a2445d9
Merge branch 'update-metatensor' into composition-model
frostedoyster b69b189
Fix test
frostedoyster 021b7c2
Merge branch 'main' into composition-model
frostedoyster 7a406d1
Debugg
frostedoyster ab41dc2
Merge branch 'main' into composition-model
Luthaf d29920c
Do not import metatensor operation on the top level
Luthaf b1fd48d
Selected atoms for composition model
frostedoyster 92b4735
Merge branch 'composition-model' of https://github.com/lab-cosmo/meta…
frostedoyster fe8e536
Test selected atoms
frostedoyster 269c93c
Merge branch 'main' into composition-model
frostedoyster 112ed89
More testing
frostedoyster e285cbf
Merge branch 'composition-model' of https://github.com/lab-cosmo/meta…
frostedoyster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
src/metatrain/experimental/alchemical_model/utils/composition.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import List, Tuple, Union | ||
|
||
import torch | ||
|
||
from ....utils.data.dataset import Dataset, get_atomic_types | ||
|
||
|
||
def calculate_composition_weights( | ||
datasets: Union[Dataset, List[Dataset]], property: str | ||
) -> Tuple[torch.Tensor, List[int]]: | ||
"""Calculate the composition weights for a dataset. | ||
|
||
It assumes per-system properties. | ||
|
||
:param dataset: Dataset to calculate the composition weights for. | ||
:returns: Composition weights for the dataset, as well as the | ||
list of species that the weights correspond to. | ||
""" | ||
if not isinstance(datasets, list): | ||
datasets = [datasets] | ||
|
||
# Note: `atomic_types` are sorted, and the composition weights are sorted as | ||
# well, because the species are sorted in the composition features. | ||
atomic_types = sorted(get_atomic_types(datasets)) | ||
|
||
targets = torch.stack( | ||
[sample[property].block().values for dataset in datasets for sample in dataset] | ||
) | ||
targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions | ||
|
||
total_num_structures = sum([len(dataset) for dataset in datasets]) | ||
dtype = datasets[0][0]["system"].positions.dtype | ||
composition_features = torch.empty( | ||
(total_num_structures, len(atomic_types)), dtype=dtype | ||
) | ||
structure_index = 0 | ||
for dataset in datasets: | ||
for sample in dataset: | ||
structure = sample["system"] | ||
for j, s in enumerate(atomic_types): | ||
composition_features[structure_index, j] = torch.sum( | ||
structure.types == s | ||
) | ||
structure_index += 1 | ||
|
||
regularizer = 1e-20 | ||
while regularizer: | ||
if regularizer > 1e5: | ||
raise RuntimeError( | ||
"Failed to solve the linear system to calculate the " | ||
"composition weights. The dataset is probably too small " | ||
"or ill-conditioned." | ||
) | ||
try: | ||
solution = torch.linalg.solve( | ||
composition_features.T @ composition_features | ||
+ regularizer | ||
* torch.eye( | ||
composition_features.shape[1], | ||
dtype=composition_features.dtype, | ||
device=composition_features.device, | ||
), | ||
composition_features.T @ targets, | ||
) | ||
break | ||
except torch._C._LinAlgError: | ||
regularizer *= 10.0 | ||
|
||
return solution, atomic_types |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
|
||
from metatrain.utils.data.dataset import DatasetInfo | ||
|
||
from ...utils.composition import apply_composition_contribution | ||
from ...utils.composition import CompositionModel | ||
from ...utils.dtype import dtype_to_str | ||
from ...utils.export import export | ||
|
||
|
@@ -123,14 +123,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: | |
unit="unitless", per_atom=True | ||
) | ||
|
||
# creates a composition weight tensor that can be directly indexed by species, | ||
# this can be left as a tensor of zero or set from the outside using | ||
# set_composition_weights (recommended for better accuracy) | ||
n_outputs = len(self.outputs) | ||
self.register_buffer( | ||
"composition_weights", | ||
torch.zeros((n_outputs, max(self.atomic_types) + 1)), | ||
) | ||
# buffers cannot be indexed by strings (torchscript), so we create a single | ||
# tensor for all output. Due to this, we need to slice the tensor when we use | ||
# it and use the output name to select the correct slice via a dictionary | ||
|
@@ -195,6 +187,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: | |
} | ||
) | ||
|
||
self.composition_model = CompositionModel( | ||
model_hypers={}, | ||
dataset_info=dataset_info, | ||
) | ||
|
||
def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": | ||
# merge old and new dataset info | ||
merged_info = self.dataset_info.union(dataset_info) | ||
|
@@ -261,12 +258,7 @@ def forward( | |
atomic_energies: Dict[str, TensorMap] = {} | ||
for output_name, output_layer in self.last_layers.items(): | ||
if output_name in outputs: | ||
atomic_energies[output_name] = apply_composition_contribution( | ||
output_layer(last_layer_features), | ||
self.composition_weights[ # type: ignore | ||
self.output_to_index[output_name] | ||
], | ||
) | ||
atomic_energies[output_name] = output_layer(last_layer_features) | ||
|
||
# Sum the atomic energies coming from the BPNN to get the total energy | ||
for output_name, atomic_energy in atomic_energies.items(): | ||
|
@@ -281,6 +273,19 @@ def forward( | |
atomic_energy, ["atom", "center_type"] | ||
) | ||
|
||
if not self.training: | ||
# at evaluation, we also add the composition contributions | ||
Comment on lines
+276
to
+277
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NICE! |
||
composition_contributions = self.composition_model( | ||
systems, outputs, selected_atoms | ||
) | ||
for name in return_dict: | ||
if name.startswith("mtt::aux::"): | ||
continue # skip auxiliary outputs (not targets) | ||
return_dict[name] = metatensor.torch.add( | ||
return_dict[name], | ||
composition_contributions[name], | ||
) | ||
|
||
return return_dict | ||
|
||
@classmethod | ||
|
@@ -303,6 +308,11 @@ def export(self) -> MetatensorAtomisticModel: | |
if dtype not in self.__supported_dtypes__: | ||
raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn") | ||
|
||
# Make sure the model is all in the same dtype | ||
# For example, at this point, the composition model within the SOAP-BPNN is | ||
# still float64 | ||
self.to(dtype) | ||
|
||
capabilities = ModelCapabilities( | ||
outputs=self.outputs, | ||
atomic_types=self.atomic_types, | ||
|
@@ -314,21 +324,6 @@ def export(self) -> MetatensorAtomisticModel: | |
|
||
return export(model=self, model_capabilities=capabilities) | ||
|
||
def set_composition_weights( | ||
self, | ||
output_name: str, | ||
input_composition_weights: torch.Tensor, | ||
atomic_types: List[int], | ||
) -> None: | ||
"""Set the composition weights for a given output.""" | ||
# all species that are not present retain their weight of zero | ||
self.composition_weights[self.output_to_index[output_name]][ # type: ignore | ||
atomic_types | ||
] = input_composition_weights.to( | ||
dtype=self.composition_weights.dtype, # type: ignore | ||
device=self.composition_weights.device, # type: ignore | ||
) | ||
|
||
def add_output(self, output_name: str) -> None: | ||
"""Add a new output to the self.""" | ||
# add a new row to the composition weights tensor | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not needed anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this needs to be kept for the alchemical model (notice that it changed directories), which works in a way that doesn't allow me to change things without changing the alchemical model code itself
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay!