Skip to content

Commit

Permalink
Merge pull request #57 from mila-iqia/conditional_mace
Browse files Browse the repository at this point in the history
Conditional mace
  • Loading branch information
sblackburn86 authored Jun 12, 2024
2 parents 0b31dc2 + 183eb07 commit f34687d
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 336 deletions.
39 changes: 33 additions & 6 deletions crystal_diffusion/models/diffusion_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch_geometric.data import Data

from crystal_diffusion.models.mace_utils import get_adj_matrix
from crystal_diffusion.namespace import (NOISE, NOISY_CARTESIAN_POSITIONS,
UNIT_CELL)
from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE,
NOISY_CARTESIAN_POSITIONS, UNIT_CELL)


class LinearVectorReadoutBlock(torch.nn.Module):
Expand Down Expand Up @@ -64,14 +64,18 @@ def input_to_diffusion_mace(batch: Dict[AnyStr, torch.Tensor], radial_cutoff: fl

flat_basis_vectors = basis_vectors.view(-1, spatial_dimension) # batch * spatial_dimension, spatial_dimension
# create the pytorch-geometric graph

forces = batch[CARTESIAN_FORCES].view(-1, spatial_dimension) # batch * n_atom_per_graph, spatial dimension

graph_data = Data(edge_index=adj_matrix,
node_attrs=node_attrs.to(device),
node_diffusion_scalars=node_diffusion_scalars.to(device),
positions=flat_cartesian_positions,
ptr=ptr.to(device),
batch=batch_tensor.to(device),
shifts=shift_matrix,
cell=flat_basis_vectors
cell=flat_basis_vectors,
forces=forces,
)
return graph_data

Expand Down Expand Up @@ -102,6 +106,7 @@ def __init__(
gate: Optional[Callable],
radial_MLP: List[int],
radial_type: Optional[str] = "bessel",
condition_embedding_size: int = 64 # dimension of the conditional variable embedding - assumed to be l=1 (odd)
):
"""Init method."""
assert num_elements == 1, "only a single element can be used at this time. Set 'num_elements' to 1."
Expand Down Expand Up @@ -142,8 +147,8 @@ def __init__(
irreps_out=diffusion_scalar_irreps_out,
biases=True)
self.diffusion_scalar_embedding.append(linear)
non_linearity = Activation(irreps_in=diffusion_scalar_irreps_out, acts=[gate])
for _ in range(number_of_mlp_layers):
non_linearity = Activation(irreps_in=diffusion_scalar_irreps_out, acts=[gate])
self.diffusion_scalar_embedding.append(non_linearity)

linear = o3.Linear(irreps_in=diffusion_scalar_irreps_out,
Expand Down Expand Up @@ -252,7 +257,24 @@ def __init__(
# the output is a single vector.
self.vector_readout = LinearVectorReadoutBlock(irreps_in=hidden_irreps_out)

def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
# Apply a MLP with a bias on the forces as a conditional feature. This would be a 1o irrep
forces_irreps_in = o3.Irreps("1x1o")
forces_irreps_embedding = o3.Irreps(f"{condition_embedding_size}x1o")
self.condition_embedding_layer = o3.Linear(irreps_in=forces_irreps_in,
irreps_out=forces_irreps_embedding,
biases=False) # can't have biases with 1o irreps

# conditional layers for the forces as a conditional feature to guide the diffusion
self.conditional_layers = torch.nn.ModuleList([])
for _ in range(num_interactions):
cond_layer = o3.Linear(
irreps_in=forces_irreps_embedding,
irreps_out=hidden_irreps_out,
biases=False
)
self.conditional_layers.append(cond_layer)

def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> torch.Tensor:
"""Forward method."""
# Setup

Expand All @@ -271,7 +293,9 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(lengths)

for interaction, product in zip(self.interactions, self.products):
forces_embedding = self.condition_embedding_layer(data["forces"]) # 0e + 1o embedding

for interaction, product, cond_layer in zip(self.interactions, self.products, self.conditional_layers):
node_feats, sc = interaction(
node_attrs=augmented_node_attributes,
node_feats=node_feats,
Expand All @@ -284,6 +308,9 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
sc=sc,
node_attrs=augmented_node_attributes,
)
if conditional: # modify the node features to account for the conditional features i.e. forces
force_embed = cond_layer(forces_embedding)
node_feats += force_embed

# Outputs
vectors_output = self.vector_readout(node_feats)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters):
gate: str = "silu" # non linearity for last readout - choices: ["silu", "tanh", "abs", "None"]
radial_MLP: List[int] = field(default_factory=lambda: [64, 64, 64]) # "width of the radial MLP"
radial_type: str = "bessel" # type of radial basis functions - choices=["bessel", "gaussian", "chebyshev"]
condition_embedding_size: int = 64 # dimension of the conditional variable embedding - assumed to be l=1 (odd)


class DiffusionMACEScoreNetwork(ScoreNetwork):
Expand Down Expand Up @@ -71,7 +72,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters):
correlation=hyper_params.correlation,
gate=gate_dict[hyper_params.gate],
radial_MLP=hyper_params.radial_MLP,
radial_type=hyper_params.radial_type
radial_type=hyper_params.radial_type,
condition_embedding_size=hyper_params.condition_embedding_size,
)

