diff --git a/crystal_diffusion/callbacks/callback_loader.py b/crystal_diffusion/callbacks/callback_loader.py index 629cf026..bf7830ab 100644 --- a/crystal_diffusion/callbacks/callback_loader.py +++ b/crystal_diffusion/callbacks/callback_loader.py @@ -2,12 +2,15 @@ from pytorch_lightning import Callback +from crystal_diffusion.callbacks.sampling_callback import \ + instantiate_diffusion_sampling_callback from crystal_diffusion.callbacks.standard_callbacks import ( CustomProgressBar, instantiate_early_stopping_callback, instantiate_model_checkpoint_callbacks) OPTIONAL_CALLBACK_DICTIONARY = dict(early_stopping=instantiate_early_stopping_callback, - model_checkpoint=instantiate_model_checkpoint_callbacks) + model_checkpoint=instantiate_model_checkpoint_callbacks, + diffusion_sampling=instantiate_diffusion_sampling_callback) def create_all_callbacks(hyper_params: Dict[AnyStr, Any], output_directory: str, verbose: bool) -> Dict[str, Callback]: diff --git a/crystal_diffusion/callbacks/callbacks.py b/crystal_diffusion/callbacks/callbacks.py index 3cec517c..16694b6f 100644 --- a/crystal_diffusion/callbacks/callbacks.py +++ b/crystal_diffusion/callbacks/callbacks.py @@ -1,5 +1,3 @@ -import dataclasses - import matplotlib.pyplot as plt import torch from pytorch_lightning import Callback @@ -19,18 +17,6 @@ TENSORBOARD_FIGSIZE = (0.65 * PLEASANT_FIG_SIZE[0], 0.65 * PLEASANT_FIG_SIZE[1]) -class HPLoggingCallback(Callback): - """This callback is responsible for logging hyperparameters.""" - - def on_train_start(self, trainer, pl_module): - """Log hyperparameters when training starts.""" - assert hasattr( - pl_module, "hyper_params" - ), "The lightning module should have a hyper_params attribute for HP logging." - hp_dict = dataclasses.asdict(pl_module.hyper_params) - trainer.logger.log_hyperparams(hp_dict) - - class TensorBoardDebuggingLoggingCallback(Callback): """Base class to log debugging information for plotting on TensorBoard.""" diff --git a/crystal_diffusion/callbacks/sampling_callback.py b/crystal_diffusion/callbacks/sampling_callback.py new file mode 100644 index 00000000..74b611ae --- /dev/null +++ b/crystal_diffusion/callbacks/sampling_callback.py @@ -0,0 +1,129 @@ +import logging +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AnyStr, Dict, List + +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback, LightningModule, Trainer + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE +from crystal_diffusion.loggers.logger_loader import log_figure +from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps +from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ + AnnealedLangevinDynamicsSampler +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class SamplingParameters: + """Hyper-parameters for diffusion sampling.""" + spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. + number_of_corrector_steps: int = 1 + number_of_atoms: int # the number of atoms that must be generated in a sampled configuration. + number_of_samples: int + sample_every_n_epochs: int = 1 # Sampling is expensive; control frequency + cell_dimensions: List[float] # unit cell dimensions; the unit cell is assumed to be a orthogonal. + + +def instantiate_diffusion_sampling_callback(callback_params: Dict[AnyStr, Any], + output_directory: str, + verbose: bool) -> Dict[str, Callback]: + """Instantiate the Diffusion Sampling callback.""" + noise_parameters = NoiseParameters(**callback_params['noise']) + sampling_parameters = SamplingParameters(**callback_params['sampling']) + + sample_output_directory = os.path.join(output_directory, 'energy_samples') + Path(sample_output_directory).mkdir(parents=True, exist_ok=True) + + diffusion_sampling_callback = DiffusionSamplingCallback(noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + output_directory=sample_output_directory) + + return dict(diffusion_sampling=diffusion_sampling_callback) + + +class DiffusionSamplingCallback(Callback): + """Callback class to periodically generate samples and log their energies.""" + + def __init__(self, noise_parameters: NoiseParameters, + sampling_parameters: SamplingParameters, + output_directory: str): + """Init method.""" + self.noise_parameters = noise_parameters + self.sampling_parameters = sampling_parameters + self.output_directory = output_directory + + def _draw_sample_of_relative_positions(self, pl_model: LightningModule) -> np.ndarray: + """Draw a sample from the generative model.""" + logger.info("Creating sampler") + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + sampler_parameters = dict(noise_parameters=self.noise_parameters, + number_of_corrector_steps=self.sampling_parameters.number_of_corrector_steps, + number_of_atoms=self.sampling_parameters.number_of_atoms, + spatial_dimension=self.sampling_parameters.spatial_dimension) + + pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network, + **sampler_parameters) + logger.info("Draw samples") + samples = pc_sampler.sample(self.sampling_parameters.number_of_samples) + + batch_relative_positions = samples.cpu().numpy() + return batch_relative_positions + + @staticmethod + def _plot_energy_histogram(sample_energies: np.ndarray) -> plt.figure: + """Generate a plot of the energy samples.""" + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + + fig.suptitle('Sampling Energy Distributions') + + ax1 = fig.add_subplot(111) + + ax1.hist(sample_energies, density=True, bins=50, histtype="stepfilled", alpha=0.25, color='green') + ax1.set_xlabel('Energy (eV)') + ax1.set_ylabel('Density') + fig.tight_layout() + return fig + + def _compute_lammps_energies(self, batch_relative_positions: np.ndarray) -> np.ndarray: + """Compute energies from samples.""" + box = np.diag(self.sampling_parameters.cell_dimensions) + batch_positions = np.dot(batch_relative_positions, box) + atom_types = np.ones(self.sampling_parameters.number_of_atoms, dtype=int) + + list_energy = [] + + logger.info("Compute energy from Oracle") + + with tempfile.TemporaryDirectory() as tmp_work_dir: + for idx, positions in enumerate(batch_positions): + energy, forces = get_energy_and_forces_from_lammps(positions, + box, + atom_types, + tmp_work_dir=tmp_work_dir) + list_energy.append(energy) + + return np.array(list_energy) + + def on_validation_epoch_end(self, trainer: Trainer, pl_model: LightningModule) -> None: + """On validation epoch end.""" + if trainer.current_epoch % self.sampling_parameters.sample_every_n_epochs != 0: + return + + batch_relative_positions = self._draw_sample_of_relative_positions(pl_model) + sample_energies = self._compute_lammps_energies(batch_relative_positions) + + output_path = os.path.join(self.output_directory, f"energies_sample_epoch={trainer.current_epoch}.pt") + torch.save(torch.from_numpy(sample_energies), output_path) + + fig = self._plot_energy_histogram(sample_energies) + + for pl_logger in trainer.loggers: + log_figure(figure=fig, global_step=trainer.global_step, pl_logger=pl_logger) diff --git a/crystal_diffusion/loggers/logger_loader.py b/crystal_diffusion/loggers/logger_loader.py index 25903854..d1f6b514 100644 --- a/crystal_diffusion/loggers/logger_loader.py +++ b/crystal_diffusion/loggers/logger_loader.py @@ -4,6 +4,7 @@ import orion import yaml +from matplotlib import pyplot as plt from pytorch_lightning.loggers import (CometLogger, CSVLogger, Logger, TensorBoardLogger) @@ -52,11 +53,11 @@ def create_all_loggers(hyper_params: Dict[AnyStr, Any], output_directory: str) - match logger_name: case "csv": logger = CSVLogger(save_dir=output_directory, - name=full_run_name) + name="csv_logs") case "tensorboard": logger = TensorBoardLogger(save_dir=output_directory, default_hp_metric=False, - name=full_run_name, + name="tensorboard_logs", version=0, # Necessary to resume tensorboard logging ) case "comet": @@ -131,3 +132,20 @@ def read_and_validate_comet_experiment_key(full_run_name: str, output_directory: experiment_key = data[full_run_name] return experiment_key + + +def log_figure(figure: plt.figure, global_step: int, pl_logger: Logger) -> None: + """Log figure. + + Args: + figure : a matplotlib figure. + global_step: current step index. + pl_logger : a pytorch lightning Logger. + + Returns: + No return + """ + if type(pl_logger) is CometLogger: + pl_logger.experiment.log_figure(figure) + elif type(pl_logger) is TensorBoardLogger: + pl_logger.experiment.add_figure("train/samples", figure, global_step=global_step) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 3ca35895..17427265 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -69,6 +69,21 @@ def configure_optimizers(self): """ return load_optimizer(self.hyper_params.optimizer_parameters, self) + @staticmethod + def _get_batch_size(batch: torch.Tensor) -> int: + """Get batch size. + + Args: + batch : a dictionary that should contain a data sample. + + Returns: + batch_size: the size of the batch. + """ + # The relative positions have dimensions [batch_size, number_of_atoms, spatial_dimension]. + assert "relative_positions" in batch, "The field 'relative_positions' is missing from the input." + batch_size = batch["relative_positions"].shape[0] + return batch_size + def _generic_step( self, batch: typing.Any, @@ -112,7 +127,7 @@ def _generic_step( f"the shape of the relative_positions array should be [batch_size, number_of_atoms, spatial_dimensions]. " f"Got shape = {shape}." ) - batch_size = shape[0] + batch_size = self._get_batch_size(batch) noise_sample = self.variance_sampler.get_random_noise_sample(batch_size) @@ -205,21 +220,32 @@ def training_step(self, batch, batch_idx): """Runs a prediction step for training, returning the loss.""" output = self._generic_step(batch, batch_idx) loss = output["loss"] - self.log("train_loss", loss, prog_bar=True) - self.log("epoch", self.current_epoch) - self.log("step", self.global_step) + + batch_size = self._get_batch_size(batch) + + # The 'train_step_loss' is only logged on_step, meaning it is a value for each batch + self.log("train_step_loss", loss, on_step=True, on_epoch=False, prog_bar=True) + + # The 'train_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. + self.log("train_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True) return output def validation_step(self, batch, batch_idx): """Runs a prediction step for validation, logging the loss.""" output = self._generic_step(batch, batch_idx) loss = output["loss"] - self.log("val_loss", loss, prog_bar=True) + batch_size = self._get_batch_size(batch) + + # The 'validation_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. + self.log("validation_epoch_loss", loss, + batch_size=batch_size, on_step=False, on_epoch=True, prog_bar=True) return output def test_step(self, batch, batch_idx): """Runs a prediction step for testing, logging the loss.""" output = self._generic_step(batch, batch_idx) loss = output["loss"] - self.log("test_loss", loss) + batch_size = self._get_batch_size(batch) + # The 'test_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. + self.log("test_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True) return output diff --git a/examples/local/diffusion/config_diffusion.yaml b/examples/local/diffusion/config_diffusion.yaml index 88c6721b..60c9d47b 100644 --- a/examples/local/diffusion/config_diffusion.yaml +++ b/examples/local/diffusion/config_diffusion.yaml @@ -2,7 +2,7 @@ exp_name: example_experiment run_name: run_debug_delete_me max_epoch: 10 -log_every_n_steps: 1 +log_every_n_steps: 100 # set to null to avoid setting a seed (can speed up GPU computation, but # results will not be reproducible) @@ -10,8 +10,8 @@ seed: 1234 # data data: - batch_size: 32 - num_workers: 4 + batch_size: 128 + num_workers: 0 max_atom: 64 # architecture @@ -31,14 +31,29 @@ optimizer: # early stopping early_stopping: - metric: val_loss + metric: validation_epoch_loss mode: min patience: 100 model_checkpoint: - monitor: val_loss + monitor: validation_epoch_loss mode: min +# Sampling from the generative model +diffusion_sampling: + noise: + total_time_steps: 10 + sigma_min: 0.005 # default value + sigma_max: 0.5 # default value + sampling: + spatial_dimension: 3 + number_of_corrector_steps: 1 + number_of_atoms: 64 + number_of_samples: 128 + sample_every_n_epochs: 1 + cell_dimensions: [10.86, 10.86, 10.86] + logging: + - csv - tensorboard - comet diff --git a/sanity_checks/overfit_fake_data.py b/sanity_checks/overfit_fake_data.py index 52e39404..65424f7e 100644 --- a/sanity_checks/overfit_fake_data.py +++ b/sanity_checks/overfit_fake_data.py @@ -7,17 +7,18 @@ - the trained score network should reproduce the perturbation kernel, at least in the regions where it is sampled. - the generated samples should be tightly clustered around x0. """ +import dataclasses import os import pytorch_lightning import torch -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader from crystal_diffusion.callbacks.callbacks import ( - HPLoggingCallback, TensorboardGeneratedSamplesLoggingCallback, + TensorboardGeneratedSamplesLoggingCallback, TensorboardHistogramLoggingCallback, TensorboardSamplesLoggingCallback, TensorboardScoreAndErrorLoggingCallback) from crystal_diffusion.models.optimizer import (OptimizerParameters, @@ -28,6 +29,19 @@ from crystal_diffusion.samplers.variance_sampler import NoiseParameters from sanity_checks import SANITY_CHECK_FOLDER + +class HPLoggingCallback(Callback): + """This callback is responsible for logging hyperparameters.""" + + def on_train_start(self, trainer, pl_module): + """Log hyperparameters when training starts.""" + assert hasattr( + pl_module, "hyper_params" + ), "The lightning module should have a hyper_params attribute for HP logging." + hp_dict = dataclasses.asdict(pl_module.hyper_params) + trainer.logger.log_hyperparams(hp_dict) + + batch_size = 4096 number_of_atoms = 1 spatial_dimension = 1 @@ -69,11 +83,9 @@ score_error_callback = TensorboardScoreAndErrorLoggingCallback(x0=x0) - tbx_logger = TensorBoardLogger(save_dir=os.path.join(SANITY_CHECK_FOLDER, "tensorboard"), name="overfit_fake_data") if __name__ == '__main__': - pytorch_lightning.seed_everything(123) all_positions = x0 * torch.ones(batch_size, number_of_atoms, spatial_dimension) data = [dict(relative_positions=configuration) for configuration in all_positions] diff --git a/tests/test_train_diffusion.py b/tests/test_train_diffusion.py index 5e644bc9..2d286c75 100644 --- a/tests/test_train_diffusion.py +++ b/tests/test_train_diffusion.py @@ -68,8 +68,16 @@ def get_config(number_of_atoms: int, max_epoch: int): optimizer_config = dict(name='adam', learning_rate=0.001) - early_stopping_config = dict(metric='val_loss', mode='min', patience=max_epoch) - model_checkpoint_config = dict(monitor='val_loss', mode='min') + sampling_dict = {'spatial_dimension': 3, + 'number_of_corrector_steps': 1, + 'number_of_atoms': number_of_atoms, + 'number_of_samples': 4, + 'sample_every_n_epochs': 1, + 'cell_dimensions': [1.23, 4.56, 7.89]} + + early_stopping_config = dict(metric='validation_epoch_loss', mode='min', patience=max_epoch) + model_checkpoint_config = dict(monitor='validation_epoch_loss', mode='min') + diffusion_sampling_config = dict(noise={'total_time_steps': 10}, sampling=sampling_dict) config = dict(max_epoch=max_epoch, exp_name='smoke_test', @@ -80,6 +88,7 @@ def get_config(number_of_atoms: int, max_epoch: int): optimizer=optimizer_config, early_stopping=early_stopping_config, model_checkpoint=model_checkpoint_config, + diffusion_sampling=diffusion_sampling_config, logging=['csv', 'tensorboard']) return config