From 264f5ea2b65c2282552fbca8523a62bce85a2df6 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 14:47:06 -0400 Subject: [PATCH 01/20] mixing sigma with edge features and sigma to sigma+1 --- crystal_diffusion/models/diffusion_mace.py | 24 ++++++++++++++++--- .../diffusion_mace_score_network.py | 4 ++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index a7a691fe..e233b697 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -53,7 +53,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # We broadcast to each node to avoid complex broadcasting logic within the model itself. # TODO: it might be better to define the noise as a 'global' graph attribute, and find 'the right way' of # mixing it with bona fide node features within the model. - noises = batch[NOISE] # [batch_size, 1] + noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] # [batchsize * natoms, spatial dimension] @@ -92,6 +92,8 @@ def __init__( r_max: float, num_bessel: int, num_polynomial_cutoff: int, + num_edge_hidden_layers: int, + edge_hidden_irreps: o3.Irreps, max_ell: int, interaction_cls: Type[InteractionBlock], interaction_cls_first: Type[InteractionBlock], @@ -182,6 +184,20 @@ def __init__( ) edge_feats_irreps = o3.Irreps([(self.radial_embedding.out_dim, scalar_irrep)]) + self.edge_attribute_mixing = o3.FullyConnectedTensorProduct(irreps_in1=diffusion_scalar_irreps_out, + irreps_in2=edge_feats_irreps, + irreps_out=edge_hidden_irreps, + irrep_normalization='norm') + self.edge_hidden_layers = torch.nn.ModuleList([]) + edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate]) + for i in range(num_edge_hidden_layers): + if i != 0: + self.edge_hidden_layers.append(edge_non_linearity) + edge_hidden_layer = o3.Linear(irreps_in=edge_hidden_irreps, + irreps_out=edge_hidden_irreps, + biases=False) + self.edge_hidden_layers.append(edge_hidden_layer) + # The "spherical harmonics" correspond to Y_{lm} in the definition of A^{(1)}, eq. 9 of the PAPER. sh_irreps = o3.Irreps.spherical_harmonics(max_ell) interaction_irreps = (sh_irreps * number_of_hidden_scalar_dimensions).sort()[0].simplify() @@ -199,7 +215,7 @@ def __init__( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, + edge_feats_irreps=edge_hidden_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, @@ -236,7 +252,7 @@ def __init__( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, + edge_feats_irreps=edge_hidden_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, @@ -292,6 +308,8 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) + augmented_edge_attributes = self.edge_attribute_mixing(diffusion_scalar_embeddings, edge_feats) + edge_feats = self.edge_hidden_layers(augmented_edge_attributes) forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding diff --git a/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py b/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py index c8fc2330..570971ad 100644 --- a/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py +++ b/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py @@ -25,6 +25,8 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters): r_max: float = 5.0 num_bessel: int = 8 num_polynomial_cutoff: int = 5 + num_edge_hidden_layers: int = 1 + edge_hidden_irreps: str = "16x0e" max_ell: int = 2 interaction_cls: str = "RealAgnosticResidualInteractionBlock" interaction_cls_first: str = "RealAgnosticInteractionBlock" @@ -59,6 +61,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters): r_max=hyper_params.r_max, num_bessel=hyper_params.num_bessel, num_polynomial_cutoff=hyper_params.num_polynomial_cutoff, + num_edge_hidden_layers=hyper_params.num_edge_hidden_layers, + edge_hidden_irreps=o3.Irreps(hyper_params.edge_hidden_irreps), max_ell=hyper_params.max_ell, interaction_cls=interaction_classes[hyper_params.interaction_cls], interaction_cls_first=interaction_classes[hyper_params.interaction_cls_first], From 589b13d9504ca4f3fd47de512f3017a733ca13d2 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 15:20:07 -0400 Subject: [PATCH 02/20] edge diffusion scalars do not match dimension with node diffusion scalar --- crystal_diffusion/models/diffusion_mace.py | 11 +++++++---- crystal_diffusion/models/mace_utils.py | 7 +++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index e233b697..d1cf4d34 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -41,9 +41,9 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl basis_vectors = batch[UNIT_CELL] # batch, spatial_dimension, spatial_dimension - adj_matrix, shift_matrix, batch_tensor = get_adj_matrix(positions=cartesian_positions, - basis_vectors=basis_vectors, - radial_cutoff=radial_cutoff) + adj_matrix, shift_matrix, batch_tensor, num_edges = get_adj_matrix(positions=cartesian_positions, + basis_vectors=basis_vectors, + radial_cutoff=radial_cutoff) # node features are int corresponding to atom type # TODO handle different atom types @@ -55,6 +55,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # mixing it with bona fide node features within the model. noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] + edge_diffusion_scalars = noises.repeat_interleave(num_edges, dim=0) # [batchsize * natoms, spatial dimension] flat_cartesian_positions = cartesian_positions.view(-1, spatial_dimension) @@ -70,6 +71,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl graph_data = Data(edge_index=adj_matrix, node_attrs=node_attrs.to(device), node_diffusion_scalars=node_diffusion_scalars.to(device), + edge_diffusion_scalars=edge_diffusion_scalars.to(device), positions=flat_cartesian_positions, ptr=ptr.to(device), batch=batch_tensor.to(device), @@ -308,7 +310,8 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) - augmented_edge_attributes = self.edge_attribute_mixing(diffusion_scalar_embeddings, edge_feats) + edge_diffusion_scalar_embeddings = self.diffusion_scalar_embedding(data["edge_diffusion_scalars"]) + augmented_edge_attributes = self.edge_attribute_mixing(edge_diffusion_scalar_embeddings, edge_feats) edge_feats = self.edge_hidden_layers(augmented_edge_attributes) forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index 8c3b9eb7..fc69f9b3 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -14,7 +14,7 @@ def get_adj_matrix(positions: torch.Tensor, basis_vectors: torch.Tensor, - radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Create the adjacency and shift matrices. Args: @@ -32,6 +32,7 @@ def get_adj_matrix(positions: torch.Tensor, adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor, shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor batch_indices: for each node, this indicates which batch item it originally belonged to. + number_of_edges: for each element in the batch, how many edges belong to it """ batch_size, number_of_atoms, spatial_dimensions = positions.shape @@ -47,7 +48,9 @@ def get_adj_matrix(positions: torch.Tensor, shifts = adjacency_info.shifts batch_indices = adjacency_info.node_batch_indices - return shifted_adjacency_matrix, shifts, batch_indices + number_of_edges = adjacency_info.number_of_edges + + return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data: From 882399cec3863e8c22d3466ea0deb6b9a4558b92 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 15:32:47 -0400 Subject: [PATCH 03/20] gpu & tensor type issues --- crystal_diffusion/models/diffusion_mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index d1cf4d34..deae94af 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -55,7 +55,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # mixing it with bona fide node features within the model. noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] - edge_diffusion_scalars = noises.repeat_interleave(num_edges, dim=0) + edge_diffusion_scalars = noises.repeat_interleave(num_edges.long().to(noises), dim=0) # [batchsize * natoms, spatial dimension] flat_cartesian_positions = cartesian_positions.view(-1, spatial_dimension) From 6437f27db0f0a4e1b2787e1d9799965aba0392ea Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 15:45:13 -0400 Subject: [PATCH 04/20] sequential and not module list + adding bn for regularization --- crystal_diffusion/models/diffusion_mace.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index deae94af..e073454c 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -2,7 +2,7 @@ import torch from e3nn import o3 -from e3nn.nn import Activation +from e3nn.nn import Activation, BatchNorm from mace.modules import (EquivariantProductBasisBlock, InteractionBlock, LinearNodeEmbeddingBlock, RadialEmbeddingBlock) from mace.modules.utils import get_edge_vectors_and_lengths @@ -190,7 +190,7 @@ def __init__( irreps_in2=edge_feats_irreps, irreps_out=edge_hidden_irreps, irrep_normalization='norm') - self.edge_hidden_layers = torch.nn.ModuleList([]) + self.edge_hidden_layers = torch.nn.Sequential() edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate]) for i in range(num_edge_hidden_layers): if i != 0: @@ -199,6 +199,8 @@ def __init__( irreps_out=edge_hidden_irreps, biases=False) self.edge_hidden_layers.append(edge_hidden_layer) + bn = BatchNorm(edge_hidden_irreps) + self.edge_hidden_layers.append(bn) # The "spherical harmonics" correspond to Y_{lm} in the definition of A^{(1)}, eq. 9 of the PAPER. sh_irreps = o3.Irreps.spherical_harmonics(max_ell) From 0b3dd95f54ec384b0226c62e539fa3e210ee6383 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 14:47:06 -0400 Subject: [PATCH 05/20] mixing sigma with edge features and sigma to sigma+1 --- crystal_diffusion/models/diffusion_mace.py | 24 ++++++++++++++++--- .../diffusion_mace_score_network.py | 4 ++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 8f9ce1c5..87efedc2 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -53,7 +53,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # We broadcast to each node to avoid complex broadcasting logic within the model itself. # TODO: it might be better to define the noise as a 'global' graph attribute, and find 'the right way' of # mixing it with bona fide node features within the model. - noises = batch[NOISE] # [batch_size, 1] + noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] # [batchsize * natoms, spatial dimension] @@ -92,6 +92,8 @@ def __init__( r_max: float, num_bessel: int, num_polynomial_cutoff: int, + num_edge_hidden_layers: int, + edge_hidden_irreps: o3.Irreps, max_ell: int, interaction_cls: Type[InteractionBlock], interaction_cls_first: Type[InteractionBlock], @@ -183,6 +185,20 @@ def __init__( ) edge_feats_irreps = o3.Irreps([(self.radial_embedding.out_dim, scalar_irrep)]) + self.edge_attribute_mixing = o3.FullyConnectedTensorProduct(irreps_in1=diffusion_scalar_irreps_out, + irreps_in2=edge_feats_irreps, + irreps_out=edge_hidden_irreps, + irrep_normalization='norm') + self.edge_hidden_layers = torch.nn.ModuleList([]) + edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate]) + for i in range(num_edge_hidden_layers): + if i != 0: + self.edge_hidden_layers.append(edge_non_linearity) + edge_hidden_layer = o3.Linear(irreps_in=edge_hidden_irreps, + irreps_out=edge_hidden_irreps, + biases=False) + self.edge_hidden_layers.append(edge_hidden_layer) + # The "spherical harmonics" correspond to Y_{lm} in the definition of A^{(1)}, eq. 9 of the PAPER. sh_irreps = o3.Irreps.spherical_harmonics(max_ell) interaction_irreps = (sh_irreps * number_of_hidden_scalar_dimensions).sort()[0].simplify() @@ -200,7 +216,7 @@ def __init__( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, + edge_feats_irreps=edge_hidden_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, @@ -244,7 +260,7 @@ def __init__( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, + edge_feats_irreps=edge_hidden_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, @@ -304,6 +320,8 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) + augmented_edge_attributes = self.edge_attribute_mixing(diffusion_scalar_embeddings, edge_feats) + edge_feats = self.edge_hidden_layers(augmented_edge_attributes) forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding diff --git a/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py b/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py index 5841e330..e3223ea3 100644 --- a/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py +++ b/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py @@ -25,6 +25,8 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters): r_max: float = 5.0 num_bessel: int = 8 num_polynomial_cutoff: int = 5 + num_edge_hidden_layers: int = 1 + edge_hidden_irreps: str = "16x0e" max_ell: int = 2 interaction_cls: str = "RealAgnosticResidualInteractionBlock" interaction_cls_first: str = "RealAgnosticInteractionBlock" @@ -60,6 +62,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters): r_max=hyper_params.r_max, num_bessel=hyper_params.num_bessel, num_polynomial_cutoff=hyper_params.num_polynomial_cutoff, + num_edge_hidden_layers=hyper_params.num_edge_hidden_layers, + edge_hidden_irreps=o3.Irreps(hyper_params.edge_hidden_irreps), max_ell=hyper_params.max_ell, interaction_cls=interaction_classes[hyper_params.interaction_cls], interaction_cls_first=interaction_classes[hyper_params.interaction_cls_first], From fd6603e91d167fc00e4acd3518a3bbf59d2df430 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 15:20:07 -0400 Subject: [PATCH 06/20] edge diffusion scalars do not match dimension with node diffusion scalar --- crystal_diffusion/models/diffusion_mace.py | 11 +++++++---- crystal_diffusion/models/mace_utils.py | 7 +++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 87efedc2..3804460b 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -41,9 +41,9 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl basis_vectors = batch[UNIT_CELL] # batch, spatial_dimension, spatial_dimension - adj_matrix, shift_matrix, batch_tensor = get_adj_matrix(positions=cartesian_positions, - basis_vectors=basis_vectors, - radial_cutoff=radial_cutoff) + adj_matrix, shift_matrix, batch_tensor, num_edges = get_adj_matrix(positions=cartesian_positions, + basis_vectors=basis_vectors, + radial_cutoff=radial_cutoff) # node features are int corresponding to atom type # TODO handle different atom types @@ -55,6 +55,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # mixing it with bona fide node features within the model. noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] + edge_diffusion_scalars = noises.repeat_interleave(num_edges, dim=0) # [batchsize * natoms, spatial dimension] flat_cartesian_positions = cartesian_positions.view(-1, spatial_dimension) @@ -70,6 +71,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl graph_data = Data(edge_index=adj_matrix, node_attrs=node_attrs.to(device), node_diffusion_scalars=node_diffusion_scalars.to(device), + edge_diffusion_scalars=edge_diffusion_scalars.to(device), positions=flat_cartesian_positions, ptr=ptr.to(device), batch=batch_tensor.to(device), @@ -320,7 +322,8 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) - augmented_edge_attributes = self.edge_attribute_mixing(diffusion_scalar_embeddings, edge_feats) + edge_diffusion_scalar_embeddings = self.diffusion_scalar_embedding(data["edge_diffusion_scalars"]) + augmented_edge_attributes = self.edge_attribute_mixing(edge_diffusion_scalar_embeddings, edge_feats) edge_feats = self.edge_hidden_layers(augmented_edge_attributes) forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index 8c3b9eb7..fc69f9b3 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -14,7 +14,7 @@ def get_adj_matrix(positions: torch.Tensor, basis_vectors: torch.Tensor, - radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Create the adjacency and shift matrices. Args: @@ -32,6 +32,7 @@ def get_adj_matrix(positions: torch.Tensor, adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor, shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor batch_indices: for each node, this indicates which batch item it originally belonged to. + number_of_edges: for each element in the batch, how many edges belong to it """ batch_size, number_of_atoms, spatial_dimensions = positions.shape @@ -47,7 +48,9 @@ def get_adj_matrix(positions: torch.Tensor, shifts = adjacency_info.shifts batch_indices = adjacency_info.node_batch_indices - return shifted_adjacency_matrix, shifts, batch_indices + number_of_edges = adjacency_info.number_of_edges + + return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data: From 53979756638fe6ba8774c64332a45c852cb79c71 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 15:32:47 -0400 Subject: [PATCH 07/20] gpu & tensor type issues --- crystal_diffusion/models/diffusion_mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 3804460b..93f6a0f6 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -55,7 +55,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # mixing it with bona fide node features within the model. noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] - edge_diffusion_scalars = noises.repeat_interleave(num_edges, dim=0) + edge_diffusion_scalars = noises.repeat_interleave(num_edges.long().to(noises), dim=0) # [batchsize * natoms, spatial dimension] flat_cartesian_positions = cartesian_positions.view(-1, spatial_dimension) From 7b0906a4e734d60eb073fa2c49c246fb36bc87b5 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Sun, 16 Jun 2024 15:45:13 -0400 Subject: [PATCH 08/20] sequential and not module list + adding bn for regularization --- crystal_diffusion/models/diffusion_mace.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 93f6a0f6..44625041 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -191,7 +191,7 @@ def __init__( irreps_in2=edge_feats_irreps, irreps_out=edge_hidden_irreps, irrep_normalization='norm') - self.edge_hidden_layers = torch.nn.ModuleList([]) + self.edge_hidden_layers = torch.nn.Sequential() edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate]) for i in range(num_edge_hidden_layers): if i != 0: @@ -200,6 +200,8 @@ def __init__( irreps_out=edge_hidden_irreps, biases=False) self.edge_hidden_layers.append(edge_hidden_layer) + bn = BatchNorm(edge_hidden_irreps) + self.edge_hidden_layers.append(bn) # The "spherical harmonics" correspond to Y_{lm} in the definition of A^{(1)}, eq. 9 of the PAPER. sh_irreps = o3.Irreps.spherical_harmonics(max_ell) From e417c985f6bacbf8b6b2674ac5eebf33dbe6efe6 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 18 Jun 2024 14:24:59 -0400 Subject: [PATCH 09/20] edge_features in difface --- crystal_diffusion/models/diffusion_mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 44625041..61876019 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -55,7 +55,7 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl # mixing it with bona fide node features within the model. noises = batch[NOISE] + 1 # [batch_size, 1] - add 1 to avoid getting a zero at sigma=0 (initialization issues) node_diffusion_scalars = noises.repeat_interleave(n_atom_per_graph, dim=0) # [flat_batch_size, 1] - edge_diffusion_scalars = noises.repeat_interleave(num_edges.long().to(noises), dim=0) + edge_diffusion_scalars = noises.repeat_interleave(num_edges.to(noises).long(), dim=0) # [batchsize * natoms, spatial dimension] flat_cartesian_positions = cartesian_positions.view(-1, spatial_dimension) From 61bcf4e01cfce4e0c842b598e60df9fb37ed3e58 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 10:48:15 -0400 Subject: [PATCH 10/20] option to add tanh in difface & adding sigma to edge features --- crystal_diffusion/models/diffusion_mace.py | 48 ++++++++++++------- .../diffusion_mace_score_network.py | 6 ++- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 2fdf2134..ca4142d1 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -2,7 +2,7 @@ import torch from e3nn import o3 -from e3nn.nn import Activation, BatchNorm +from e3nn.nn import Activation, BatchNorm, NormActivation from mace.modules import (EquivariantProductBasisBlock, InteractionBlock, LinearNodeEmbeddingBlock, RadialEmbeddingBlock) from mace.modules.utils import get_edge_vectors_and_lengths @@ -112,6 +112,7 @@ def __init__( radial_type: Optional[str] = "bessel", condition_embedding_size: int = 64, # dimension of the conditional variable embedding - assumed to be l=1 (odd) use_batchnorm: bool = False, + tanh_after_interaction: bool = True, ): """Init method.""" assert num_elements == 1, "only a single element can be used at this time. Set 'num_elements' to 1." @@ -187,21 +188,23 @@ def __init__( ) edge_feats_irreps = o3.Irreps([(self.radial_embedding.out_dim, scalar_irrep)]) - self.edge_attribute_mixing = o3.FullyConnectedTensorProduct(irreps_in1=diffusion_scalar_irreps_out, - irreps_in2=edge_feats_irreps, - irreps_out=edge_hidden_irreps, - irrep_normalization='norm') - self.edge_hidden_layers = torch.nn.Sequential() - edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate]) - for i in range(num_edge_hidden_layers): - if i != 0: - self.edge_hidden_layers.append(edge_non_linearity) - edge_hidden_layer = o3.Linear(irreps_in=edge_hidden_irreps, - irreps_out=edge_hidden_irreps, - biases=False) - self.edge_hidden_layers.append(edge_hidden_layer) - bn = BatchNorm(edge_hidden_irreps) - self.edge_hidden_layers.append(bn) + if num_edge_hidden_layers > 0: + self.edge_attribute_mixing = o3.FullyConnectedTensorProduct(irreps_in1=diffusion_scalar_irreps_out, + irreps_in2=edge_feats_irreps, + irreps_out=edge_hidden_irreps, + irrep_normalization='norm') + self.edge_hidden_layers = torch.nn.Sequential() + edge_non_linearity = Activation(irreps_in=edge_hidden_irreps, acts=[gate]) + for i in range(num_edge_hidden_layers): + if i != 0: + self.edge_hidden_layers.append(edge_non_linearity) + edge_hidden_layer = o3.Linear(irreps_in=edge_hidden_irreps, + irreps_out=edge_hidden_irreps, + biases=False) + self.edge_hidden_layers.append(edge_hidden_layer) + else: + self.edge_attribute_mixing, self.edge_hidden_layers = None, None + edge_hidden_irreps = edge_feats_irreps # The "spherical harmonics" correspond to Y_{lm} in the definition of A^{(1)}, eq. 9 of the PAPER. sh_irreps = o3.Irreps.spherical_harmonics(max_ell) @@ -228,6 +231,11 @@ def __init__( ) self.interactions = torch.nn.ModuleList([inter]) + if tanh_after_interaction: + self.interactions_tanh = torch.nn.ModuleList([NormActivation(inter.target_irreps, torch.tanh)]) + else: + self.interactions_tanh = None + # 'sc' means 'self-connection', namely a 'residual-like' connection, h^{t+1} = m^{t} + (sc) x h^{(t)} # Use the appropriate self connection at the first layer for proper E0 use_sc_first = False @@ -272,6 +280,9 @@ def __init__( ) self.interactions.append(inter) + if self.interactions_tanh is not None: + self.interactions_tanh.append(NormActivation(interaction_irreps, torch.tanh)) + # prod compute h^{(t+1)} from A^{(t)} and h^{(t)}, computing B and the messages internally. prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, @@ -343,6 +354,11 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t edge_feats=edge_feats, edge_index=data["edge_index"], ) + + if self.interactions_tanh is not None: + batch_size, nf_irreps = node_feats.size(0), node_feats.size(-1) # reshaping for e3nn implementation + node_feats = self.interactions_tanh[i](node_feats.view(batch_size, -1)).view(batch_size, -1, nf_irreps) + node_feats = product( node_feats=node_feats, sc=sc, diff --git a/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py b/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py index e3223ea3..defcbacd 100644 --- a/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py +++ b/crystal_diffusion/models/score_networks/diffusion_mace_score_network.py @@ -25,7 +25,7 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters): r_max: float = 5.0 num_bessel: int = 8 num_polynomial_cutoff: int = 5 - num_edge_hidden_layers: int = 1 + num_edge_hidden_layers: int = 0 # layers mixing sigma in edge features. Set to 0 to not add sigma in edges edge_hidden_irreps: str = "16x0e" max_ell: int = 2 interaction_cls: str = "RealAgnosticResidualInteractionBlock" @@ -41,6 +41,7 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters): radial_type: str = "bessel" # type of radial basis functions - choices=["bessel", "gaussian", "chebyshev"] condition_embedding_size: int = 64 # dimension of the conditional variable embedding - assumed to be l=1 (odd) use_batchnorm: bool = False + tanh_after_interaction: bool = True # use a tanh non-linearity (based on irreps norm) in the message-passing class DiffusionMACEScoreNetwork(ScoreNetwork): @@ -79,7 +80,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters): radial_MLP=hyper_params.radial_MLP, radial_type=hyper_params.radial_type, condition_embedding_size=hyper_params.condition_embedding_size, - use_batchnorm=hyper_params.use_batchnorm + use_batchnorm=hyper_params.use_batchnorm, + tanh_after_interaction=hyper_params.tanh_after_interaction ) self._natoms = hyper_params.number_of_atoms From 5ae43d69bad598999ec394e6a7879b16c765c653 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 10:57:23 -0400 Subject: [PATCH 11/20] config files update --- .../config_files/diffusion/config_diffusion_mace.yaml | 6 +++++- .../diffusion/config_diffusion_mace_orion.yaml | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/config_files/diffusion/config_diffusion_mace.yaml b/examples/config_files/diffusion/config_diffusion_mace.yaml index ab81edb1..92e3f784 100644 --- a/examples/config_files/diffusion/config_diffusion_mace.yaml +++ b/examples/config_files/diffusion/config_diffusion_mace.yaml @@ -27,6 +27,8 @@ model: r_max: 5.0 num_bessel: 8 num_polynomial_cutoff: 5 + num_edge_hidden_layers: 0 + edge_hidden_irreps: 8x0e max_ell: 2 interaction_cls: RealAgnosticResidualInteractionBlock interaction_cls_first: RealAgnosticInteractionBlock @@ -38,10 +40,12 @@ model: correlation: 3 gate: silu radial_MLP: [8, 8, 8] - radial_type: bessel + radial_type: gaussian conditional_prob: 0.0 conditional_gamma: 2 condition_embedding_size: 64 + use_batchnorm: False + tanh_after_interaction: True noise: total_time_steps: 100 sigma_min: 0.001 # default value diff --git a/examples/config_files/diffusion/config_diffusion_mace_orion.yaml b/examples/config_files/diffusion/config_diffusion_mace_orion.yaml index c5b0ca6e..a1ec43c0 100644 --- a/examples/config_files/diffusion/config_diffusion_mace_orion.yaml +++ b/examples/config_files/diffusion/config_diffusion_mace_orion.yaml @@ -25,8 +25,10 @@ model: architecture: diffusion_mace number_of_atoms: 8 r_max: 5.0 - num_bessel: 8 + num_bessel: 'orion~choices([128, 256, 512])' num_polynomial_cutoff: 5 + num_edge_hidden_layers: 0 + edge_hidden_irreps: 8x0e max_ell: 2 interaction_cls: RealAgnosticResidualInteractionBlock interaction_cls_first: RealAgnosticInteractionBlock @@ -38,10 +40,12 @@ model: correlation: 3 gate: silu radial_MLP: 'orion~choices([[8, 8, 8], [32, 32, 32], [64, 64]])' - radial_type: bessel + radial_type: 'orion~choices(["bessel", "gaussian"])' conditional_prob: 'orion~choices([0.0, 0.25, 0.5, 0.75])' conditional_gamma: 2 condition_embedding_size: 'orion~choices([32, 64])' + use_batchnorm: False + tanh_after_interaction: True noise: total_time_steps: 100 sigma_min: 0.001 # default value From 7e09e30cd4e6b7b794f14295cec7b9a205107d8c Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 11:17:49 -0400 Subject: [PATCH 12/20] fixing input_to_mace function --- crystal_diffusion/models/mace_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index fc69f9b3..76b095dd 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -68,9 +68,9 @@ def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data: batch_size, n_atom_per_graph, spatial_dimension = noisy_cartesian_positions.shape device = noisy_cartesian_positions.device - adj_matrix, shift_matrix, batch_tensor = get_adj_matrix(positions=noisy_cartesian_positions, - basis_vectors=cell, - radial_cutoff=radial_cutoff) + adj_matrix, shift_matrix, batch_tensor, _ = get_adj_matrix(positions=noisy_cartesian_positions, + basis_vectors=cell, + radial_cutoff=radial_cutoff) # node features are int corresponding to atom type # TODO handle different atom types node_attrs = torch.nn.functional.one_hot((torch.ones(batch_size * n_atom_per_graph) * 14).long(), From 7fd4137c1b448b984480cb4171419a7d8cf06835 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 11:38:08 -0400 Subject: [PATCH 13/20] augmented edge attributes should be in a if --- crystal_diffusion/models/diffusion_mace.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index ca4142d1..fa3294ae 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -335,9 +335,10 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding(lengths) - edge_diffusion_scalar_embeddings = self.diffusion_scalar_embedding(data["edge_diffusion_scalars"]) - augmented_edge_attributes = self.edge_attribute_mixing(edge_diffusion_scalar_embeddings, edge_feats) - edge_feats = self.edge_hidden_layers(augmented_edge_attributes) + if self.edge_attribute_mixing is not None: + edge_diffusion_scalar_embeddings = self.diffusion_scalar_embedding(data["edge_diffusion_scalars"]) + augmented_edge_attributes = self.edge_attribute_mixing(edge_diffusion_scalar_embeddings, edge_feats) + edge_feats = self.edge_hidden_layers(augmented_edge_attributes) forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding From 5ab5925d01c8f00128e2a36a2f01963597f27385 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 11:55:48 -0400 Subject: [PATCH 14/20] fixing broken test due to missing parameter --- tests/models/test_diffusion_mace.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/test_diffusion_mace.py b/tests/models/test_diffusion_mace.py index f6e7ef46..da2f0313 100644 --- a/tests/models/test_diffusion_mace.py +++ b/tests/models/test_diffusion_mace.py @@ -134,6 +134,8 @@ def hyperparameters(self, r_max): hps = dict(r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, + num_edge_hidden_layers=0, + edge_hidden_irreps=o3.Irreps("8x0e"), max_ell=2, num_elements=1, atomic_numbers=[14], @@ -147,7 +149,8 @@ def hyperparameters(self, r_max): correlation=2, gate=gate_dict["silu"], radial_MLP=[8, 8, 8], - radial_type="bessel") + radial_type="bessel", + ) return hps @pytest.fixture() From 2e265f630d9ad5bcb5784502389c7312977a55fb Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 14:20:54 -0400 Subject: [PATCH 15/20] juggling tensor size because mace and e3nn do not play well together --- crystal_diffusion/models/diffusion_mace.py | 9 +++++++-- crystal_diffusion/models/mace_utils.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index fa3294ae..7fd36173 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -8,7 +8,7 @@ from mace.modules.utils import get_edge_vectors_and_lengths from torch_geometric.data import Data -from crystal_diffusion.models.mace_utils import get_adj_matrix +from crystal_diffusion.models.mace_utils import get_adj_matrix, reshape_from_mace_to_e3nn, reshape_from_e3nn_to_mace from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, NOISY_CARTESIAN_POSITIONS, UNIT_CELL) @@ -358,7 +358,12 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t if self.interactions_tanh is not None: batch_size, nf_irreps = node_feats.size(0), node_feats.size(-1) # reshaping for e3nn implementation - node_feats = self.interactions_tanh[i](node_feats.view(batch_size, -1)).view(batch_size, -1, nf_irreps) + # reshape from (node, channels, (l_max + 1)**2) to a (node, -1) tensor compatible with e3nn + node_feats = reshape_from_mace_to_e3nn(node_feats, self.interactions_tanh[i].irreps_in) + # apply non-linearity + node_feats = self.interactions_tanh[i](node_feats) + # reshape from e3nn shape to mace format (node, channels, (l_max+1)**2) + node_feats = reshape_from_e3nn_to_mace(node_feats, self.interactions_tanh[i].irreps_out) node_feats = product( node_feats=node_feats, diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index 76b095dd..a21fd62a 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -211,3 +211,22 @@ def get_normalized_irreps_permutation_indices(irreps: o3.Irreps) -> Tuple[o3.Irr sorted_irreps = sorted_output.irreps.simplify() return sorted_irreps, column_permutation_indices + + +def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps): + node = x.size(0) + # x : node, channel, irreps index + x_ = [] + for ell in range(irreps.lmax + 1): + x_l = x[:, :, (ell ** 2):(ell + 1)**2].reshape(node, -1) # node, channel * (2l + 1) + x_.append(x_l) + return torch.cat(x_, dim=-1) + + +def reshape_from_e3nn_to_mace(x, irreps): + node = x.size(0) + x_ = [] + for ell, s in enumerate(irreps.slices()): + x_l = x[:, s].reshape(node, -1, 2 * ell + 1) + x_.append(x_l) + return torch.cat(x_, dim=-1) \ No newline at end of file From 36012e4444b0994e7d8440df57f7f68213e6030a Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 14:28:18 -0400 Subject: [PATCH 16/20] added docstring --- crystal_diffusion/models/mace_utils.py | 36 +++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index a21fd62a..e675811d 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -213,17 +213,47 @@ def get_normalized_irreps_permutation_indices(irreps: o3.Irreps) -> Tuple[o3.Irr return sorted_irreps, column_permutation_indices -def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps): +def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor: + """Reshape a MACE input/output tensor to a e3nn.NormActivation compatible format. + + MACE uses tensors in the 2D format (ignoring the nodes / batchsize): + ---- l = 0 ---- + ---- l = 1 ---- + ---- l = 1 ---- + ---- l = 1 ---- + ... + And e3nn wants a tensor in the 1D format: + ---- l = 0 ---- ---- l= 1 ---- ---- l=2 ---- ... + + Args: + x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2 + irreps: o3 irreps matching the x tensor + + Returns: + tensor of size (number of nodes, number of channels * (ell_max + 1)^2) usable by e3nn + """ node = x.size(0) - # x : node, channel, irreps index x_ = [] for ell in range(irreps.lmax + 1): + # for example, for l=1, take indices 1, 2, 3 (in the last index) and flatten as a channel * 3 tensor x_l = x[:, :, (ell ** 2):(ell + 1)**2].reshape(node, -1) # node, channel * (2l + 1) x_.append(x_l) + # stack the flatten irrep tensors together return torch.cat(x_, dim=-1) -def reshape_from_e3nn_to_mace(x, irreps): +def reshape_from_e3nn_to_mace(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor: + """Reshape a tensor in the e3nn.NormActivation format to a MACE format. + + See reshape_from_mace_to_e3nn for an explanation of the formats + + Args: + x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2 + irreps: o3 irreps matching the x tensor + + Returns: + tensor of size (number of nodes, number of channels * (ell_max + 1)^2) usable by e3nn + """ node = x.size(0) x_ = [] for ell, s in enumerate(irreps.slices()): From f7db962de25155bc7741f0d1de80fa7930e3043b Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 14:31:15 -0400 Subject: [PATCH 17/20] flake8 and isort checks --- crystal_diffusion/models/diffusion_mace.py | 4 +++- crystal_diffusion/models/mace_utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 7fd36173..8abd732c 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -8,7 +8,9 @@ from mace.modules.utils import get_edge_vectors_and_lengths from torch_geometric.data import Data -from crystal_diffusion.models.mace_utils import get_adj_matrix, reshape_from_mace_to_e3nn, reshape_from_e3nn_to_mace +from crystal_diffusion.models.mace_utils import (get_adj_matrix, + reshape_from_e3nn_to_mace, + reshape_from_mace_to_e3nn) from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, NOISY_CARTESIAN_POSITIONS, UNIT_CELL) diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index e675811d..beb84a16 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -259,4 +259,4 @@ def reshape_from_e3nn_to_mace(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tenso for ell, s in enumerate(irreps.slices()): x_l = x[:, s].reshape(node, -1, 2 * ell + 1) x_.append(x_l) - return torch.cat(x_, dim=-1) \ No newline at end of file + return torch.cat(x_, dim=-1) From 1983ea092ac12248f01da83b7dc618a9804f67ba Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 14:33:55 -0400 Subject: [PATCH 18/20] useless variable not removed --- crystal_diffusion/models/diffusion_mace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index 8abd732c..6b1f732e 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -359,7 +359,6 @@ def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> t ) if self.interactions_tanh is not None: - batch_size, nf_irreps = node_feats.size(0), node_feats.size(-1) # reshaping for e3nn implementation # reshape from (node, channels, (l_max + 1)**2) to a (node, -1) tensor compatible with e3nn node_feats = reshape_from_mace_to_e3nn(node_feats, self.interactions_tanh[i].irreps_in) # apply non-linearity From c04353c950c2f63b0c296b62530561a4898a63a8 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 25 Jun 2024 15:06:08 -0400 Subject: [PATCH 19/20] unit tests for mace <-> e3nn reshapes --- tests/models/test_mace_utils.py | 55 ++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/models/test_mace_utils.py b/tests/models/test_mace_utils.py index ce628f9b..ecf5613a 100644 --- a/tests/models/test_mace_utils.py +++ b/tests/models/test_mace_utils.py @@ -11,7 +11,7 @@ from crystal_diffusion.models.mace_utils import ( get_normalized_irreps_permutation_indices, get_pretrained_mace, - input_to_mace) + input_to_mace, reshape_from_e3nn_to_mace, reshape_from_mace_to_e3nn) from crystal_diffusion.namespace import NOISY_CARTESIAN_POSITIONS, UNIT_CELL from crystal_diffusion.utils.basis_transformations import \ get_positions_from_coordinates @@ -246,3 +246,56 @@ def test_download_pretrained_mace_invalid_model_name(self, mock_model_savedir): with pytest.raises(AssertionError) as e: get_pretrained_mace("invalid_name", mock_model_savedir) assert "Model name should be small, medium or large. Got invalid_name" in str(e.value) + + +class TestReshapes: + @pytest.fixture + def num_nodes(self): + return 5 + + @pytest.fixture + def num_channels(self): + return 3 + + @pytest.fixture + def ell_max(self): + return 2 + + @pytest.fixture + def irrep(self, num_channels, ell_max): + irrep_str = f"{num_channels}x0e" + for i in range(1, ell_max + 1): + parity = "e" if i % 2 == 0 else "o" + irrep_str += f"+ {num_channels}x{i}{parity}" + return o3.Irreps(irrep_str) + + @pytest.fixture + def mace_format_tensor(self, num_nodes, num_channels, ell_max): + return torch.rand(num_nodes, num_channels, (ell_max + 1) ** 2) + + @pytest.fixture + def e3nn_format_tensor(self, num_nodes, num_channels, ell_max): + return torch.rand(num_nodes, num_channels * (ell_max + 1) ** 2) + + def test_reshape_from_mace_to_e3nn(self, mace_format_tensor, irrep, ell_max, num_channels): + converted_tensor = reshape_from_mace_to_e3nn(mace_format_tensor, irrep) + # mace_format_tensor: (node, channel, (lmax + 1) ** 2) + # converted: (node, channel * (lmax +1)**2) + # check for each ell that the values match + for ell in range(ell_max + 1): + start_idx = ell ** 2 + end_idx = (ell + 1) ** 2 + expected_values = mace_format_tensor[:, :, start_idx:end_idx].reshape(-1, num_channels * (2 * ell + 1)) + assert torch.allclose(expected_values, + converted_tensor[:, num_channels * start_idx: num_channels * end_idx]) + + def test_reshape_from_e3nn_to_mace(self, e3nn_format_tensor, irrep, ell_max, num_channels): + converted_tensor = reshape_from_e3nn_to_mace(e3nn_format_tensor, irrep) + # e3nn_format_tensor: (node, channel * (lmax +1)**2) + # converted: (node, channel, (lmax + 1) ** 2) + for ell in range(ell_max + 1): + start_idx = num_channels * (ell ** 2) + end_idx = num_channels * ((ell + 1) ** 2) + expected_values = e3nn_format_tensor[:, start_idx:end_idx].reshape(-1, num_channels, 2 * ell + 1) + assert torch.allclose(expected_values, + converted_tensor[:, :, ell ** 2:(ell + 1) ** 2]) From 3712c574cb6006f8a6bffe3361c0e1f3e850ac8f Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Wed, 26 Jun 2024 10:41:57 -0400 Subject: [PATCH 20/20] typo fixing from code review --- crystal_diffusion/models/mace_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index beb84a16..467f2142 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -226,7 +226,7 @@ def reshape_from_mace_to_e3nn(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tenso ---- l = 0 ---- ---- l= 1 ---- ---- l=2 ---- ... Args: - x: torch used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2 + x: tensor used by MACE. Should be of size (number of nodes, number of channels, (ell_max + 1)^2) irreps: o3 irreps matching the x tensor Returns: