Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a torch.nn.Module for the composition weights #269

Closed
PicoCentauri opened this issue Jun 21, 2024 · 1 comment · Fixed by #280
Closed

Create a torch.nn.Module for the composition weights #269

PicoCentauri opened this issue Jun 21, 2024 · 1 comment · Fixed by #280
Labels
Discussion Issues to be discussed by the contributors Infrastructure: Miscellaneous General infrastructure issues Priority: Medium Important issues to address after high priority.

Comments

@PicoCentauri
Copy link
Contributor

Composition weights are used by almost every architecture to subtract the energy from the chemical composition of the dataset to make it easier for the actual architecture to learn the targets. In metatrain, there is a utility function to compute the composition weights function

https://github.com/lab-cosmo/metatrain/blob/40d4d6a30d9add0fa4d1e9d2ce9e2635423fcc48/src/metatrain/utils/composition.py#L9

However, so far each architecture has to apply these composition weights on their own by looping over the Systems and creating the output TensorBlocks. See for example the GAP architecture

https://github.com/lab-cosmo/metatrain/blob/40d4d6a30d9add0fa4d1e9d2ce9e2635423fcc48/src/metatrain/experimental/gap/model.py#L242

I think it would be useful to create a torch.nn.Module for the CompositionEnergy to make this essential part of an architecture easier and usable for the devs. Even though we could, I wouldn't make this a public architecture to users right now. The idea for an Module is also in line for a short range Module as discussed in #265 where we concluded to also create a torch.nn.Module.

My idea for the design below is basically copied from the metatensor atomistic tutorial.

from typing import Dict, List, Optional

import torch

from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import ModelOutput, System


class CompositionEnergy(torch.nn.Modules):

    def __init__(self, atomic_types):
        self.atomic_types = atomic_types

    def compute_weights(atasets: Union[Dataset, List[Dataset]], property: str) -> None:
        """Calculate the composition weights for a dataset."""
        # Basically the code form
        # `metatrain.utils.composition.calculate_composition_weights()`
        ...

    def forward(
        self,
        systems: List[System],
        outputs: Dict[str, ModelOutput],
        selected_atoms: Optional[Labels] = None,
    ) -> Dict[str, TensorMap]:
        # if the model user did not request an energy calculation, we have nothing to do
        if "energy" not in outputs:
            return {}

        # we don't want to worry about selected_atoms yet
        if selected_atoms is not None:
            raise NotImplementedError("selected_atoms is not implemented")

        if outputs["energy"].per_atom:
            raise NotImplementedError("per atom energy is not implemented")

        # compute the energy for each system by adding together the energy for each atom
        energy = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype)
        for i, system in enumerate(systems):
            energy[i] += 0.0  # do actual calculations here

        # Add metadata to the output
        block = TensorBlock(
            values=energy,
            samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)),
            components=[],
            properties=Labels("energy", torch.tensor([[0]])),
        )
        return {
            "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block])
        }

Once we have this Module, an architecture basically has to call the compute_weights function to store the weights and can use the forward function to apply the composition energies.

@PicoCentauri PicoCentauri added Priority: Medium Important issues to address after high priority. Discussion Issues to be discussed by the contributors Infrastructure: Miscellaneous General infrastructure issues labels Jun 21, 2024
@frostedoyster
Copy link
Collaborator

frostedoyster commented Jun 21, 2024

I like it. It would be cool if, for such a lightweight component of a model, we could avoid recomputing the labels at each iteration (but this is an implementation detail)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion Issues to be discussed by the contributors Infrastructure: Miscellaneous General infrastructure issues Priority: Medium Important issues to address after high priority.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants