Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difface fixes #56

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions crystal_diffusion/models/diffusion_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from e3nn import o3
from e3nn.nn import Activation, BatchNorm
from e3nn.nn import Activation
from mace.modules import (EquivariantProductBasisBlock, InteractionBlock,
LinearNodeEmbeddingBlock, RadialEmbeddingBlock)
from mace.modules.utils import get_edge_vectors_and_lengths
Expand Down Expand Up @@ -146,9 +146,6 @@ def __init__(
non_linearity = Activation(irreps_in=diffusion_scalar_irreps_out, acts=[gate])
self.diffusion_scalar_embedding.append(non_linearity)

normalization = BatchNorm(diffusion_scalar_irreps_out)
self.diffusion_scalar_embedding.append(normalization)

linear = o3.Linear(irreps_in=diffusion_scalar_irreps_out,
irreps_out=diffusion_scalar_irreps_out,
biases=True)
Expand Down
3 changes: 1 addition & 2 deletions crystal_diffusion/models/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class CosineAnnealingLRSchedulerParameters(SchedulerParameters):
eta_min: float = 0.0


def load_scheduler_dictionary(hyper_params: SchedulerParameters,
optimizer: optim.Optimizer) -> Dict[AnyStr, Union[optim.lr_scheduler, AnyStr]]:
def load_scheduler_dictionary(hyper_params: SchedulerParameters, optimizer: optim.Optimizer) -> Dict[AnyStr, Any]:
"""Instantiate the Scheduler.

Args:
Expand Down
2 changes: 1 addition & 1 deletion crystal_diffusion/utils/sample_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def read_from_pickle(path_to_pickle: str):
"""Read from pickle."""
with open(path_to_pickle, 'rb') as fd:
sample_trajectory = SampleTrajectory()
sample_trajectory.data = torch.load(fd)
sample_trajectory.data = torch.load(fd, map_location=torch.device('cpu'))
return sample_trajectory


Expand Down
2 changes: 1 addition & 1 deletion examples/mila_cluster/diffusion/config_diffusion_mace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
num_interactions: 2
hidden_irreps: 8x0e + 8x1o
mlp_irreps: 8x0e

number_of_mlp_layers: 0
avg_num_neighbors: 1
correlation: 3
gate: silu
Expand Down
85 changes: 85 additions & 0 deletions experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# general
exp_name: difface
run_name: run1
max_epoch: 25
log_every_n_steps: 1
gradient_clipping: 0.1

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
seed: 1234

# data
data:
batch_size: 512
num_workers: 8
max_atom: 8

# architecture
spatial_dimension: 3
model:
score_network:
architecture: diffusion_mace
number_of_atoms: 8
r_max: 5.0
num_bessel: 8
num_polynomial_cutoff: 5
max_ell: 2
interaction_cls: RealAgnosticResidualInteractionBlock
interaction_cls_first: RealAgnosticInteractionBlock
num_interactions: 2
hidden_irreps: 128x0e + 128x1o + 128x2e
mlp_irreps: 128x0e
number_of_mlp_layers: 0
avg_num_neighbors: 1
correlation: 3
gate: silu
radial_MLP: [128, 128, 128]
radial_type: bessel
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
sigma_max: 0.5 # default value'

# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
weight_decay: 1.0e-8

scheduler:
name: ReduceLROnPlateau
factor: 0.1
patience: 20

# early stopping
early_stopping:
metric: validation_epoch_loss
mode: min
patience: 10

model_checkpoint:
monitor: validation_epoch_loss
mode: min

# Sampling from the generative model
diffusion_sampling:
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
sigma_max: 0.5 # default value
sampling:
spatial_dimension: 3
number_of_corrector_steps: 1
number_of_atoms: 8
number_of_samples: 1000
sample_every_n_epochs: 5
cell_dimensions: [5.43, 5.43, 5.43]

# A callback to check the loss vs. sigma
loss_monitoring:
number_of_bins: 50
sample_every_n_epochs: 2

logging:
- comet
Loading