self._natoms = hyper_params.number_of_atoms
Expand Down Expand Up @@ -100,15 +102,14 @@ def _forward_unchecked(self, batch: Dict[AnyStr, torch.Tensor], conditional: boo
Returns:
output : the scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor.
"""
del conditional # TODO do something with forces when conditional
relative_coordinates = batch[NOISY_RELATIVE_COORDINATES]
batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape

basis_vectors = batch[UNIT_CELL]
batch[NOISY_CARTESIAN_POSITIONS] = get_positions_from_coordinates(relative_coordinates, basis_vectors)
graph_input = input_to_diffusion_mace(batch, radial_cutoff=self.r_max)

flat_cartesian_scores = self.diffusion_mace_network(graph_input)
flat_cartesian_scores = self.diffusion_mace_network(graph_input, conditional)
cartesian_scores = flat_cartesian_scores.reshape(batch_size, number_of_atoms, spatial_dimension)

reciprocal_basis_vectors_as_columns = get_reciprocal_basis_vectors(basis_vectors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ model:
gate: silu
radial_MLP: [8, 8, 8]
radial_type: bessel
conditional_prob: 0.0
conditional_gamma: 2
condition_embedding_size: 64
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ run_name: run_debug_delete_me
max_epoch: 10
log_every_n_steps: 1
gradient_clipping: 0.1
accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
Expand All @@ -28,14 +29,17 @@ model:
interaction_cls: RealAgnosticResidualInteractionBlock
interaction_cls_first: RealAgnosticInteractionBlock
num_interactions: 2
hidden_irreps: 8x0e + 8x1o
mlp_irreps: 8x0e
hidden_irreps: 'orion~choices(["8x0e + 8x1o", "16x0e + 16x1o + 16x2e", "32x0e + 32x1o + 32x2e + 32x3o"])'
mlp_irreps: 'orion~choices(["8x0e", "32x0e"])'
number_of_mlp_layers: 0
avg_num_neighbors: 1
correlation: 3
gate: silu
radial_MLP: [8, 8, 8]
radial_MLP: 'orion~choices([[8, 8, 8], [32, 32, 32], [64, 64]])'
radial_type: bessel
conditional_prob: 'orion~choices([0.0, 0.25, 0.5, 0.75])'
conditional_gamma: 2
condition_embedding_size: 'orion~choices([32, 64])'
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
Expand All @@ -44,7 +48,7 @@ model:
# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
learning_rate: 'orion~loguniform(1e-6, 1e-3)'
weight_decay: 1.0e-6

scheduler:
Expand Down Expand Up @@ -77,6 +81,6 @@ diffusion_sampling:
cell_dimensions: [5.43, 5.43, 5.43]

logging:
- csv
- tensorboard
# - csv
# - tensorboard
- comet
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ spatial_dimension: 3
model:
score_network:
architecture: mlp
conditional_prob: 0.0
conditional_gamma: 2
number_of_atoms: 8
n_hidden_dimensions: 2
hidden_dimensions_size: 64
conditional_prob: 0.0
conditional_gamma: 2
condition_embedding_size: 64
noise:
total_time_steps: 100
sigma_min: 0.005 # default value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ run_name: run_debug_delete_me
max_epoch: 10
log_every_n_steps: 1
gradient_clipping: 0
accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
Expand All @@ -20,11 +21,12 @@ spatial_dimension: 3
model:
score_network:
architecture: mlp
conditional_prob: 0.0
conditional_gamma: 2
number_of_atoms: 8
n_hidden_dimensions: 2
hidden_dimensions_size: 64
n_hidden_dimensions: 'orion~choices([1, 2, 3, 4])'
hidden_dimensions_size: 'orion~choices([16, 32, 64])'
conditional_prob: 'orion~choices([0.0, 0.25, 0.5])'
conditional_gamma: 2
condition_embedding_size: 'orion~choices([32, 64])'
noise:
total_time_steps: 100
sigma_min: 0.005 # default value
Expand All @@ -33,7 +35,7 @@ model:
# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
learning_rate: 'orion~loguniform(1e-6, 1e-3)'
weight_decay: 1.0e-6

scheduler:
Expand Down
2 changes: 1 addition & 1 deletion examples/local/diffusion/run_diffusion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This example assumes that the dataset 'si_diffusion_small' is present locally in the DATA folder.
# It is also assumed that the user has a Comet account for logging experiments.

CONFIG=config_diffusion_mace.yaml
CONFIG=../../config_files/diffusion/config_diffusion_mace.yaml
DATA_DIR=../../../data/si_diffusion_1x1x1
PROCESSED_DATA=${DATA_DIR}/processed
DATA_WORK_DIR=./tmp_work_dir/
Expand Down
45 changes: 0 additions & 45 deletions examples/local_orion/diffusion/config_diffusion.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion examples/local_orion/diffusion/run_orion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export ORION_DB_ADDRESS='orion_db.pkl'
export ORION_DB_TYPE='pickleddb'

ROOT_DIR=../../../
CONFIG=config_diffusion.yaml
CONFIG=../../config_files/diffusion/config_diffusion_mlp_orion.yaml
DATA_DIR=${ROOT_DIR}/data/si_diffusion_small
PROCESSED_DATA=${DATA_DIR}/processed
DATA_WORK_DIR=./tmp_work_dir/
Expand Down
86 changes: 0 additions & 86 deletions examples/mila_cluster/diffusion/config_mace_equivariant_head.yaml

This file was deleted.

Loading

0 comments on commit f34687d

Please sign in to comment.