Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing mace with edge feats #61

Merged
merged 21 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
264f5ea
mixing sigma with edge features and sigma to sigma+1
sblackburn-mila Jun 16, 2024
589b13d
edge diffusion scalars do not match dimension with node diffusion scalar
sblackburn-mila Jun 16, 2024
882399c
gpu & tensor type issues
sblackburn-mila Jun 16, 2024
6437f27
sequential and not module list + adding bn for regularization
sblackburn-mila Jun 16, 2024
0b3dd95
mixing sigma with edge features and sigma to sigma+1
sblackburn-mila Jun 16, 2024
fd6603e
edge diffusion scalars do not match dimension with node diffusion scalar
sblackburn-mila Jun 16, 2024
5397975
gpu & tensor type issues
sblackburn-mila Jun 16, 2024
7b0906a
sequential and not module list + adding bn for regularization
sblackburn-mila Jun 16, 2024
e417c98
edge_features in difface
sblackburn-mila Jun 18, 2024
d65a2fa
fixing device type conflict
sblackburn-mila Jun 18, 2024
61bcf4e
option to add tanh in difface & adding sigma to edge features
sblackburn-mila Jun 25, 2024
5ae43d6
config files update
sblackburn-mila Jun 25, 2024
7e09e30
fixing input_to_mace function
sblackburn-mila Jun 25, 2024
7fd4137
augmented edge attributes should be in a if
sblackburn-mila Jun 25, 2024
5ab5925
fixing broken test due to missing parameter
sblackburn-mila Jun 25, 2024
2e265f6
juggling tensor size because mace and e3nn do not play well together
sblackburn-mila Jun 25, 2024
36012e4
added docstring
sblackburn-mila Jun 25, 2024
f7db962
flake8 and isort checks
sblackburn-mila Jun 25, 2024
1983ea0
useless variable not removed
sblackburn-mila Jun 25, 2024
c04353c
unit tests for mace <-> e3nn reshapes
sblackburn-mila Jun 25, 2024
3712c57
typo fixing from code review
sblackburn-mila Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions crystal_diffusion/models/diffusion_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import torch
from e3nn import o3
from e3nn.nn import Activation, BatchNorm
from e3nn.nn import Activation, BatchNorm, NormActivation
from mace.modules import (EquivariantProductBasisBlock, InteractionBlock,
LinearNodeEmbeddingBlock, RadialEmbeddingBlock)
from mace.modules.utils import get_edge_vectors_and_lengths
from torch_geometric.data import Data

from crystal_diffusion.models.mace_utils import get_adj_matrix
from crystal_diffusion.models.mace_utils import (get_adj_matrix,
reshape_from_e3nn_to_mace,
reshape_from_mace_to_e3nn)
from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE,
NOISY_CARTESIAN_POSITIONS, UNIT_CELL)

Expand Down Expand Up @@ -41,9 +43,9 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl

basis_vectors = batch[UNIT_CELL] # batch, spatial_dimension, spatial_dimension

adj_matrix, shift_matrix, batch_tensor = get_adj_matrix(positions=cartesian_positions,
basis_vectors=basis_vectors,
radial_cutoff=radial_cutoff)
adj_matrix, shift_matrix, batch_tensor, num_edges = get_adj_matrix(positions=cartesian_positions,
basis_vectors=basis_vectors,
radial_cutoff=radial_cutoff)

# node features are int corresponding to atom type
# TODO handle different atom types
Expand All @@ -53,8 +55,9 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl
# We broadcast to each node to avoid complex broadcasting logic within the model itself.
# TODO: it might be better to define the noise as a 'global' graph attribute, and find 'the right way' of
# mixing it with bona fide node features within the model.
noises = batch[NOISE] # [batch_size, 1]
noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues)
node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1]
edge_diffusion_scalars = noises.repeat_interleave(num_edges.long().to(device=noises.device), dim=0)

# [batchsize * natoms, spatial dimension]
flat_cartesian_positions = cartesian_positions.view(-1, spatial_dimension)
Expand All @@ -70,6 +73,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl
graph_data = Data(edge_index=adj_matrix,
node_attrs=node_attrs.to(device),
node_diffusion_scalars=node_diffusion_scalars.to(device),
edge_diffusion_scalars=edge_diffusion_scalars.to(device),
positions=flat_cartesian_positions,
ptr=ptr.to(device),
batch=batch_tensor.to(device),
Expand All @@ -92,6 +96,8 @@ def __init__(
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
num_edge_hidden_layers: int,
edge_hidden_irreps: o3.Irreps,
max_ell: int,
interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
Expand All @@ -108,6 +114,7 @@ def __init__(
radial_type: Optional[str] = "bessel",
condition_embedding_size: int = 64, # dimension of the conditional variable embedding - assumed to be l=1 (odd)
use_batchnorm: bool = False,
tanh_after_interaction: bool = True,
):
"""Init method."""
assert num_elements == 1, "only a single element can be used at this time. Set 'num_elements' to 1."
Expand Down Expand Up @@ -183,6 +190,24 @@ def __init__(
)
edge_feats_irreps = o3.Irreps([(self.radial_embedding.out_dim, scalar_irrep)])

if num_edge_hidden_layers > 0:
self.edge_attribute_mixing = o3.FullyConnectedTensorProduct(irreps_in1=diffusion_scalar_irreps_out,
irreps_in2=edge_feats_irreps,
irreps_out=edge_hidden_irreps,
irrep_normalization='norm')
self.edge_hidden_layers = torch.nn.Sequential()
edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate])
for i in range(num_edge_hidden_layers):
if i != 0:
self.edge_hidden_layers.append(edge_non_linearity)
edge_hidden_layer = o3.Linear(irreps_in=edge_hidden_irreps,
irreps_out=edge_hidden_irreps,
biases=False)
self.edge_hidden_layers.append(edge_hidden_layer)
else:
self.edge_attribute_mixing, self.edge_hidden_layers = None, None
edge_hidden_irreps = edge_feats_irreps

# The "spherical harmonics" correspond to Y_{lm} in the definition of A^{(1)}, eq. 9 of the PAPER.
sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
interaction_irreps = (sh_irreps * number_of_hidden_scalar_dimensions).sort()[0].simplify()
Expand All @@ -200,14 +225,19 @@ def __init__(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
edge_feats_irreps=edge_hidden_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions = torch.nn.ModuleList([inter])

if tanh_after_interaction:
self.interactions_tanh = torch.nn.ModuleList([NormActivation(inter.target_irreps, torch.tanh)])
else:
self.interactions_tanh = None

# 'sc' means 'self-connection', namely a 'residual-like' connection, h^{t+1} = m^{t} + (sc) x h^{(t)}
# Use the appropriate self connection at the first layer for proper E0
use_sc_first = False
Expand Down Expand Up @@ -244,14 +274,17 @@ def __init__(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
edge_feats_irreps=edge_hidden_irreps,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
)
self.interactions.append(inter)

if self.interactions_tanh is not None:
self.interactions_tanh.append(NormActivation(interaction_irreps, torch.tanh))

# prod compute h^{(t+1)} from A^{(t)} and h^{(t)}, computing B and the messages internally.
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
Expand Down Expand Up @@ -304,6 +337,10 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(lengths)
if self.edge_attribute_mixing is not None:
edge_diffusion_scalar_embeddings = self.diffusion_scalar_embedding(data["edge_diffusion_scalars"])
augmented_edge_attributes = self.edge_attribute_mixing(edge_diffusion_scalar_embeddings, edge_feats)
edge_feats = self.edge_hidden_layers(augmented_edge_attributes)

forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding

Expand All @@ -320,6 +357,15 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t
edge_feats=edge_feats,
edge_index=data["edge_index"],
)

if self.interactions_tanh is not None:
# reshape from (node, channels, (l_max + 1)**2) to a (node, -1) tensor compatible with e3nn
node_feats = reshape_from_mace_to_e3nn(node_feats, self.interactions_tanh[i].irreps_in)
# apply non-linearity
node_feats = self.interactions_tanh[i](node_feats)
# reshape from e3nn shape to mace format (node, channels, (l_max+1)**2)
node_feats = reshape_from_e3nn_to_mace(node_feats, self.interactions_tanh[i].irreps_out)

