Skip to content

Commit

Permalink
Merge pull request #23 from mila-iqia/sampling_callback
Browse files Browse the repository at this point in the history
Sampling callback
  • Loading branch information
rousseab authored Apr 17, 2024
2 parents 991246b + ffc1a26 commit f6694aa
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 34 deletions.
5 changes: 4 additions & 1 deletion crystal_diffusion/callbacks/callback_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from pytorch_lightning import Callback

from crystal_diffusion.callbacks.sampling_callback import \
instantiate_diffusion_sampling_callback
from crystal_diffusion.callbacks.standard_callbacks import (
CustomProgressBar, instantiate_early_stopping_callback,
instantiate_model_checkpoint_callbacks)

OPTIONAL_CALLBACK_DICTIONARY = dict(early_stopping=instantiate_early_stopping_callback,
model_checkpoint=instantiate_model_checkpoint_callbacks)
model_checkpoint=instantiate_model_checkpoint_callbacks,
diffusion_sampling=instantiate_diffusion_sampling_callback)


def create_all_callbacks(hyper_params: Dict[AnyStr, Any], output_directory: str, verbose: bool) -> Dict[str, Callback]:
Expand Down
14 changes: 0 additions & 14 deletions crystal_diffusion/callbacks/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import dataclasses

import matplotlib.pyplot as plt
import torch
from pytorch_lightning import Callback
Expand All @@ -19,18 +17,6 @@
TENSORBOARD_FIGSIZE = (0.65 * PLEASANT_FIG_SIZE[0], 0.65 * PLEASANT_FIG_SIZE[1])


class HPLoggingCallback(Callback):
"""This callback is responsible for logging hyperparameters."""

def on_train_start(self, trainer, pl_module):
"""Log hyperparameters when training starts."""
assert hasattr(
pl_module, "hyper_params"
), "The lightning module should have a hyper_params attribute for HP logging."
hp_dict = dataclasses.asdict(pl_module.hyper_params)
trainer.logger.log_hyperparams(hp_dict)


class TensorBoardDebuggingLoggingCallback(Callback):
"""Base class to log debugging information for plotting on TensorBoard."""

Expand Down
129 changes: 129 additions & 0 deletions crystal_diffusion/callbacks/sampling_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import logging
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AnyStr, Dict, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from pytorch_lightning import Callback, LightningModule, Trainer

from crystal_diffusion.analysis import PLEASANT_FIG_SIZE
from crystal_diffusion.loggers.logger_loader import log_figure
from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps
from crystal_diffusion.samplers.predictor_corrector_position_sampler import \
AnnealedLangevinDynamicsSampler
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

logger = logging.getLogger(__name__)


@dataclass(kw_only=True)
class SamplingParameters:
"""Hyper-parameters for diffusion sampling."""
spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live.
number_of_corrector_steps: int = 1
number_of_atoms: int # the number of atoms that must be generated in a sampled configuration.
number_of_samples: int
sample_every_n_epochs: int = 1 # Sampling is expensive; control frequency
cell_dimensions: List[float] # unit cell dimensions; the unit cell is assumed to be a orthogonal.


def instantiate_diffusion_sampling_callback(callback_params: Dict[AnyStr, Any],
output_directory: str,
verbose: bool) -> Dict[str, Callback]:
"""Instantiate the Diffusion Sampling callback."""
noise_parameters = NoiseParameters(**callback_params['noise'])
sampling_parameters = SamplingParameters(**callback_params['sampling'])

sample_output_directory = os.path.join(output_directory, 'energy_samples')
Path(sample_output_directory).mkdir(parents=True, exist_ok=True)

diffusion_sampling_callback = DiffusionSamplingCallback(noise_parameters=noise_parameters,
sampling_parameters=sampling_parameters,
output_directory=sample_output_directory)

return dict(diffusion_sampling=diffusion_sampling_callback)


class DiffusionSamplingCallback(Callback):
"""Callback class to periodically generate samples and log their energies."""

def __init__(self, noise_parameters: NoiseParameters,
sampling_parameters: SamplingParameters,
output_directory: str):
"""Init method."""
self.noise_parameters = noise_parameters
self.sampling_parameters = sampling_parameters
self.output_directory = output_directory

def _draw_sample_of_relative_positions(self, pl_model: LightningModule) -> np.ndarray:
"""Draw a sample from the generative model."""
logger.info("Creating sampler")
sigma_normalized_score_network = pl_model.sigma_normalized_score_network

sampler_parameters = dict(noise_parameters=self.noise_parameters,
number_of_corrector_steps=self.sampling_parameters.number_of_corrector_steps,
number_of_atoms=self.sampling_parameters.number_of_atoms,
spatial_dimension=self.sampling_parameters.spatial_dimension)

pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network,
**sampler_parameters)
logger.info("Draw samples")
samples = pc_sampler.sample(self.sampling_parameters.number_of_samples)

batch_relative_positions = samples.cpu().numpy()
return batch_relative_positions

@staticmethod
def _plot_energy_histogram(sample_energies: np.ndarray) -> plt.figure:
"""Generate a plot of the energy samples."""
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)

fig.suptitle('Sampling Energy Distributions')

ax1 = fig.add_subplot(111)

ax1.hist(sample_energies, density=True, bins=50, histtype="stepfilled", alpha=0.25, color='green')
ax1.set_xlabel('Energy (eV)')
ax1.set_ylabel('Density')
fig.tight_layout()
return fig

def _compute_lammps_energies(self, batch_relative_positions: np.ndarray) -> np.ndarray:
"""Compute energies from samples."""
box = np.diag(self.sampling_parameters.cell_dimensions)
batch_positions = np.dot(batch_relative_positions, box)
atom_types = np.ones(self.sampling_parameters.number_of_atoms, dtype=int)

list_energy = []

logger.info("Compute energy from Oracle")

with tempfile.TemporaryDirectory() as tmp_work_dir:
for idx, positions in enumerate(batch_positions):
energy, forces = get_energy_and_forces_from_lammps(positions,
box,
atom_types,
tmp_work_dir=tmp_work_dir)
list_energy.append(energy)

return np.array(list_energy)

def on_validation_epoch_end(self, trainer: Trainer, pl_model: LightningModule) -> None:
"""On validation epoch end."""
if trainer.current_epoch % self.sampling_parameters.sample_every_n_epochs != 0:
return

batch_relative_positions = self._draw_sample_of_relative_positions(pl_model)
sample_energies = self._compute_lammps_energies(batch_relative_positions)

output_path = os.path.join(self.output_directory, f"energies_sample_epoch={trainer.current_epoch}.pt")
torch.save(torch.from_numpy(sample_energies), output_path)

fig = self._plot_energy_histogram(sample_energies)

for pl_logger in trainer.loggers:
log_figure(figure=fig, global_step=trainer.global_step, pl_logger=pl_logger)
22 changes: 20 additions & 2 deletions crystal_diffusion/loggers/logger_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import orion
import yaml
from matplotlib import pyplot as plt
from pytorch_lightning.loggers import (CometLogger, CSVLogger, Logger,
TensorBoardLogger)

