diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index 272377c5..6a188bfc 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -73,10 +73,20 @@ def train_model( if fixed_weights is None: fixed_weights = {} - missing_types = sorted(set(get_atomic_types(datasets)) - set(self.atomic_types)) + additional_types = sorted( + set(get_atomic_types(datasets)) - set(self.atomic_types) + ) + if additional_types: + raise ValueError( + "Provided `datasets` contains unknown " + f"atomic types {additional_types}. " + f"Known types from initilaization are {self.atomic_types}." + ) + + missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets))) if missing_types: raise ValueError( - f"Provided `datasets` contains unknown atomic types {missing_types}. " + f"Provided `datasets` do not contain atomic types {missing_types}. " f"Known types from initilaization are {self.atomic_types}." ) diff --git a/tests/utils/test_composition.py b/tests/utils/test_composition.py index 5bb48be0..e9472197 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_composition.py @@ -1,6 +1,7 @@ from pathlib import Path import metatensor.torch +import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import ModelOutput, System @@ -239,3 +240,104 @@ def test_remove_composition(): # In QM9 the composition contribution is very large: the standard deviation # of the energies is reduced by a factor of over 100 upon removing the composition assert std_after < 100.0 * std_before + + +def test_composition_model_missing_types(): + """ + Test the error when there are too many or too types in the dataset + compared to those declared at initialization. + """ + + # Here we use three synthetic structures: + # - O atom, with an energy of 1.0 + # - H2O molecule, with an energy of 5.0 + # - H4O2 molecule, with an energy of 10.0 + # The expected composition weights are 2.0 for H and 1.0 for O. + + systems = [ + System( + positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), + types=torch.tensor([8]), + cell=torch.eye(3, dtype=torch.float64), + ), + System( + positions=torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=torch.float64 + ), + types=torch.tensor([1, 1, 8]), + cell=torch.eye(3, dtype=torch.float64), + ), + System( + positions=torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ], + dtype=torch.float64, + ), + types=torch.tensor([1, 1, 8, 1, 1, 8]), + cell=torch.eye(3, dtype=torch.float64), + ), + ] + energies = [1.0, 5.0, 10.0] + energies = [ + TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]])), + blocks=[ + TensorBlock( + values=torch.tensor([[e]], dtype=torch.float64), + samples=Labels(names=["system"], values=torch.tensor([[i]])), + components=[], + properties=Labels(names=["energy"], values=torch.tensor([[0]])), + ) + ], + ) + for i, e in enumerate(energies) + ] + dataset = Dataset({"system": systems, "energy": energies}) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="energy", + per_atom=False, + ) + } + ), + ), + ) + with pytest.raises( + ValueError, + match="unknown atomic types", + ): + composition_model.train_model(dataset) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 8, 100], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="energy", + per_atom=False, + ) + } + ), + ), + ) + with pytest.raises( + ValueError, + match="do not contain atomic types", + ): + composition_model.train_model(dataset)