node_feats = product(
node_feats=node_feats,
sc=sc,
Expand Down
62 changes: 57 additions & 5 deletions crystal_diffusion/models/mace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def get_adj_matrix(positions: torch.Tensor,
basis_vectors: torch.Tensor,
radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create the adjacency and shift matrices.

Args:
Expand All @@ -32,6 +32,7 @@ def get_adj_matrix(positions: torch.Tensor,
adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor,
shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor
batch_indices: for each node, this indicates which batch item it originally belonged to.
number_of_edges: for each element in the batch, how many edges belong to it
"""
batch_size, number_of_atoms, spatial_dimensions = positions.shape

Expand All @@ -47,7 +48,9 @@ def get_adj_matrix(positions: torch.Tensor,
shifts = adjacency_info.shifts
batch_indices = adjacency_info.node_batch_indices

return shifted_adjacency_matrix, shifts, batch_indices
number_of_edges = adjacency_info.number_of_edges

return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges


def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data:
Expand All @@ -65,9 +68,9 @@ def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data:

batch_size, n_atom_per_graph, spatial_dimension = noisy_cartesian_positions.shape
device = noisy_cartesian_positions.device
adj_matrix, shift_matrix, batch_tensor = get_adj_matrix(positions=noisy_cartesian_positions,
basis_vectors=cell,
radial_cutoff=radial_cutoff)
adj_matrix, shift_matrix, batch_tensor, _ = get_adj_matrix(positions=noisy_cartesian_positions,
basis_vectors=cell,
radial_cutoff=radial_cutoff)
# node features are int corresponding to atom type
# TODO handle different atom types
node_attrs = torch.nn.functional.one_hot((torch.ones(batch_size * n_atom_per_graph) * 14).long(),
Expand Down Expand Up @@ -208,3 +211,52 @@ def get_normalized_irreps_permutation_indices(irreps: o3.Irreps) -> Tuple[o3.Irr
sorted_irreps = sorted_output.irreps.simplify()

return sorted_irreps, column_permutation_indices


def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor:
"""Reshape a MACE input/output tensor to a e3nn.NormActivation compatible format.

MACE uses tensors in the 2D format (ignoring the nodes / batchsize):
---- l = 0 ----
---- l = 1 ----
---- l = 1 ----
---- l = 1 ----
...
And e3nn wants a tensor in the 1D format:
---- l = 0 ---- ---- l= 1 ---- ---- l=2 ---- ...

Args:
x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2
Copy link
Collaborator

@rousseab rousseab Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: torch -> tensor
Also, the size tuple is missing a closing parenthesis.

irreps: o3 irreps matching the x tensor

Returns:
tensor of size (number of nodes, number of channels * (ell_max + 1)^2) usable by e3nn
"""
node = x.size(0)
x_ = []
for ell in range(irreps.lmax + 1):
# for example, for l=1, take indices 1, 2, 3 (in the last index) and flatten as a channel * 3 tensor
x_l = x[:, :, (ell ** 2):(ell + 1)**2].reshape(node, -1) # node, channel * (2l + 1)
x_.append(x_l)
# stack the flatten irrep tensors together
return torch.cat(x_, dim=-1)


def reshape_from_e3nn_to_mace(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor:
"""Reshape a tensor in the e3nn.NormActivation format to a MACE format.

See reshape_from_mace_to_e3nn for an explanation of the formats

Args:
x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2
irreps: o3 irreps matching the x tensor

Returns:
tensor of size (number of nodes, number of channels * (ell_max + 1)^2) usable by e3nn
"""
node = x.size(0)
x_ = []
for ell, s in enumerate(irreps.slices()):
x_l = x[:, s].reshape(node, -1, 2 * ell + 1)
x_.append(x_l)
return torch.cat(x_, dim=-1)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters):
r_max: float = 5.0
num_bessel: int = 8
num_polynomial_cutoff: int = 5
num_edge_hidden_layers: int = 0 # layers mixing sigma in edge features. Set to 0 to not add sigma in edges
edge_hidden_irreps: str = "16x0e"
max_ell: int = 2
interaction_cls: str = "RealAgnosticResidualInteractionBlock"
interaction_cls_first: str = "RealAgnosticInteractionBlock"
Expand All @@ -39,6 +41,7 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters):
radial_type: str = "bessel" # type of radial basis functions - choices=["bessel", "gaussian", "chebyshev"]
condition_embedding_size: int = 64 # dimension of the conditional variable embedding - assumed to be l=1 (odd)
use_batchnorm: bool = False
tanh_after_interaction: bool = True # use a tanh non-linearity (based on irreps norm) in the message-passing


class DiffusionMACEScoreNetwork(ScoreNetwork):
Expand All @@ -60,6 +63,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters):
r_max=hyper_params.r_max,
num_bessel=hyper_params.num_bessel,
num_polynomial_cutoff=hyper_params.num_polynomial_cutoff,
num_edge_hidden_layers=hyper_params.num_edge_hidden_layers,
edge_hidden_irreps=o3.Irreps(hyper_params.edge_hidden_irreps),
max_ell=hyper_params.max_ell,
interaction_cls=interaction_classes[hyper_params.interaction_cls],
interaction_cls_first=interaction_classes[hyper_params.interaction_cls_first],
Expand All @@ -75,7 +80,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters):
radial_MLP=hyper_params.radial_MLP,
radial_type=hyper_params.radial_type,
condition_embedding_size=hyper_params.condition_embedding_size,
use_batchnorm=hyper_params.use_batchnorm
use_batchnorm=hyper_params.use_batchnorm,
tanh_after_interaction=hyper_params.tanh_after_interaction
)

self._natoms = hyper_params.number_of_atoms
Expand Down
6 changes: 5 additions & 1 deletion examples/config_files/diffusion/config_diffusion_mace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ model:
r_max: 5.0
num_bessel: 8
num_polynomial_cutoff: 5
num_edge_hidden_layers: 0
edge_hidden_irreps: 8x0e
max_ell: 2
interaction_cls: RealAgnosticResidualInteractionBlock
interaction_cls_first: RealAgnosticInteractionBlock
Expand All @@ -38,10 +40,12 @@ model:
correlation: 3
gate: silu
radial_MLP: [8, 8, 8]
radial_type: bessel
radial_type: gaussian
conditional_prob: 0.0
conditional_gamma: 2
condition_embedding_size: 64
use_batchnorm: False
tanh_after_interaction: True
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ model:
architecture: diffusion_mace
number_of_atoms: 8
r_max: 5.0
num_bessel: 8
num_bessel: 'orion~choices([128, 256, 512])'
num_polynomial_cutoff: 5
num_edge_hidden_layers: 0
edge_hidden_irreps: 8x0e
max_ell: 2
interaction_cls: RealAgnosticResidualInteractionBlock
interaction_cls_first: RealAgnosticInteractionBlock
Expand All @@ -38,10 +40,12 @@ model:
correlation: 3
gate: silu
radial_MLP: 'orion~choices([[8, 8, 8], [32, 32, 32], [64, 64]])'
radial_type: bessel
radial_type: 'orion~choices(["bessel", "gaussian"])'
conditional_prob: 'orion~choices([0.0, 0.25, 0.5, 0.75])'
conditional_gamma: 2
condition_embedding_size: 'orion~choices([32, 64])'
use_batchnorm: False
tanh_after_interaction: True
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_diffusion_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def hyperparameters(self, r_max):
hps = dict(r_max=r_max,
num_bessel=8,
num_polynomial_cutoff=5,
num_edge_hidden_layers=0,
edge_hidden_irreps=o3.Irreps("8x0e"),
max_ell=2,
num_elements=1,
atomic_numbers=[14],
Expand All @@ -147,7 +149,8 @@ def hyperparameters(self, r_max):
correlation=2,
gate=gate_dict["silu"],
radial_MLP=[8, 8, 8],
radial_type="bessel")
radial_type="bessel",
)
return hps

@pytest.fixture()
Expand Down
Loading
Loading