Expand Down Expand Up @@ -52,11 +53,11 @@ def create_all_loggers(hyper_params: Dict[AnyStr, Any], output_directory: str) -
match logger_name:
case "csv":
logger = CSVLogger(save_dir=output_directory,
name=full_run_name)
name="csv_logs")
case "tensorboard":
logger = TensorBoardLogger(save_dir=output_directory,
default_hp_metric=False,
name=full_run_name,
name="tensorboard_logs",
version=0, # Necessary to resume tensorboard logging
)
case "comet":
Expand Down Expand Up @@ -131,3 +132,20 @@ def read_and_validate_comet_experiment_key(full_run_name: str, output_directory:
experiment_key = data[full_run_name]

return experiment_key


def log_figure(figure: plt.figure, global_step: int, pl_logger: Logger) -> None:
"""Log figure.
Args:
figure : a matplotlib figure.
global_step: current step index.
pl_logger : a pytorch lightning Logger.
Returns:
No return
"""
if type(pl_logger) is CometLogger:
pl_logger.experiment.log_figure(figure)
elif type(pl_logger) is TensorBoardLogger:
pl_logger.experiment.add_figure("train/samples", figure, global_step=global_step)
38 changes: 32 additions & 6 deletions crystal_diffusion/models/position_diffusion_lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def configure_optimizers(self):
"""
return load_optimizer(self.hyper_params.optimizer_parameters, self)

@staticmethod
def _get_batch_size(batch: torch.Tensor) -> int:
"""Get batch size.
Args:
batch : a dictionary that should contain a data sample.
Returns:
batch_size: the size of the batch.
"""
# The relative positions have dimensions [batch_size, number_of_atoms, spatial_dimension].
assert "relative_positions" in batch, "The field 'relative_positions' is missing from the input."
batch_size = batch["relative_positions"].shape[0]
return batch_size

def _generic_step(
self,
batch: typing.Any,
Expand Down Expand Up @@ -112,7 +127,7 @@ def _generic_step(
f"the shape of the relative_positions array should be [batch_size, number_of_atoms, spatial_dimensions]. "
f"Got shape = {shape}."
)
batch_size = shape[0]
batch_size = self._get_batch_size(batch)

noise_sample = self.variance_sampler.get_random_noise_sample(batch_size)

Expand Down Expand Up @@ -205,21 +220,32 @@ def training_step(self, batch, batch_idx):
"""Runs a prediction step for training, returning the loss."""
output = self._generic_step(batch, batch_idx)
loss = output["loss"]
self.log("train_loss", loss, prog_bar=True)
self.log("epoch", self.current_epoch)
self.log("step", self.global_step)

batch_size = self._get_batch_size(batch)

# The 'train_step_loss' is only logged on_step, meaning it is a value for each batch
self.log("train_step_loss", loss, on_step=True, on_epoch=False, prog_bar=True)

# The 'train_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch.
self.log("train_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True)
return output

def validation_step(self, batch, batch_idx):
"""Runs a prediction step for validation, logging the loss."""
output = self._generic_step(batch, batch_idx)
loss = output["loss"]
self.log("val_loss", loss, prog_bar=True)
batch_size = self._get_batch_size(batch)

# The 'validation_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch.
self.log("validation_epoch_loss", loss,
batch_size=batch_size, on_step=False, on_epoch=True, prog_bar=True)
return output

def test_step(self, batch, batch_idx):
"""Runs a prediction step for testing, logging the loss."""
output = self._generic_step(batch, batch_idx)
loss = output["loss"]
self.log("test_loss", loss)
batch_size = self._get_batch_size(batch)
# The 'test_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch.
self.log("test_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True)
return output
25 changes: 20 additions & 5 deletions examples/local/diffusion/config_diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
exp_name: example_experiment
run_name: run_debug_delete_me
max_epoch: 10
log_every_n_steps: 1
log_every_n_steps: 100

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
seed: 1234

# data
data:
batch_size: 32
num_workers: 4
batch_size: 128
num_workers: 0
max_atom: 64

# architecture
Expand All @@ -31,14 +31,29 @@ optimizer:

# early stopping
early_stopping:
metric: val_loss
metric: validation_epoch_loss
mode: min
patience: 100

model_checkpoint:
monitor: val_loss
monitor: validation_epoch_loss
mode: min

# Sampling from the generative model
diffusion_sampling:
noise:
total_time_steps: 10
sigma_min: 0.005 # default value
sigma_max: 0.5 # default value
sampling:
spatial_dimension: 3
number_of_corrector_steps: 1
number_of_atoms: 64
number_of_samples: 128
sample_every_n_epochs: 1
cell_dimensions: [10.86, 10.86, 10.86]

logging:
- csv
- tensorboard
- comet
20 changes: 16 additions & 4 deletions sanity_checks/overfit_fake_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
- the trained score network should reproduce the perturbation kernel, at least in the regions where it is sampled.
- the generated samples should be tightly clustered around x0.
"""
import dataclasses
import os

import pytorch_lightning
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader

from crystal_diffusion.callbacks.callbacks import (
HPLoggingCallback, TensorboardGeneratedSamplesLoggingCallback,
TensorboardGeneratedSamplesLoggingCallback,
TensorboardHistogramLoggingCallback, TensorboardSamplesLoggingCallback,
TensorboardScoreAndErrorLoggingCallback)
from crystal_diffusion.models.optimizer import (OptimizerParameters,
Expand All @@ -28,6 +29,19 @@
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
from sanity_checks import SANITY_CHECK_FOLDER


class HPLoggingCallback(Callback):
"""This callback is responsible for logging hyperparameters."""

def on_train_start(self, trainer, pl_module):
"""Log hyperparameters when training starts."""
assert hasattr(
pl_module, "hyper_params"
), "The lightning module should have a hyper_params attribute for HP logging."
hp_dict = dataclasses.asdict(pl_module.hyper_params)
trainer.logger.log_hyperparams(hp_dict)


batch_size = 4096
number_of_atoms = 1
spatial_dimension = 1
Expand Down Expand Up @@ -69,11 +83,9 @@

score_error_callback = TensorboardScoreAndErrorLoggingCallback(x0=x0)


tbx_logger = TensorBoardLogger(save_dir=os.path.join(SANITY_CHECK_FOLDER, "tensorboard"), name="overfit_fake_data")

if __name__ == '__main__':

pytorch_lightning.seed_everything(123)
all_positions = x0 * torch.ones(batch_size, number_of_atoms, spatial_dimension)
data = [dict(relative_positions=configuration) for configuration in all_positions]
Expand Down
Loading

0 comments on commit f6694aa

Please sign in to comment.