diff --git a/crystal_diffusion/analysis/analytic_score/perfect_score_ode_sampling.py b/crystal_diffusion/analysis/analytic_score/perfect_score_ode_sampling.py new file mode 100644 index 00000000..eeb000fc --- /dev/null +++ b/crystal_diffusion/analysis/analytic_score/perfect_score_ode_sampling.py @@ -0,0 +1,175 @@ +"""Perfect Score ODE sampling. + +This little ad hoc experiment explores sampling with an ODE solver, using the 'analytic' score. +It works very well! +""" + +import logging + +import matplotlib.pyplot as plt +import torch +from einops import einops + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import (get_exact_samples, + get_unit_cells) +from crystal_diffusion.generators.ode_position_generator import \ + ExplodingVarianceODEPositionGenerator +from crystal_diffusion.models.score_networks.analytical_score_network import ( + AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + +logger = logging.getLogger(__name__) + +plt.style.use(PLOT_STYLE_PATH) + +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + +spatial_dimension = 3 +number_of_atoms = 2 +kmax = 1 +spring_constant = 1000. +batch_size = 1000 +total_time_steps = 41 +if __name__ == '__main__': + + noise_parameters = NoiseParameters(total_time_steps=total_time_steps, + sigma_min=0.001, + sigma_max=0.5) + + equilibrium_relative_coordinates = torch.stack([0.25 * torch.ones(spatial_dimension), + 0.75 * torch.ones(spatial_dimension)]).to(device) + inverse_covariance = torch.zeros(number_of_atoms, spatial_dimension, number_of_atoms, spatial_dimension).to(device) + for atom_i in range(number_of_atoms): + for alpha in range(spatial_dimension): + inverse_covariance[atom_i, alpha, atom_i, alpha] = spring_constant + + score_network_parameters = AnalyticalScoreNetworkParameters( + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + kmax=kmax, + equilibrium_relative_coordinates=equilibrium_relative_coordinates, + inverse_covariance=inverse_covariance) + + sigma_normalized_score_network = AnalyticalScoreNetwork(score_network_parameters) + + position_generator = ExplodingVarianceODEPositionGenerator(noise_parameters, + number_of_atoms, + spatial_dimension, + sigma_normalized_score_network, + record_samples=True) + + times = torch.linspace(0, 1, 1001) + sigmas = position_generator._get_exploding_variance_sigma(times) + ode_prefactor = position_generator._get_ode_prefactor(sigmas) + + fig0 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig0.suptitle('ODE parameters') + + ax1 = fig0.add_subplot(121) + ax2 = fig0.add_subplot(122) + ax1.set_title('$\\sigma$ Parameter') + ax2.set_title('$\\gamma$ Parameter') + ax1.plot(times, sigmas, '-') + ax2.plot(times, ode_prefactor, '-') + + ax1.set_ylabel('$\\sigma(t)$') + ax2.set_ylabel('$\\gamma(t)$') + for ax in [ax1, ax2]: + ax.set_xlabel('Diffusion Time') + ax.set_xlim([-0.01, 1.01]) + + fig0.tight_layout() + plt.show() + + unit_cell = get_unit_cells(acell=1., spatial_dimension=spatial_dimension, number_of_samples=batch_size).to(device) + relative_coordinates = position_generator.sample(number_of_samples=batch_size, device=device, unit_cell=unit_cell) + + batch_times = position_generator.sample_trajectory_recorder.data['time'][0] + batch_relative_coordinates = position_generator.sample_trajectory_recorder.data['relative_coordinates'][0] + batch_flat_relative_coordinates = einops.rearrange(batch_relative_coordinates, "b t n d -> b t (n d)") + + fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig1.suptitle('ODE Trajectories') + + ax = fig1.add_subplot(111) + ax.set_xlabel('Diffusion Time') + ax.set_ylabel('Raw Relative Coordinate') + ax.yaxis.tick_right() + ax.spines['top'].set_visible(True) + ax.spines['right'].set_visible(True) + ax.spines['bottom'].set_visible(True) + ax.spines['left'].set_visible(True) + + time = batch_times[0] # all time arrays are the same + for flat_relative_coordinates in batch_flat_relative_coordinates[::20]: + for i in range(number_of_atoms * spatial_dimension): + coordinate = flat_relative_coordinates[:, i] + ax.plot(time.cpu(), coordinate.cpu(), '-', color='b', alpha=0.05) + + ax.set_xlim([1.01, -0.01]) + plt.show() + + exact_samples = get_exact_samples(equilibrium_relative_coordinates, + inverse_covariance, + batch_size).cpu() + + fig2 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig2.suptitle('Comparing ODE and Expected Marignal Distributions') + ax1 = fig2.add_subplot(131, aspect='equal') + ax2 = fig2.add_subplot(132, aspect='equal') + ax3 = fig2.add_subplot(133, aspect='equal') + + xs = einops.rearrange(relative_coordinates, 'b n d -> (b n) d').cpu() + zs = einops.rearrange(exact_samples, 'b n d -> (b n) d').cpu() + ax1.set_title('XY Projection') + ax1.plot(xs[:, 0], xs[:, 1], 'ro', alpha=0.5, mew=0, label='ODE Solver') + ax1.plot(zs[:, 0], zs[:, 1], 'go', alpha=0.05, mew=0, label='Exact Samples') + + ax2.set_title('XZ Projection') + ax2.plot(xs[:, 0], xs[:, 2], 'ro', alpha=0.5, mew=0, label='ODE Solver') + ax2.plot(zs[:, 0], zs[:, 2], 'go', alpha=0.05, mew=0, label='Exact Samples') + + ax3.set_title('YZ Projection') + ax3.plot(xs[:, 1], xs[:, 2], 'ro', alpha=0.5, mew=0, label='ODE Solver') + ax3.plot(zs[:, 1], zs[:, 2], 'go', alpha=0.05, mew=0, label='Exact Samples') + + for ax in [ax1, ax2, ax3]: + ax.set_xlim(-0.01, 1.01) + ax.set_ylim(-0.01, 1.01) + ax.vlines(x=[0, 1], ymin=0, ymax=1, color='k', lw=2) + ax.hlines(y=[0, 1], xmin=0, xmax=1, color='k', lw=2) + + ax1.legend(loc=0) + fig2.tight_layout() + plt.show() + + fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE) + ax1 = fig3.add_subplot(131) + ax2 = fig3.add_subplot(132) + ax3 = fig3.add_subplot(133) + fig3.suptitle("Marginal Distributions of t=0 Samples") + + common_params = dict(histtype='stepfilled', alpha=0.5, bins=50) + + ax1.hist(xs[:, 0], **common_params, facecolor='r', label='ODE solver') + ax2.hist(xs[:, 1], **common_params, facecolor='r', label='ODE solver') + ax3.hist(xs[:, 2], **common_params, facecolor='r', label='ODE solver') + + ax1.hist(zs[:, 0], **common_params, facecolor='g', label='Exact') + ax2.hist(zs[:, 1], **common_params, facecolor='g', label='Exact') + ax3.hist(zs[:, 2], **common_params, facecolor='g', label='Exact') + + ax1.set_xlabel('X') + ax2.set_xlabel('Y') + ax3.set_xlabel('Z') + + for ax in [ax1, ax2, ax3]: + ax.set_xlim(-0.01, 1.01) + + ax1.legend(loc=0) + fig3.tight_layout() + plt.show() diff --git a/crystal_diffusion/analysis/analytic_score/sample_quality_with_energy_analysis.py b/crystal_diffusion/analysis/analytic_score/sample_quality_with_energy_analysis.py index 666d066a..f8a0444b 100644 --- a/crystal_diffusion/analysis/analytic_score/sample_quality_with_energy_analysis.py +++ b/crystal_diffusion/analysis/analytic_score/sample_quality_with_energy_analysis.py @@ -12,10 +12,10 @@ from crystal_diffusion.analysis.analytic_score.utils import ( get_exact_samples, get_random_equilibrium_relative_coordinates, get_random_inverse_covariance, get_samples_harmonic_energy, get_unit_cells) +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + AnnealedLangevinDynamicsGenerator from crystal_diffusion.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) -from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ - AnnealedLangevinDynamicsSampler from crystal_diffusion.samplers.variance_sampler import NoiseParameters logger = logging.getLogger(__name__) @@ -64,8 +64,8 @@ record_samples=False, positions_require_grad=False) - pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network, - **sampler_parameters) + pc_sampler = AnnealedLangevinDynamicsGenerator(sigma_normalized_score_network=sigma_normalized_score_network, + **sampler_parameters) unit_cell = get_unit_cells(acell=1., spatial_dimension=spatial_dimension, number_of_samples=number_of_samples) diff --git a/crystal_diffusion/analysis/analytic_score/sample_visualization_analysis.py b/crystal_diffusion/analysis/analytic_score/sample_visualization_analysis.py index 5265ca9d..373d371c 100644 --- a/crystal_diffusion/analysis/analytic_score/sample_visualization_analysis.py +++ b/crystal_diffusion/analysis/analytic_score/sample_visualization_analysis.py @@ -9,10 +9,10 @@ from crystal_diffusion.analysis.analytic_score.utils import ( get_exact_samples, get_random_equilibrium_relative_coordinates, get_random_inverse_covariance, get_unit_cells) +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + AnnealedLangevinDynamicsGenerator from crystal_diffusion.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) -from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ - AnnealedLangevinDynamicsSampler from crystal_diffusion.samplers.variance_sampler import NoiseParameters logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ record_samples=False, positions_require_grad=False) - pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network, - **sampler_parameters) + pc_sampler = AnnealedLangevinDynamicsGenerator(sigma_normalized_score_network=sigma_normalized_score_network, + **sampler_parameters) unit_cell = get_unit_cells(acell=1., spatial_dimension=spatial_dimension, number_of_samples=number_of_samples) diff --git a/crystal_diffusion/analysis/analytic_score/utils.py b/crystal_diffusion/analysis/analytic_score/utils.py index e72977a8..073a6be0 100644 --- a/crystal_diffusion/analysis/analytic_score/utils.py +++ b/crystal_diffusion/analysis/analytic_score/utils.py @@ -40,6 +40,7 @@ def get_random_inverse_covariance(spring_constant_scale: float, number_of_atoms: def get_exact_samples(equilibrium_relative_coordinates: torch.Tensor, inverse_covariance: torch.Tensor, number_of_samples: int) -> torch.Tensor: """Sample the exact harmonic energy.""" + device = equilibrium_relative_coordinates.device natom, spatial_dimension, _, _ = inverse_covariance.shape flat_dim = natom * spatial_dimension @@ -53,7 +54,7 @@ def get_exact_samples(equilibrium_relative_coordinates: torch.Tensor, inverse_co sigmas = 1. / torch.sqrt(eigenvalues) - z_scores = torch.randn(number_of_samples, flat_dim) + z_scores = torch.randn(number_of_samples, flat_dim).to(device) sigma_z_scores = z_scores * sigmas.unsqueeze(0) diff --git a/crystal_diffusion/analysis/diffusion_sample_position_analysis.py b/crystal_diffusion/analysis/diffusion_sample_position_analysis.py index 6cf68c1c..0afacbc6 100644 --- a/crystal_diffusion/analysis/diffusion_sample_position_analysis.py +++ b/crystal_diffusion/analysis/diffusion_sample_position_analysis.py @@ -9,7 +9,8 @@ from crystal_diffusion import TOP_DIR from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH -from crystal_diffusion.utils.sample_trajectory import SampleTrajectory +from crystal_diffusion.utils.sample_trajectory import \ + PredictorCorrectorSampleTrajectory plt.style.use(PLOT_STYLE_PATH) @@ -34,10 +35,10 @@ energies = torch.load(energy_sample_directory / f"energies_sample_epoch={epoch}.pt") pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt" - sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path) + sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path) pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt" - sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path) + sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path) list_predictor_coordinates = sample_trajectory.data['predictor_x_i'] float_datatype = list_predictor_coordinates[0].dtype diff --git a/crystal_diffusion/analysis/positions_to_cif_files.py b/crystal_diffusion/analysis/positions_to_cif_files.py index 12c88bf2..994c7d9d 100644 --- a/crystal_diffusion/analysis/positions_to_cif_files.py +++ b/crystal_diffusion/analysis/positions_to_cif_files.py @@ -7,7 +7,8 @@ from pymatgen.core import Lattice, Structure from crystal_diffusion import TOP_DIR -from crystal_diffusion.utils.sample_trajectory import SampleTrajectory +from crystal_diffusion.utils.sample_trajectory import \ + PredictorCorrectorSampleTrajectory # Hard coding some paths to local results. Modify as needed... epoch = 35 @@ -25,10 +26,10 @@ if __name__ == '__main__': pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt" - sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path) + sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path) pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt" - sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path) + sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path) basis_vectors = sample_trajectory.data['unit_cell'][sample_idx].numpy() lattice = Lattice(matrix=basis_vectors, pbc=(True, True, True)) diff --git a/crystal_diffusion/callbacks/sampling_callback.py b/crystal_diffusion/callbacks/sampling_callback.py index 380806a8..087cbce2 100644 --- a/crystal_diffusion/callbacks/sampling_callback.py +++ b/crystal_diffusion/callbacks/sampling_callback.py @@ -12,10 +12,13 @@ from pytorch_lightning import Callback, LightningModule, Trainer from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.generators.ode_position_generator import \ + ExplodingVarianceODEPositionGenerator +from crystal_diffusion.generators.position_generator import PositionGenerator +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + AnnealedLangevinDynamicsGenerator from crystal_diffusion.loggers.logger_loader import log_figure from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps -from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ - AnnealedLangevinDynamicsSampler from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.basis_transformations import \ get_positions_from_coordinates @@ -29,6 +32,7 @@ @dataclass(kw_only=True) class SamplingParameters: """Hyper-parameters for diffusion sampling.""" + algorithm: str spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. number_of_corrector_steps: int = 1 number_of_atoms: int # the number of atoms that must be generated in a sampled configuration. @@ -40,17 +44,45 @@ class SamplingParameters: record_samples: bool = False # should the predictor and corrector steps be recorded to a file +@dataclass(kw_only=True) +class PredictorCorrectorSamplingParameters(SamplingParameters): + """Hyper-parameters for diffusion sampling with the predictor-corrector algorithm.""" + algorithm: str = 'predictor_corrector' + number_of_corrector_steps: int = 1 + + +@dataclass(kw_only=True) +class ODESamplingParameters(SamplingParameters): + """Hyper-parameters for diffusion sampling with the ode algorithm.""" + algorithm: str = 'ode' + + def instantiate_diffusion_sampling_callback(callback_params: Dict[AnyStr, Any], output_directory: str, verbose: bool) -> Dict[str, Callback]: """Instantiate the Diffusion Sampling callback.""" noise_parameters = NoiseParameters(**callback_params['noise']) - sampling_parameters = SamplingParameters(**callback_params['sampling']) - - diffusion_sampling_callback = DiffusionSamplingCallback(noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=output_directory) + sampling_parameter_dictionary = callback_params['sampling'] + assert 'algorithm' in sampling_parameter_dictionary, "The sampling parameters must select an algorithm." + algorithm = sampling_parameter_dictionary['algorithm'] + + assert algorithm in ['ode', 'predictor_corrector'], \ + "Unknown algorithm. Possible choices are 'ode' and 'predictor_corrector'" + + if algorithm == 'predictor_corrector': + sampling_parameters = PredictorCorrectorSamplingParameters(**sampling_parameter_dictionary) + diffusion_sampling_callback = ( + PredictorCorrectorDiffusionSamplingCallback(noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + output_directory=output_directory)) + elif algorithm == 'ode': + sampling_parameters = ODESamplingParameters(**sampling_parameter_dictionary) + diffusion_sampling_callback = ODEDiffusionSamplingCallback(noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + output_directory=output_directory) + else: + raise NotImplementedError("algorithm is not implemented") return dict(diffusion_sampling=diffusion_sampling_callback) @@ -132,26 +164,17 @@ def _initialize_validation_energies_array(self): # data does not change, we will avoid having this in memory at all times. self.validation_energies = np.array([]) - def _create_sampler(self, pl_model: LightningModule) -> Tuple[AnnealedLangevinDynamicsSampler, torch.Tensor]: + def _create_generator(self, pl_model: LightningModule) -> PositionGenerator: """Draw a sample from the generative model.""" - logger.info("Creating sampler") - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - sampler_parameters = dict(noise_parameters=self.noise_parameters, - number_of_corrector_steps=self.sampling_parameters.number_of_corrector_steps, - number_of_atoms=self.sampling_parameters.number_of_atoms, - spatial_dimension=self.sampling_parameters.spatial_dimension, - record_samples=self.sampling_parameters.record_samples, - positions_require_grad=pl_model.grads_are_needed_in_inference) + raise NotImplementedError("This method must be implemented in a child class") - pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network, - **sampler_parameters) + def _create_unit_cell(self, pl_model) -> torch.Tensor: + """Create the batch of unit cells needed by the generative model.""" # TODO we will have to sample unit cell dimensions at some points instead of working with fixed size unit_cell = (self._get_orthogonal_unit_cell(batch_size=self.sampling_parameters.number_of_samples, cell_dimensions=self.sampling_parameters.cell_dimensions) .to(pl_model.device)) - - return pc_sampler, unit_cell + return unit_cell @staticmethod def _plot_energy_histogram(sample_energies: np.ndarray, validation_dataset_energies: np.array, @@ -220,7 +243,8 @@ def sample_and_evaluate_energy(self, pl_model: LightningModule, current_epoch: i Returns: array with energy of each sample from LAMMPS """ - pc_sampler, unit_cell = self._create_sampler(pl_model) + generator = self._create_generator(pl_model) + unit_cell = self._create_unit_cell(pl_model) logger.info("Draw samples") @@ -232,17 +256,17 @@ def sample_and_evaluate_energy(self, pl_model: LightningModule, current_epoch: i for n in range(0, self.sampling_parameters.number_of_samples, self.sampling_parameters.sample_batchsize): unit_cell_ = unit_cell[n:min(n + self.sampling_parameters.sample_batchsize, self.sampling_parameters.number_of_samples)] - samples = pc_sampler.sample(min(self.sampling_parameters.number_of_samples - n, - self.sampling_parameters.sample_batchsize), - device=pl_model.device, - unit_cell=unit_cell_) + samples = generator.sample(min(self.sampling_parameters.number_of_samples - n, + self.sampling_parameters.sample_batchsize), + device=pl_model.device, + unit_cell=unit_cell_) if self.sampling_parameters.record_samples: sample_output_path = os.path.join(self.position_sample_output_directory, f"diffusion_position_sample_epoch={current_epoch}" + f"_steps={n}.pt") # write trajectories to disk and reset to save memory - pc_sampler.sample_trajectory_recorder.write_to_pickle(sample_output_path) - pc_sampler.sample_trajectory_recorder.reset() + generator.sample_trajectory_recorder.write_to_pickle(sample_output_path) + generator.sample_trajectory_recorder.reset() batch_relative_coordinates = samples.detach().cpu() sample_energies += [self._compute_oracle_energies(batch_relative_coordinates)] @@ -280,3 +304,43 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_model: LightningModule) - log_figure(figure=fig, global_step=trainer.global_step, pl_logger=pl_logger) self._initialize_validation_energies_array() + + +class PredictorCorrectorDiffusionSamplingCallback(DiffusionSamplingCallback): + """Callback class to periodically generate samples and log their energies.""" + + def _create_generator(self, pl_model: LightningModule) -> AnnealedLangevinDynamicsGenerator: + """Draw a sample from the generative model.""" + logger.info("Creating sampler") + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + sampler_parameters = dict(noise_parameters=self.noise_parameters, + number_of_corrector_steps=self.sampling_parameters.number_of_corrector_steps, + number_of_atoms=self.sampling_parameters.number_of_atoms, + spatial_dimension=self.sampling_parameters.spatial_dimension, + record_samples=self.sampling_parameters.record_samples, + positions_require_grad=pl_model.grads_are_needed_in_inference) + + generator = AnnealedLangevinDynamicsGenerator(sigma_normalized_score_network=sigma_normalized_score_network, + **sampler_parameters) + + return generator + + +class ODEDiffusionSamplingCallback(DiffusionSamplingCallback): + """Callback class to periodically generate samples and log their energies.""" + + def _create_generator(self, pl_model: LightningModule) -> ExplodingVarianceODEPositionGenerator: + """Draw a sample from the generative model.""" + logger.info("Creating sampler") + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + sampler_parameters = dict(noise_parameters=self.noise_parameters, + number_of_atoms=self.sampling_parameters.number_of_atoms, + spatial_dimension=self.sampling_parameters.spatial_dimension, + record_samples=self.sampling_parameters.record_samples) + + generator = ExplodingVarianceODEPositionGenerator(sigma_normalized_score_network=sigma_normalized_score_network, + **sampler_parameters) + + return generator diff --git a/crystal_diffusion/generators/__init__.py b/crystal_diffusion/generators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/crystal_diffusion/generators/ode_position_generator.py b/crystal_diffusion/generators/ode_position_generator.py new file mode 100644 index 00000000..6391d990 --- /dev/null +++ b/crystal_diffusion/generators/ode_position_generator.py @@ -0,0 +1,210 @@ +import logging +from typing import Callable + +import einops +import torch +import torchode as to + +from crystal_diffusion.generators.position_generator import PositionGenerator +from crystal_diffusion.models.score_networks.score_network import ScoreNetwork +from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, + NOISY_RELATIVE_COORDINATES, TIME, + UNIT_CELL) +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.sample_trajectory import (NoOpODESampleTrajectory, + ODESampleTrajectory) + +logger = logging.getLogger(__name__) + + +class ExplodingVarianceODEPositionGenerator(PositionGenerator): + """Exploding Variance ODE Position Generator. + + This class generates position samples by solving an ordinary differential equation (ODE). + It assumes that the diffusion noise is parameterized in the 'Exploding Variance' scheme. + """ + + def __init__(self, + noise_parameters: NoiseParameters, + number_of_atoms: int, + spatial_dimension: int, + sigma_normalized_score_network: ScoreNetwork, + record_samples: bool = False, + ): + """Init method. + + Args: + noise_parameters : the diffusion noise parameters. + number_of_atoms : the number of atoms to sample. + spatial_dimension : the dimension of space. + sigma_normalized_score_network : the score network to use for drawing samples. + record_samples : should samples be recorded. + """ + self.noise_parameters = noise_parameters + assert self.noise_parameters.total_time_steps >= 2, \ + "There must at least be two time steps in the noise parameters to define the limits t0 and tf." + self.number_of_atoms = number_of_atoms + self.spatial_dimension = spatial_dimension + + self.sigma_normalized_score_network = sigma_normalized_score_network + + self.t0 = 0.0 # The "initial diffusion time", corresponding to the physical distribution. + self.tf = 1.0 # The "final diffusion time", corresponding to the uniform distribution. + + self.record_samples = record_samples + + if record_samples: + self.sample_trajectory_recorder = ODESampleTrajectory() + else: + self.sample_trajectory_recorder = NoOpODESampleTrajectory() + + def _get_exploding_variance_sigma(self, times): + """Get Exploding Variance Sigma. + + In the 'exploding variance' scheme, the noise is defined by + + sigma(t) = sigma_min^{1- t} x sigma_max^{t} + + Args: + times : diffusion time + + Returns: + sigmas: value of the noise parameter. + """ + sigmas = self.noise_parameters.sigma_min ** (1.0 - times) * self.noise_parameters.sigma_max ** times + return sigmas + + def _get_ode_prefactor(self, sigmas): + """Get ODE prefactor. + + The ODE is given by + dx = [-1/2 g(t)^2 x Score] dt + with + g(t)^2 = d sigma(t)^2 / dt + + We can rearrange the ODE to: + + dx = -[1/2 g(t)^2 / sigma] x sigma Score + --------v----------- + Prefactor. + + The prefactor is then given by + + Prefactor = d sigma(t) / dt + + Args: + sigmas : the values of the noise parameters. + + Returns: + ode prefactor: the prefactor in the ODE. + """ + log_ratio = torch.log(torch.tensor(self.noise_parameters.sigma_max / self.noise_parameters.sigma_min)) + ode_prefactor = log_ratio * sigmas + return ode_prefactor + + def generate_ode_term(self, unit_cell: torch.Tensor) -> Callable: + """Generate the ode_term needed to compute the ODE solution.""" + + def ode_term(times: torch.Tensor, flat_relative_coordinates: torch.Tensor) -> torch.Tensor: + """ODE term. + + This function is in the format required by the ODE solver. + + The ODE solver expect the features to be bi-dimensional, ie [batch, feature size]. + + Args: + times : ODE times, dimension [batch_size] + flat_relative_coordinates : features for every time step, dimension [batch_size, number of features]. + + Returns: + rhs: the right-hand-side of the corresponding ODE. + """ + sigmas = self._get_exploding_variance_sigma(times) + ode_prefactor = self._get_ode_prefactor(sigmas) + + relative_coordinates = einops.rearrange(flat_relative_coordinates, + "batch (natom space) -> batch natom space", + natom=self.number_of_atoms, + space=self.spatial_dimension) + + batch = {NOISY_RELATIVE_COORDINATES: map_relative_coordinates_to_unit_cell(relative_coordinates), + NOISE: sigmas.unsqueeze(-1), + TIME: times.unsqueeze(-1), + UNIT_CELL: unit_cell, + CARTESIAN_FORCES: torch.zeros_like(relative_coordinates) # TODO: handle forces correctly. + } + + # Shape [batch_size, number of atoms, spatial dimension] + sigma_normalized_scores = self.sigma_normalized_score_network(batch) + flat_sigma_normalized_scores = einops.rearrange(sigma_normalized_scores, + "batch natom space -> batch (natom space)") + + return -ode_prefactor.unsqueeze(-1) * flat_sigma_normalized_scores + + return ode_term + + def sample(self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor) -> torch.Tensor: + """Sample. + + This method draws a position sample. + + Args: + number_of_samples : number of samples to draw. + device: device to use (cpu, cuda, etc.). Should match the PL model location. + unit_cell: unit cell definition in Angstrom. + Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] + + Returns: + samples: relative coordinates samples. + """ + ode_term = self.generate_ode_term(unit_cell) + + initial_relative_coordinates = ( + map_relative_coordinates_to_unit_cell(self.initialize(number_of_samples)).to(device)) + + y0 = einops.rearrange(initial_relative_coordinates, 'batch natom space -> batch (natom space)') + + evaluation_times = torch.linspace(self.tf, self.t0, self.noise_parameters.total_time_steps).to(device) + + t_eval = einops.repeat(evaluation_times, 't -> batch t', batch=number_of_samples) + + term = to.ODETerm(ode_term) + step_method = to.Dopri5(term=term) + # TODO: parameterize the tolerances + step_size_controller = to.IntegralController(atol=1e-3, rtol=1e-3, term=term) + solver = to.AutoDiffAdjoint(step_method, step_size_controller) + jit_solver = torch.compile(solver) + + logger.info("Starting ODE solver...") + sol = jit_solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval)) + logger.info("ODE solver Finished.") + + if self.record_samples: + # Only do these operations if they are required! + self.sample_trajectory_recorder.record_unit_cell(unit_cell) + record_relative_coordinates = einops.rearrange(sol.ys, + 'batch times (natom space) -> batch times natom space', + natom=self.number_of_atoms, + space=self.spatial_dimension) + self.sample_trajectory_recorder.record_ode_solution(times=sol.ts, + relative_coordinates=record_relative_coordinates, + stats=sol.stats, + status=sol.status) + + # sol.ys has dimensions [number of samples, number of times, number of features] + # only the final time (ie, t0) is the real sample. + flat_relative_coordinates = sol.ys[:, -1, :] + + relative_coordinates = einops.rearrange(flat_relative_coordinates, + 'batch (natom space) -> batch natom space', + natom=self.number_of_atoms, + space=self.spatial_dimension) + + return map_relative_coordinates_to_unit_cell(relative_coordinates) + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = torch.rand(number_of_samples, self.number_of_atoms, self.spatial_dimension) + return relative_coordinates diff --git a/crystal_diffusion/generators/position_generator.py b/crystal_diffusion/generators/position_generator.py new file mode 100644 index 00000000..3d0d45f6 --- /dev/null +++ b/crystal_diffusion/generators/position_generator.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +import torch + + +class PositionGenerator(ABC): + """This defines the interface for position generators.""" + + @abstractmethod + def sample(self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor) -> torch.Tensor: + """Sample. + + This method draws a position sample. + + Args: + number_of_samples : number of samples to draw. + device: device to use (cpu, cuda, etc.). Should match the PL model location. + unit_cell: unit cell definition in Angstrom. + Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] + + Returns: + samples: relative coordinates samples. + """ + pass + + @abstractmethod + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + pass diff --git a/crystal_diffusion/samplers/predictor_corrector_position_sampler.py b/crystal_diffusion/generators/predictor_corrector_position_generator.py similarity index 92% rename from crystal_diffusion/samplers/predictor_corrector_position_sampler.py rename to crystal_diffusion/generators/predictor_corrector_position_generator.py index f9f83ecc..6b05e85f 100644 --- a/crystal_diffusion/samplers/predictor_corrector_position_sampler.py +++ b/crystal_diffusion/generators/predictor_corrector_position_generator.py @@ -1,9 +1,10 @@ import logging -from abc import ABC, abstractmethod +from abc import abstractmethod import torch from tqdm import tqdm +from crystal_diffusion.generators.position_generator import PositionGenerator from crystal_diffusion.models.score_networks.score_network import ScoreNetwork from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, @@ -12,14 +13,14 @@ ExplodingVarianceSampler, NoiseParameters) from crystal_diffusion.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell -from crystal_diffusion.utils.sample_trajectory import (NoOpSampleTrajectory, - SampleTrajectory) +from crystal_diffusion.utils.sample_trajectory import ( + NoOpPredictorCorrectorSampleTrajectory, PredictorCorrectorSampleTrajectory) logger = logging.getLogger(__name__) -class PredictorCorrectorPositionSampler(ABC): - """This defines the interface for position samplers.""" +class PredictorCorrectorPositionGenerator(PositionGenerator): + """This defines the interface for predictor-corrector position generators.""" def __init__(self, number_of_discretization_steps: int, number_of_corrector_steps: int, spatial_dimension: int, **kwargs): @@ -34,7 +35,7 @@ def __init__(self, number_of_discretization_steps: int, number_of_corrector_step def sample(self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor) -> torch.Tensor: """Sample. - This method draws a sample using the PR sampler algorithm. + This method draws a sample using the PC sampler algorithm. Args: number_of_samples : number of samples to draw. @@ -59,11 +60,6 @@ def sample(self, number_of_samples: int, device: torch.device, unit_cell: torch. x_ip1 = x_i return x_i - @abstractmethod - def initialize(self, number_of_samples: int): - """This method must initialize the samples from the fully noised distribution.""" - pass - @abstractmethod def predictor_step(self, x_ip1: torch.Tensor, ip1: int, unit_cell: torch.Tensor, cartesian_forces: torch.Tensor ) -> torch.Tensor: @@ -101,10 +97,10 @@ def corrector_step(self, x_i: torch.Tensor, i: int, unit_cell: torch.Tensor, car pass -class AnnealedLangevinDynamicsSampler(PredictorCorrectorPositionSampler): - """Annealed Langevin Dynamics Sampler. +class AnnealedLangevinDynamicsGenerator(PredictorCorrectorPositionGenerator): + """Annealed Langevin Dynamics Generator. - This class implements the annealed Langevin Dynamics sampling of + This class implements the annealed Langevin Dynamics generation of position samples, following Song & Ermon 2019, namely: "Generative Modeling by Estimating Gradients of the Data Distribution" """ @@ -130,9 +126,9 @@ def __init__(self, self.sigma_normalized_score_network = sigma_normalized_score_network if record_samples: - self.sample_trajectory_recorder = SampleTrajectory() + self.sample_trajectory_recorder = PredictorCorrectorSampleTrajectory() else: - self.sample_trajectory_recorder = NoOpSampleTrajectory() + self.sample_trajectory_recorder = NoOpPredictorCorrectorSampleTrajectory() def initialize(self, number_of_samples: int): """This method must initialize the samples from the fully noised distribution.""" diff --git a/crystal_diffusion/models/score_networks/analytical_score_network.py b/crystal_diffusion/models/score_networks/analytical_score_network.py index ea17a041..e1b37164 100644 --- a/crystal_diffusion/models/score_networks/analytical_score_network.py +++ b/crystal_diffusion/models/score_networks/analytical_score_network.py @@ -54,6 +54,8 @@ def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): self.flat_dim = self.natoms * self.spatial_dimension + self.device = hyper_params.equilibrium_relative_coordinates.device + assert hyper_params.equilibrium_relative_coordinates.shape == (self.natoms, self.spatial_dimension), \ "equilibrium relative coordinates have the wrong shape" @@ -69,10 +71,10 @@ def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): 'permutation natom d -> permutation (natom d)') # shape: [ (2 kmax + 1)^flat_dim, flat_dim] - self.translations_k = self._get_all_translations(self.kmax, self.flat_dim) + self.translations_k = self._get_all_translations(self.kmax, self.flat_dim).to(self.device) # shape [ (2 kmax + 1)^flat_dim x natom!, flat_dim] - self.all_offsets = self._get_all_flat_offsets(self.permutations_x0, self.translations_k) + self.all_offsets = self._get_all_flat_offsets(self.permutations_x0, self.translations_k).to(self.device) self.beta_phi_matrix = einops.rearrange(hyper_params.inverse_covariance, "natom1 d1 natom2 d2 -> (natom1 d1) (natom2 d2)") diff --git a/crystal_diffusion/utils/sample_trajectory.py b/crystal_diffusion/utils/sample_trajectory.py index 29bb5a26..4bfa507a 100644 --- a/crystal_diffusion/utils/sample_trajectory.py +++ b/crystal_diffusion/utils/sample_trajectory.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Dict import torch @@ -22,6 +23,59 @@ def record_unit_cell(self, unit_cell: torch.Tensor): """Record unit cell.""" self.data['unit_cell'] = unit_cell.detach().cpu() + def write_to_pickle(self, path_to_pickle: str): + """Write data to pickle file.""" + with open(path_to_pickle, 'wb') as fd: + torch.save(self.data, fd) + + +class ODESampleTrajectory(SampleTrajectory): + """ODE Sample Trajectory. + + This class aims to record all details of the ODE diffusion sampling process. The goal is to produce + an artifact that can then be analyzed off-line. + """ + + def record_ode_solution(self, times: torch.Tensor, relative_coordinates: torch.Tensor, + stats: Dict, status: torch.Tensor): + """Record ODE solution information.""" + self.data['time'].append(times) + self.data['stats'].append(stats) + self.data['status'].append(status) + self.data['relative_coordinates'].append(relative_coordinates) + + @staticmethod + def read_from_pickle(path_to_pickle: str): + """Read from pickle.""" + with open(path_to_pickle, 'rb') as fd: + sample_trajectory = ODESampleTrajectory() + sample_trajectory.data = torch.load(fd, map_location=torch.device('cpu')) + return sample_trajectory + + +class NoOpODESampleTrajectory(ODESampleTrajectory): + """A sample trajectory object that performs no operation.""" + + def record_unit_cell(self, unit_cell: torch.Tensor): + """No Op.""" + return + + def record_ode_solution(self, times: torch.Tensor, relative_coordinates: torch.Tensor, + stats: Dict, status: torch.Tensor): + """No Op.""" + return + + def write_to_pickle(self, path_to_pickle: str): + """No Op.""" + return + + +class PredictorCorrectorSampleTrajectory(SampleTrajectory): + """Predictor Corrector Sample Trajectory. + + This class aims to record all details of the predictor-corrector diffusion sampling process. The goal is to produce + an artifact that can then be analyzed off-line. + """ def record_predictor_step(self, i_index: int, time: float, sigma: float, x_i: torch.Tensor, x_im1: torch.Tensor, scores: torch.Tensor): """Record predictor step.""" @@ -42,21 +96,16 @@ def record_corrector_step(self, i_index: int, time: float, sigma: float, self.data['corrector_corrected_x_i'].append(corrected_x_i.detach().cpu()) self.data['corrector_scores'].append(scores.detach().cpu()) - def write_to_pickle(self, path_to_pickle: str): - """Write data to pickle file.""" - with open(path_to_pickle, 'wb') as fd: - torch.save(self.data, fd) - @staticmethod def read_from_pickle(path_to_pickle: str): """Read from pickle.""" with open(path_to_pickle, 'rb') as fd: - sample_trajectory = SampleTrajectory() + sample_trajectory = PredictorCorrectorSampleTrajectory() sample_trajectory.data = torch.load(fd, map_location=torch.device('cpu')) return sample_trajectory -class NoOpSampleTrajectory(SampleTrajectory): +class NoOpPredictorCorrectorSampleTrajectory(PredictorCorrectorSampleTrajectory): """A sample trajectory object that performs no operation.""" def record_unit_cell(self, unit_cell: torch.Tensor): diff --git a/examples/local/diffusion/config_diffusion_mace.yaml b/examples/local/diffusion/config_diffusion_mace.yaml index bfc67af5..a9c4df79 100644 --- a/examples/local/diffusion/config_diffusion_mace.yaml +++ b/examples/local/diffusion/config_diffusion_mace.yaml @@ -70,11 +70,12 @@ diffusion_sampling: sigma_min: 0.001 # default value sigma_max: 0.5 # default value sampling: + algorithm: ode spatial_dimension: 3 - number_of_corrector_steps: 1 number_of_atoms: 8 number_of_samples: 16 - sample_every_n_epochs: 1 + sample_every_n_epochs: 5 + record_samples: True cell_dimensions: [5.43, 5.43, 5.43] logging: diff --git a/examples/local/diffusion/config_diffusion_mlp.yaml b/examples/local/diffusion/config_diffusion_mlp.yaml index 6815e1b5..983c9be3 100644 --- a/examples/local/diffusion/config_diffusion_mlp.yaml +++ b/examples/local/diffusion/config_diffusion_mlp.yaml @@ -1,7 +1,7 @@ # general exp_name: mlp_example -run_name: run_debug_delete_me -max_epoch: 10 +run_name: run2 +max_epoch: 500 log_every_n_steps: 1 gradient_clipping: 0 accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step @@ -55,7 +55,7 @@ model_checkpoint: # A callback to check the loss vs. sigma loss_monitoring: number_of_bins: 50 - sample_every_n_epochs: 2 + sample_every_n_epochs: 25 # Sampling from the generative model diffusion_sampling: @@ -64,15 +64,16 @@ diffusion_sampling: sigma_min: 0.001 # default value sigma_max: 0.5 # default value sampling: + algorithm: ode spatial_dimension: 3 - number_of_corrector_steps: 1 number_of_atoms: 8 number_of_samples: 16 - sample_batchsize: None - sample_every_n_epochs: 2 + sample_batchsize: 16 + sample_every_n_epochs: 25 + record_samples: True cell_dimensions: [5.43, 5.43, 5.43] logging: - - comet -#- tensorboard +# - comet +- tensorboard #- csv diff --git a/examples/local/diffusion/config_mace_equivariant_head.yaml b/examples/local/diffusion/config_mace_equivariant_head.yaml index 6a3bf416..f0485c94 100644 --- a/examples/local/diffusion/config_mace_equivariant_head.yaml +++ b/examples/local/diffusion/config_mace_equivariant_head.yaml @@ -73,6 +73,7 @@ diffusion_sampling: sigma_min: 0.005 # default value sigma_max: 0.5 # default value sampling: + algorithm: predictor_corrector spatial_dimension: 3 number_of_corrector_steps: 1 number_of_atoms: 8 diff --git a/examples/local/diffusion/config_mace_mlp_head.yaml b/examples/local/diffusion/config_mace_mlp_head.yaml index cfcbb93a..415f994f 100644 --- a/examples/local/diffusion/config_mace_mlp_head.yaml +++ b/examples/local/diffusion/config_mace_mlp_head.yaml @@ -74,6 +74,7 @@ diffusion_sampling: sigma_min: 0.005 # default value sigma_max: 0.5 # default value sampling: + algorithm: predictor_corrector spatial_dimension: 3 number_of_corrector_steps: 1 number_of_atoms: 8 diff --git a/experiment_analysis/sampling_analysis/diffusion_mace_ode_sampling_analysis.py b/experiment_analysis/sampling_analysis/diffusion_mace_ode_sampling_analysis.py new file mode 100644 index 00000000..6765f399 --- /dev/null +++ b/experiment_analysis/sampling_analysis/diffusion_mace_ode_sampling_analysis.py @@ -0,0 +1,214 @@ +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +from einops import einops + +from crystal_diffusion import DATA_DIR +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.data.diffusion.data_loader import ( + LammpsForDiffusionDataModule, LammpsLoaderParameters) +from crystal_diffusion.models.mace_utils import get_adj_matrix +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from experiment_analysis import EXPERIMENT_ANALYSIS_DIR + + +def get_interatomic_distances(cartesian_positions: torch.Tensor, + basis_vectors: torch.Tensor, + radial_cutoff: float = 5.0): + """Get interatomic distances.""" + shifted_adjacency_matrix, shifts, batch_indices = get_adj_matrix(positions=cartesian_positions, + basis_vectors=basis_vectors, + radial_cutoff=radial_cutoff) + + flat_positions = einops.rearrange(cartesian_positions, "b n d -> (b n) d") + + displacements = flat_positions[shifted_adjacency_matrix[1]] - flat_positions[shifted_adjacency_matrix[0]] + shifts + interatomic_distances = torch.linalg.norm(displacements, dim=1) + return interatomic_distances + + +logger = logging.getLogger(__name__) + +plt.style.use(PLOT_STYLE_PATH) + + +# Some hardcoded paths and parameters. Change as needed! +epoch = 30 +base_data_dir = Path("/Users/bruno/courtois/difface_ode/run7") +position_samples_dir = base_data_dir / "diffusion_position_samples" +energy_samples_dir = base_data_dir / "energy_samples" + + +dataset_name = "si_diffusion_1x1x1" +lammps_run_dir = str(DATA_DIR / dataset_name) +processed_dataset_dir = str(DATA_DIR / dataset_name / 'processed') +data_params = LammpsLoaderParameters(batch_size=64, max_atom=8) +cache_dir = str(EXPERIMENT_ANALYSIS_DIR / "cache" / dataset_name) + + +if __name__ == '__main__': + + datamodule = LammpsForDiffusionDataModule( + lammps_run_dir=lammps_run_dir, + processed_dataset_dir=processed_dataset_dir, + hyper_params=data_params, + working_cache_dir=cache_dir, + ) + + datamodule.setup() + + train_dataset = datamodule.train_dataset + batch = train_dataset[:1000] + + positions_data = torch.load(position_samples_dir / f"diffusion_position_sample_epoch={epoch}_steps=0.pt", + map_location=torch.device('cpu')) + + unit_cell = positions_data['unit_cell'] + + batch_times = positions_data['time'][0] + batch_noisy_relative_coordinates = positions_data['relative_coordinates'][0] + number_of_atoms, spatial_dimension = batch_noisy_relative_coordinates.shape[-2:] + + batch_flat_noisy_relative_coordinates = einops.rearrange(batch_noisy_relative_coordinates, + "b t n d -> b t (n d)") + + fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig1.suptitle(f'ODE Trajectories: Sample at Epoch {epoch}') + + ax = fig1.add_subplot(111) + ax.set_xlabel('Diffusion Time') + ax.set_ylabel('Raw Relative Coordinate') + ax.yaxis.tick_right() + ax.spines['top'].set_visible(True) + ax.spines['right'].set_visible(True) + ax.spines['bottom'].set_visible(True) + ax.spines['left'].set_visible(True) + + time = batch_times[0] # all time arrays are the same + for flat_relative_coordinates in batch_flat_noisy_relative_coordinates: + for i in range(number_of_atoms * spatial_dimension): + coordinate = flat_relative_coordinates[:, i] + ax.plot(time.cpu(), coordinate.cpu(), '-', color='b', alpha=0.05) + + ax.set_xlim([1.01, -0.01]) + ax.set_ylim([-2.0, 2.0]) + plt.show() + + training_relative_coordinates = batch['relative_coordinates'] + training_center_of_mass = training_relative_coordinates.mean(dim=1).mean(dim=0) + + raw_sample_relative_coordinates = map_relative_coordinates_to_unit_cell(batch_noisy_relative_coordinates[:, -1]) + raw_sample_centers_of_mass = raw_sample_relative_coordinates.mean(dim=1) + + zero_centered_sample_relative_coordinates = (raw_sample_relative_coordinates + - raw_sample_centers_of_mass.unsqueeze(1)) + sample_relative_coordinates = (zero_centered_sample_relative_coordinates + + training_center_of_mass.unsqueeze(0).unsqueeze(0)) + + sample_relative_coordinates = map_relative_coordinates_to_unit_cell(sample_relative_coordinates) + + fig2 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig2.suptitle(f'ODE Marginal Distributions, Sample at Epoch {epoch}') + ax1 = fig2.add_subplot(131, aspect='equal') + ax2 = fig2.add_subplot(132, aspect='equal') + ax3 = fig2.add_subplot(133, aspect='equal') + + xs = einops.rearrange(sample_relative_coordinates, 'b n d -> (b n) d') + zs = einops.rearrange(training_relative_coordinates, 'b n d -> (b n) d') + ax1.set_title('XY Projection') + ax1.plot(xs[:, 0], xs[:, 1], 'ro', alpha=0.5, mew=0, label='ODE Solver Samples') + ax1.plot(zs[:, 0], zs[:, 1], 'go', alpha=0.05, mew=0, label='Training Data') + + ax2.set_title('XZ Projection') + ax2.plot(xs[:, 0], xs[:, 2], 'ro', alpha=0.5, mew=0, label='ODE Solver Samples') + ax2.plot(zs[:, 0], zs[:, 2], 'go', alpha=0.05, mew=0, label='Training Data') + + ax3.set_title('YZ Projection') + ax3.plot(xs[:, 1], xs[:, 2], 'ro', alpha=0.5, mew=0, label='ODE Solver Samples') + ax3.plot(zs[:, 1], zs[:, 2], 'go', alpha=0.05, mew=0, label='Training Data') + + for ax in [ax1, ax2, ax3]: + ax.set_xlim(-0.01, 1.01) + ax.set_ylim(-0.01, 1.01) + ax.vlines(x=[0, 1], ymin=0, ymax=1, color='k', lw=2) + ax.hlines(y=[0, 1], xmin=0, xmax=1, color='k', lw=2) + + ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -0.5), ncol=2, fancybox=True, shadow=True) + fig2.tight_layout() + plt.show() + + fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE) + ax1 = fig3.add_subplot(131) + ax2 = fig3.add_subplot(132) + ax3 = fig3.add_subplot(133) + fig3.suptitle(f"Marginal Distributions of t=0 Samples, Sample at Epoch {epoch}") + + common_params = dict(histtype='stepfilled', alpha=0.5, bins=50) + + ax1.hist(xs[:, 0], **common_params, facecolor='r', label='ODE solver') + ax2.hist(xs[:, 1], **common_params, facecolor='r', label='ODE solver') + ax3.hist(xs[:, 2], **common_params, facecolor='r', label='ODE solver') + + ax1.hist(zs[:, 0], **common_params, facecolor='g', label='Training Data') + ax2.hist(zs[:, 1], **common_params, facecolor='g', label='Training Data') + ax3.hist(zs[:, 2], **common_params, facecolor='g', label='Training Data') + + ax1.set_xlabel('X') + ax2.set_xlabel('Y') + ax3.set_xlabel('Z') + + for ax in [ax1, ax2, ax3]: + ax.set_xlim(-0.01, 1.01) + ax.set_yscale('log') + + ax1.legend(loc=0) + fig3.tight_layout() + plt.show() + + radial_cutoff = 5.4 + training_cartesian_positions = batch['cartesian_positions'] + basis_vectors = torch.diag_embed(batch['box']) + training_interatomic_distances = get_interatomic_distances(training_cartesian_positions, + basis_vectors, + radial_cutoff=radial_cutoff) + + sample_relative_coordinates = map_relative_coordinates_to_unit_cell(batch_noisy_relative_coordinates[:, -1]) + sample_cartesian_positions = torch.bmm(sample_relative_coordinates, unit_cell) + sample_interatomic_distances = get_interatomic_distances(sample_cartesian_positions, + unit_cell, + radial_cutoff=radial_cutoff) + + fig4 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig4.suptitle(f'Interatomic Distance Distribution: Sample at Epoch {epoch}') + + ax1 = fig4.add_subplot(121) + ax2 = fig4.add_subplot(122) + + ax1.set_title('Training vs. Samples') + ax2.set_title('Intermediate Diffusion') + + common_params = dict(histtype='stepfilled', alpha=0.5, bins=75) + ax1.hist(training_interatomic_distances, **common_params, facecolor='g', label='Training Data') + ax1.hist(sample_interatomic_distances, **common_params, facecolor='r', label='ODE Sample, t = 0') + + for time_idx, color in zip([0, len(time) // 2 + 1, -1], ['blue', 'yellow', 'red']): + sample_relative_coordinates = map_relative_coordinates_to_unit_cell( + batch_noisy_relative_coordinates[:, time_idx]) + sample_cartesian_positions = torch.bmm(sample_relative_coordinates, unit_cell) + sample_interatomic_distances = get_interatomic_distances(sample_cartesian_positions, + unit_cell, + radial_cutoff=radial_cutoff) + ax2.hist(sample_interatomic_distances, **common_params, facecolor=color, + label=f'Noisy Sample t = {time[time_idx]:2.1f}') + + for ax in [ax1, ax2]: + ax.set_xlabel('Distance (Angstrom)') + ax.set_ylabel('Count') + ax.set_xlim([-0.01, radial_cutoff]) + ax.legend(loc=0) + ax.set_yscale('log') + fig4.tight_layout() + plt.show() diff --git a/experiment_analysis/sampling_analysis/diffusion_mace_ode_trajctory_analysis.py b/experiment_analysis/sampling_analysis/diffusion_mace_ode_trajctory_analysis.py new file mode 100644 index 00000000..6237081c --- /dev/null +++ b/experiment_analysis/sampling_analysis/diffusion_mace_ode_trajctory_analysis.py @@ -0,0 +1,63 @@ +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import torch + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH + +logger = logging.getLogger(__name__) + +plt.style.use(PLOT_STYLE_PATH) + + +# Some hardcoded paths and parameters. Change as needed! +epoch = 30 +base_data_dir = Path("/Users/bruno/courtois/difface_ode/run7") +position_samples_dir = base_data_dir / "diffusion_position_samples" +energy_samples_dir = base_data_dir / "energy_samples" +energy_data_directory = base_data_dir / "energy_samples" + + +if __name__ == '__main__': + energies = torch.load(energy_data_directory / f"energies_sample_epoch={epoch}.pt") + positions_data = torch.load(position_samples_dir / f"diffusion_position_sample_epoch={epoch}_steps=0.pt", + map_location=torch.device('cpu')) + + unit_cell = positions_data['unit_cell'] + + batch_times = positions_data['time'][0] + batch_noisy_relative_coordinates = positions_data['relative_coordinates'][0] + number_of_atoms, spatial_dimension = batch_noisy_relative_coordinates.shape[-2:] + + idx = energies.argmax() + relative_coordinates = batch_noisy_relative_coordinates[idx] + + fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig1.suptitle(f'ODE trajectory: Sample {idx} at Epoch {epoch} - Energy = {energies.max():4.2f}') + + ax1 = fig1.add_subplot(131) + ax2 = fig1.add_subplot(132) + ax3 = fig1.add_subplot(133) + + time = batch_times[0] # all time arrays are the same + + for atom_idx in range(number_of_atoms): + ax1.plot(time, relative_coordinates[:, atom_idx, 0], '-', alpha=0.5) + ax2.plot(time, relative_coordinates[:, atom_idx, 1], '-', alpha=0.5) + ax3.plot(time, relative_coordinates[:, atom_idx, 2], '-', alpha=0.5) + + for ax in [ax1, ax2, ax3]: + ax.set_xlabel('Diffusion Time') + ax.set_ylabel('Raw Relative Coordinate') + ax.yaxis.tick_right() + ax.spines['top'].set_visible(True) + ax.spines['right'].set_visible(True) + ax.spines['bottom'].set_visible(True) + ax.spines['left'].set_visible(True) + ax.set_xlim([1.01, -0.01]) + ax1.set_ylabel('X') + ax2.set_ylabel('Y') + ax3.set_ylabel('Z') + fig1.tight_layout() + plt.show() diff --git a/experiment_analysis/sampling_analysis/ode_sample_positions_to_cif_files.py b/experiment_analysis/sampling_analysis/ode_sample_positions_to_cif_files.py new file mode 100644 index 00000000..2ef2f0b8 --- /dev/null +++ b/experiment_analysis/sampling_analysis/ode_sample_positions_to_cif_files.py @@ -0,0 +1,50 @@ +"""Position to cif files for the ODE sampler. + +A simple script to extract the diffusion positions from a pickle on disk and output +in cif format for visualization. +""" +from pathlib import Path + +import torch +from pymatgen.core import Lattice, Structure + +from crystal_diffusion.utils.sample_trajectory import ODESampleTrajectory + +# Hard coding some paths to local results. Modify as needed... +epoch = 15 + +base_data_dir = Path("/Users/bruno/courtois/difface_ode/run1") +trajectory_data_directory = base_data_dir / "diffusion_position_samples" +energy_data_directory = base_data_dir / "energy_samples" +output_top_dir = trajectory_data_directory.parent / "visualization" + + +if __name__ == '__main__': + energies = torch.load(energy_data_directory / f"energies_sample_epoch={epoch}.pt") + + sample_idx = energies.argmax() + output_dir = output_top_dir / f"visualise_sampling_trajectory_epoch_{epoch}_sample_{sample_idx}" + output_dir.mkdir(exist_ok=True, parents=True) + + pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}_steps=0.pt" + sample_trajectory = ODESampleTrajectory.read_from_pickle(pickle_path) + + basis_vectors = sample_trajectory.data['unit_cell'][sample_idx].numpy() + lattice = Lattice(matrix=basis_vectors, pbc=(True, True, True)) + + # Shape [batch, time, number of atoms, space dimension] + batch_noisy_relative_coordinates = sample_trajectory.data['relative_coordinates'][0] + + noisy_relative_coordinates = batch_noisy_relative_coordinates[sample_idx].numpy() + + for idx, coordinates in enumerate(noisy_relative_coordinates): + number_of_atoms = coordinates.shape[0] + species = number_of_atoms * ['Si'] + + structure = Structure(lattice=lattice, + species=species, + coords=coordinates, + coords_are_cartesian=False) + + file_path = str(output_dir / f"diffusion_positions_{idx}.cif") + structure.to_file(file_path) diff --git a/experiment_analysis/sampling_analysis/sampling_si_diffusion.py b/experiment_analysis/sampling_analysis/sampling_si_diffusion.py index 882a3ade..89e47e8b 100644 --- a/experiment_analysis/sampling_analysis/sampling_si_diffusion.py +++ b/experiment_analysis/sampling_analysis/sampling_si_diffusion.py @@ -15,10 +15,10 @@ from yaml import load from crystal_diffusion import DATA_DIR, TOP_DIR +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + AnnealedLangevinDynamicsGenerator from crystal_diffusion.models.model_loader import load_diffusion_model from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps -from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ - AnnealedLangevinDynamicsSampler from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger @@ -72,11 +72,11 @@ sigma_normalized_score_network = pl_model.sigma_normalized_score_network logger.info("Creating sampler") - pc_sampler = AnnealedLangevinDynamicsSampler(noise_parameters=noise_parameters, - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - spatial_dimension=score_network_parameters.spatial_dimension, - sigma_normalized_score_network=sigma_normalized_score_network) + pc_sampler = AnnealedLangevinDynamicsGenerator(noise_parameters=noise_parameters, + number_of_corrector_steps=number_of_corrector_steps, + number_of_atoms=number_of_atoms, + spatial_dimension=score_network_parameters.spatial_dimension, + sigma_normalized_score_network=sigma_normalized_score_network) logger.info("Draw samples") samples = pc_sampler.sample(number_of_samples) diff --git a/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml b/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml index 6acefc6e..c6b79aa1 100644 --- a/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml +++ b/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml @@ -1,5 +1,5 @@ # general -exp_name: difface +exp_name: difface_ode run_name: run1 max_epoch: 25 log_every_n_steps: 1 @@ -11,7 +11,7 @@ seed: 1234 # data data: - batch_size: 512 + batch_size: 1024 num_workers: 8 max_atom: 8 @@ -28,13 +28,13 @@ model: interaction_cls: RealAgnosticResidualInteractionBlock interaction_cls_first: RealAgnosticInteractionBlock num_interactions: 2 - hidden_irreps: 128x0e + 128x1o + 128x2e - mlp_irreps: 128x0e - number_of_mlp_layers: 0 + hidden_irreps: 64x0e + 64x1o + 64x2e + mlp_irreps: 64x0e + number_of_mlp_layers: 3 avg_num_neighbors: 1 correlation: 3 gate: silu - radial_MLP: [128, 128, 128] + radial_MLP: [64, 64, 64] radial_type: bessel noise: total_time_steps: 100 @@ -69,17 +69,18 @@ diffusion_sampling: sigma_min: 0.001 # default value sigma_max: 0.5 # default value sampling: + algorithm: ode spatial_dimension: 3 - number_of_corrector_steps: 1 number_of_atoms: 8 number_of_samples: 1000 sample_every_n_epochs: 5 + record_samples: True cell_dimensions: [5.43, 5.43, 5.43] # A callback to check the loss vs. sigma loss_monitoring: number_of_bins: 50 - sample_every_n_epochs: 2 + sample_every_n_epochs: 5 logging: - comet \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4ac1d936..089d8635 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,4 @@ deepdiff==7.0.1 pykeops==2.2.3 comet_ml einops==0.8.0 +torchode==0.2.0 \ No newline at end of file diff --git a/sanity_checks/sanity_check_callbacks.py b/sanity_checks/sanity_check_callbacks.py index 91daf418..fe4489ce 100644 --- a/sanity_checks/sanity_check_callbacks.py +++ b/sanity_checks/sanity_check_callbacks.py @@ -4,11 +4,11 @@ from pytorch_lightning.loggers import TensorBoardLogger from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + AnnealedLangevinDynamicsGenerator from crystal_diffusion.models.score_networks.mlp_score_network import \ MLPScoreNetworkParameters from crystal_diffusion.namespace import NOISY_RELATIVE_COORDINATES -from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ - AnnealedLangevinDynamicsSampler from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.score.wrapped_gaussian_score import \ get_sigma_normalized_score @@ -73,13 +73,13 @@ def __init__(self, noise_parameters: NoiseParameters, def log_artifact(self, pl_module, tbx_logger): """Create artifact and log to tensorboard.""" sigma_normalized_score_network = pl_module.sigma_normalized_score_network - pc_sampler = AnnealedLangevinDynamicsSampler(noise_parameters=self.noise_parameters, - number_of_corrector_steps=self.number_of_corrector_steps, - number_of_atoms=self.number_of_atoms, - spatial_dimension=self.spatial_dimension, - sigma_normalized_score_network=sigma_normalized_score_network) + pc_generator = AnnealedLangevinDynamicsGenerator(noise_parameters=self.noise_parameters, + number_of_corrector_steps=self.number_of_corrector_steps, + number_of_atoms=self.number_of_atoms, + spatial_dimension=self.spatial_dimension, + sigma_normalized_score_network=sigma_normalized_score_network) - samples = pc_sampler.sample(self.number_of_samples).flatten() + samples = pc_generator.sample(self.number_of_samples).flatten() fig = plt.figure(figsize=PLEASANT_FIG_SIZE) ax = fig.add_subplot(111) diff --git a/tests/callbacks/test_sampling_callback.py b/tests/callbacks/test_sampling_callback.py index 15b6a8c7..091917ef 100644 --- a/tests/callbacks/test_sampling_callback.py +++ b/tests/callbacks/test_sampling_callback.py @@ -5,7 +5,8 @@ from pytorch_lightning import LightningModule from crystal_diffusion.callbacks.sampling_callback import ( - DiffusionSamplingCallback, SamplingParameters) + DiffusionSamplingCallback, ODESamplingParameters, + PredictorCorrectorSamplingParameters) from crystal_diffusion.samplers.variance_sampler import NoiseParameters @@ -17,16 +18,20 @@ @pytest.mark.parametrize("unit_cell_size", [10]) @pytest.mark.parametrize("lammps_energy", [2]) @pytest.mark.parametrize("spatial_dimension", [3]) -@pytest.mark.parametrize("number_of_corrector_steps", [1]) @pytest.mark.parametrize("number_of_atoms", [4]) @pytest.mark.parametrize("sample_batchsize", [None, 8, 4]) @pytest.mark.parametrize("record_samples", [True, False]) +@pytest.mark.parametrize("algorithm, number_of_corrector_steps", [('predictor_corrector', 1), ('ode', None)]) class TestSamplingCallback: @pytest.fixture() - def mock_create_sampler(self, number_of_samples): - pc_sampler = MagicMock() + def mock_create_generator(self): + generator = MagicMock() + return generator + + @pytest.fixture() + def mock_create_create_unit_cell(self, number_of_samples): unit_cell = np.arange(number_of_samples) # Dummy unit cell - return pc_sampler, unit_cell + return unit_cell @pytest.fixture() def mock_compute_lammps_energies(self, lammps_energy): @@ -41,29 +46,44 @@ def noise_parameters(self, total_time_steps, time_delta, sigma_min, corrector_st return noise_parameters @pytest.fixture() - def sampling_parameters(self, spatial_dimension, number_of_corrector_steps, number_of_atoms, number_of_samples, - sample_batchsize, unit_cell_size, record_samples): - sampling_parameters = SamplingParameters(spatial_dimension=spatial_dimension, - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - sample_batchsize=sample_batchsize, - cell_dimensions=[unit_cell_size for _ in range(spatial_dimension)], - record_samples=record_samples) + def sampling_parameters(self, algorithm, spatial_dimension, number_of_corrector_steps, + number_of_atoms, number_of_samples, sample_batchsize, unit_cell_size, record_samples): + if algorithm == 'predictor_corrector': + sampling_parameters = ( + PredictorCorrectorSamplingParameters(spatial_dimension=spatial_dimension, + number_of_corrector_steps=number_of_corrector_steps, + number_of_atoms=number_of_atoms, + number_of_samples=number_of_samples, + sample_batchsize=sample_batchsize, + cell_dimensions=[unit_cell_size for _ in range(spatial_dimension)], + record_samples=record_samples)) + elif algorithm == 'ode': + sampling_parameters = ( + ODESamplingParameters(spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + number_of_samples=number_of_samples, + sample_batchsize=sample_batchsize, + cell_dimensions=[unit_cell_size for _ in range(spatial_dimension)], + record_samples=record_samples)) + + else: + raise NotImplementedError + return sampling_parameters @pytest.fixture() def pl_model(self): return MagicMock(spec=LightningModule) - def test_sample_and_evaluate_energy(self, mocker, mock_compute_lammps_energies, mock_create_sampler, - noise_parameters, sampling_parameters, pl_model, sample_batchsize, - number_of_samples, tmpdir): + def test_sample_and_evaluate_energy(self, mocker, mock_compute_lammps_energies, mock_create_generator, + mock_create_create_unit_cell, noise_parameters, sampling_parameters, + pl_model, sample_batchsize, number_of_samples, tmpdir): sampling_cb = DiffusionSamplingCallback( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, output_directory=tmpdir) - mocker.patch.object(sampling_cb, "_create_sampler", return_value=mock_create_sampler) + mocker.patch.object(sampling_cb, "_create_generator", return_value=mock_create_generator) + mocker.patch.object(sampling_cb, "_create_unit_cell", return_value=mock_create_create_unit_cell) mocker.patch.object(sampling_cb, "_compute_oracle_energies", return_value=mock_compute_lammps_energies) sample_energies = sampling_cb.sample_and_evaluate_energy(pl_model) diff --git a/tests/conftest.py b/tests/conftest.py index 16a8b103..4316aec4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,34 @@ get_configuration_runs, write_to_yaml) +def pytest_addoption(parser): + parser.addoption( + "--quick", action="store_true", default=False, help="skip slow tests" + ) + parser.addoption( + "--slow", action="store_true", default=False, help="only perform slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--quick"): + # --quick given in cli: skip slow tests + skip = pytest.mark.skip(reason="--quick option must be absent to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip) + elif config.getoption("--slow"): + # --slow given in cli: only do the slow tests + skip = pytest.mark.skip(reason="--slow option must be present to run") + for item in items: + if "slow" not in item.keywords: + item.add_marker(skip) + + @pytest.fixture def basis_vectors(batch_size): # orthogonal boxes with dimensions between 5 and 10. diff --git a/tests/generators/__init__.py b/tests/generators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/generators/test_ode_position_generator.py b/tests/generators/test_ode_position_generator.py new file mode 100644 index 00000000..14fe26f9 --- /dev/null +++ b/tests/generators/test_ode_position_generator.py @@ -0,0 +1,75 @@ +from typing import AnyStr, Dict + +import pytest +import torch + +from crystal_diffusion.generators.ode_position_generator import \ + ExplodingVarianceODEPositionGenerator +from crystal_diffusion.models.score_networks.score_network import ( + ScoreNetwork, ScoreNetworkParameters) +from crystal_diffusion.namespace import NOISY_RELATIVE_COORDINATES +from crystal_diffusion.samplers.variance_sampler import ( + ExplodingVarianceSampler, NoiseParameters) + + +class FakeScoreNetwork(ScoreNetwork): + """A fake, smooth score network for the ODE solver.""" + + def _forward_unchecked(self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False) -> torch.Tensor: + return batch[NOISY_RELATIVE_COORDINATES] + + +@pytest.mark.parametrize("total_time_steps", [2, 5, 10]) +@pytest.mark.parametrize("spatial_dimension", [2, 3]) +@pytest.mark.parametrize("number_of_atoms", [8]) +@pytest.mark.parametrize("sigma_min", [0.15]) +@pytest.mark.parametrize("record_samples", [False, True]) +class TestExplodingVarianceODEPositionGenerator: + @pytest.fixture() + def sigma_normalized_score_network(self, spatial_dimension): + return FakeScoreNetwork(ScoreNetworkParameters(architecture='dummy', spatial_dimension=spatial_dimension)) + + @pytest.fixture() + def noise_parameters(self, total_time_steps, sigma_min): + return NoiseParameters(total_time_steps=total_time_steps, time_delta=0., sigma_min=sigma_min) + + @pytest.fixture() + def ode_generator(self, noise_parameters, number_of_atoms, spatial_dimension, + sigma_normalized_score_network, record_samples): + generator = ExplodingVarianceODEPositionGenerator(noise_parameters=noise_parameters, + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + sigma_normalized_score_network=sigma_normalized_score_network, + record_samples=record_samples) + + return generator + + @pytest.fixture() + def unit_cell_sample(self, spatial_dimension, number_of_samples): + unit_cell_size = 10. + return torch.diag(torch.Tensor([unit_cell_size] * spatial_dimension)).repeat(number_of_samples, 1, 1) + + def test_get_exploding_variance_sigma(self, ode_generator, noise_parameters): + times = ExplodingVarianceSampler._get_time_array(noise_parameters) + expected_sigmas = ExplodingVarianceSampler._create_sigma_array(noise_parameters, times) + computed_sigmas = ode_generator._get_exploding_variance_sigma(times) + torch.testing.assert_close(expected_sigmas, computed_sigmas) + + def test_get_ode_prefactor(self, ode_generator, noise_parameters): + times = ExplodingVarianceSampler._get_time_array(noise_parameters) + sigmas = ode_generator._get_exploding_variance_sigma(times) + + sig_ratio = torch.tensor(noise_parameters.sigma_max / noise_parameters.sigma_min) + expected_ode_prefactor = torch.log(sig_ratio) * sigmas + computed_ode_prefactor = ode_generator._get_ode_prefactor(sigmas) + torch.testing.assert_close(expected_ode_prefactor, computed_ode_prefactor) + + @pytest.mark.parametrize("number_of_samples", [8]) + def test_smoke_sample(self, ode_generator, number_of_samples, number_of_atoms, spatial_dimension, unit_cell_sample): + # Just a smoke test that we can sample without crashing. + relative_coordinates = ode_generator.sample(number_of_samples, torch.device('cpu'), unit_cell_sample) + + assert relative_coordinates.shape == (number_of_samples, number_of_atoms, spatial_dimension) + + assert relative_coordinates.min() >= 0. + assert relative_coordinates.max() < 1. diff --git a/tests/samplers/test_predictor_corrector_position_sampler.py b/tests/generators/test_predictor_corrector_position_generator.py similarity index 70% rename from tests/samplers/test_predictor_corrector_position_sampler.py rename to tests/generators/test_predictor_corrector_position_generator.py index 549316dd..32f93f34 100644 --- a/tests/samplers/test_predictor_corrector_position_sampler.py +++ b/tests/generators/test_predictor_corrector_position_generator.py @@ -1,18 +1,18 @@ import pytest import torch +from crystal_diffusion.generators.predictor_corrector_position_generator import ( + AnnealedLangevinDynamicsGenerator, PredictorCorrectorPositionGenerator) from crystal_diffusion.models.score_networks.mlp_score_network import ( MLPScoreNetwork, MLPScoreNetworkParameters) -from crystal_diffusion.samplers.predictor_corrector_position_sampler import ( - AnnealedLangevinDynamicsSampler, PredictorCorrectorPositionSampler) from crystal_diffusion.samplers.variance_sampler import ( ExplodingVarianceSampler, NoiseParameters) from crystal_diffusion.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell -class FakePCSampler(PredictorCorrectorPositionSampler): - """A dummy PC sampler for the purpose of testing.""" +class FakePCGenerator(PredictorCorrectorPositionGenerator): + """A dummy PC generator for the purpose of testing.""" def __init__( self, @@ -42,7 +42,7 @@ def corrector_step(self, x_i: torch.Tensor, i: int, unit_cell: torch.Tensor, for @pytest.mark.parametrize("number_of_discretization_steps", [1, 5, 10]) @pytest.mark.parametrize("number_of_corrector_steps", [0, 1, 2]) @pytest.mark.parametrize("unit_cell_size", [10]) -class TestPredictorCorrectorPositionSampler: +class TestPredictorCorrectorPositionGenerator: @pytest.fixture(scope="class", autouse=True) def set_random_seed(self): torch.manual_seed(1234567) @@ -52,13 +52,13 @@ def initial_sample(self, number_of_samples, number_of_atoms, spatial_dimension): return torch.rand(number_of_samples, number_of_atoms, spatial_dimension) @pytest.fixture - def sampler( + def generator( self, number_of_discretization_steps, number_of_corrector_steps, spatial_dimension, initial_sample ): - sampler = FakePCSampler( + generator = FakePCGenerator( number_of_discretization_steps, number_of_corrector_steps, spatial_dimension, initial_sample ) - return sampler + return generator @pytest.fixture() def unit_cell_sample(self, unit_cell_size, spatial_dimension, number_of_samples): @@ -67,7 +67,7 @@ def unit_cell_sample(self, unit_cell_size, spatial_dimension, number_of_samples) @pytest.fixture def expected_samples( self, - sampler, + generator, initial_sample, number_of_discretization_steps, number_of_corrector_steps, @@ -80,16 +80,16 @@ def expected_samples( noisy_sample = map_relative_coordinates_to_unit_cell(initial_sample) x_ip1 = noisy_sample for i in list_i: - xi = map_relative_coordinates_to_unit_cell(sampler.predictor_step(x_ip1, i + 1, unit_cell_sample, - torch.zeros_like(x_ip1))) + xi = map_relative_coordinates_to_unit_cell(generator.predictor_step(x_ip1, i + 1, unit_cell_sample, + torch.zeros_like(x_ip1))) for _ in list_j: - xi = map_relative_coordinates_to_unit_cell(sampler.corrector_step(xi, i, unit_cell_sample, - torch.zeros_like(xi))) + xi = map_relative_coordinates_to_unit_cell(generator.corrector_step(xi, i, unit_cell_sample, + torch.zeros_like(xi))) x_ip1 = xi return xi - def test_sample(self, sampler, number_of_samples, expected_samples, unit_cell_sample): - computed_samples = sampler.sample(number_of_samples, torch.device('cpu'), unit_cell_sample) + def test_sample(self, generator, number_of_samples, expected_samples, unit_cell_sample): + computed_samples = generator.sample(number_of_samples, torch.device('cpu'), unit_cell_sample) torch.testing.assert_close(expected_samples, computed_samples) @@ -126,32 +126,32 @@ def noise_parameters(self, total_time_steps, time_delta, sigma_min, corrector_st return noise_parameters @pytest.fixture() - def pc_sampler(self, noise_parameters, - number_of_corrector_steps, - number_of_atoms, - spatial_dimension, - sigma_normalized_score_network): - sampler = AnnealedLangevinDynamicsSampler(noise_parameters=noise_parameters, - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - spatial_dimension=spatial_dimension, - sigma_normalized_score_network=sigma_normalized_score_network) - - return sampler + def pc_generator(self, noise_parameters, + number_of_corrector_steps, + number_of_atoms, + spatial_dimension, + sigma_normalized_score_network): + generator = AnnealedLangevinDynamicsGenerator(noise_parameters=noise_parameters, + number_of_corrector_steps=number_of_corrector_steps, + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + sigma_normalized_score_network=sigma_normalized_score_network) + + return generator @pytest.fixture() def unit_cell_sample(self, unit_cell_size, spatial_dimension, number_of_samples): return torch.diag(torch.Tensor([unit_cell_size] * spatial_dimension)).repeat(number_of_samples, 1, 1) - def test_smoke_sample(self, pc_sampler, number_of_samples, unit_cell_sample): + def test_smoke_sample(self, pc_generator, number_of_samples, unit_cell_sample): # Just a smoke test that we can sample without crashing. - pc_sampler.sample(number_of_samples, torch.device('cpu'), unit_cell_sample) + pc_generator.sample(number_of_samples, torch.device('cpu'), unit_cell_sample) @pytest.fixture() def x_i(self, number_of_samples, number_of_atoms, spatial_dimension): return map_relative_coordinates_to_unit_cell(torch.rand(number_of_samples, number_of_atoms, spatial_dimension)) - def test_predictor_step(self, mocker, pc_sampler, noise_parameters, x_i, total_time_steps, number_of_samples, + def test_predictor_step(self, mocker, pc_generator, noise_parameters, x_i, total_time_steps, number_of_samples, unit_cell_sample): sampler = ExplodingVarianceSampler(noise_parameters) @@ -161,11 +161,11 @@ def test_predictor_step(self, mocker, pc_sampler, noise_parameters, x_i, total_t list_time = noise.time forces = torch.zeros_like(x_i) - z = pc_sampler._draw_gaussian_sample(number_of_samples) - mocker.patch.object(pc_sampler, "_draw_gaussian_sample", return_value=z) + z = pc_generator._draw_gaussian_sample(number_of_samples) + mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) for index_i in range(1, total_time_steps + 1): - computed_sample = pc_sampler.predictor_step(x_i, index_i, unit_cell_sample, forces) + computed_sample = pc_generator.predictor_step(x_i, index_i, unit_cell_sample, forces) sigma_i = list_sigma[index_i - 1] t_i = list_time[index_i - 1] @@ -176,13 +176,13 @@ def test_predictor_step(self, mocker, pc_sampler, noise_parameters, x_i, total_t g2 = sigma_i**2 - sigma_im1**2 - s_i = pc_sampler._get_sigma_normalized_scores(x_i, t_i, sigma_i, unit_cell_sample, forces) / sigma_i + s_i = pc_generator._get_sigma_normalized_scores(x_i, t_i, sigma_i, unit_cell_sample, forces) / sigma_i expected_sample = x_i + g2 * s_i + torch.sqrt(g2) * z torch.testing.assert_close(computed_sample, expected_sample) - def test_corrector_step(self, mocker, pc_sampler, noise_parameters, x_i, total_time_steps, number_of_samples, + def test_corrector_step(self, mocker, pc_generator, noise_parameters, x_i, total_time_steps, number_of_samples, unit_cell_sample): sampler = ExplodingVarianceSampler(noise_parameters) @@ -194,11 +194,11 @@ def test_corrector_step(self, mocker, pc_sampler, noise_parameters, x_i, total_t sigma_1 = list_sigma[0] forces = torch.zeros_like(x_i) - z = pc_sampler._draw_gaussian_sample(number_of_samples) - mocker.patch.object(pc_sampler, "_draw_gaussian_sample", return_value=z) + z = pc_generator._draw_gaussian_sample(number_of_samples) + mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) for index_i in range(0, total_time_steps): - computed_sample = pc_sampler.corrector_step(x_i, index_i, unit_cell_sample, forces) + computed_sample = pc_generator.corrector_step(x_i, index_i, unit_cell_sample, forces) if index_i == 0: sigma_i = sigma_min @@ -209,7 +209,7 @@ def test_corrector_step(self, mocker, pc_sampler, noise_parameters, x_i, total_t eps_i = 0.5 * epsilon * sigma_i**2 / sigma_1**2 - s_i = pc_sampler._get_sigma_normalized_scores(x_i, t_i, sigma_i, unit_cell_sample, forces) / sigma_i + s_i = pc_generator._get_sigma_normalized_scores(x_i, t_i, sigma_i, unit_cell_sample, forces) / sigma_i expected_sample = x_i + eps_i * s_i + torch.sqrt(2. * eps_i) * z diff --git a/tests/test_train_diffusion.py b/tests/test_train_diffusion.py index a34a80f2..c2c7f001 100644 --- a/tests/test_train_diffusion.py +++ b/tests/test_train_diffusion.py @@ -83,7 +83,8 @@ def get_score_network(architecture: str, head_name: Union[str, None], number_of_ return score_network -def get_config(number_of_atoms: int, max_epoch: int, architecture: str, head_name: Union[str, None]): +def get_config(number_of_atoms: int, max_epoch: int, architecture: str, head_name: Union[str, None], + sampling_algorithm: str): data_config = dict(batch_size=4, num_workers=0, max_atom=number_of_atoms) model_config = dict(score_network=get_score_network(architecture, head_name, number_of_atoms), @@ -92,13 +93,15 @@ def get_config(number_of_atoms: int, max_epoch: int, architecture: str, head_nam optimizer_config = dict(name='adam', learning_rate=0.001) scheduler_config = dict(name='ReduceLROnPlateau', factor=0.6, patience=2) - sampling_dict = dict(spatial_dimension=3, - number_of_corrector_steps=1, + sampling_dict = dict(algorithm=sampling_algorithm, + spatial_dimension=3, number_of_atoms=number_of_atoms, number_of_samples=4, sample_every_n_epochs=1, record_samples=True, cell_dimensions=[10., 10., 10.]) + if sampling_algorithm == 'predictor_corrector': + sampling_dict["number_of_corrector_steps"] = 1 early_stopping_config = dict(metric='validation_epoch_loss', mode='min', patience=max_epoch) model_checkpoint_config = dict(monitor='validation_epoch_loss', mode='min') @@ -120,6 +123,7 @@ def get_config(number_of_atoms: int, max_epoch: int, architecture: str, head_nam return config +@pytest.mark.parametrize("sampling_algorithm", ["ode", "predictor_corrector"]) @pytest.mark.parametrize("architecture, head_name", [('diffusion_mace', None), ('mlp', None), @@ -131,8 +135,9 @@ def max_epoch(self): return 5 @pytest.fixture() - def config(self, number_of_atoms, max_epoch, architecture, head_name): - return get_config(number_of_atoms, max_epoch=max_epoch, architecture=architecture, head_name=head_name) + def config(self, number_of_atoms, max_epoch, architecture, head_name, sampling_algorithm): + return get_config(number_of_atoms, max_epoch=max_epoch, + architecture=architecture, head_name=head_name, sampling_algorithm=sampling_algorithm) @pytest.fixture() def all_paths(self, paths, tmpdir, config): @@ -164,6 +169,7 @@ def args(self, all_paths): return input_args + @pytest.mark.slow def test_checkpoint_callback(self, args, all_paths, max_epoch): train_diffusion.main(args) best_model_path = os.path.join(all_paths['output'], BEST_MODEL_NAME) @@ -184,6 +190,7 @@ def test_checkpoint_callback(self, args, all_paths, max_epoch): model_epoch = int(match_object.group('epoch')) assert model_epoch == max_epoch - 1 # the epoch counter starts at zero! + @pytest.mark.slow def test_restart(self, args, all_paths, max_epoch, mocker): last_model_path = os.path.join(all_paths['output'], LAST_MODEL_NAME) diff --git a/tests/utils/test_sample_trajectory.py b/tests/utils/test_sample_trajectory.py index 096d00af..c9195e55 100644 --- a/tests/utils/test_sample_trajectory.py +++ b/tests/utils/test_sample_trajectory.py @@ -3,7 +3,8 @@ import pytest import torch -from crystal_diffusion.utils.sample_trajectory import SampleTrajectory +from crystal_diffusion.utils.sample_trajectory import \ + PredictorCorrectorSampleTrajectory @pytest.fixture(autouse=True, scope='module') @@ -98,7 +99,7 @@ def list_corrected_x_i(number_of_predictor_steps, number_of_corrector_steps, bat @pytest.fixture(scope='module') def sample_trajectory(number_of_corrector_steps, list_i_indices, list_times, list_sigmas, basis_vectors, list_x_i, list_x_im1, predictor_scores, list_x_i_corr, list_corrected_x_i, corrector_scores): - sample_trajectory = SampleTrajectory() + sample_trajectory = PredictorCorrectorSampleTrajectory() sample_trajectory.record_unit_cell(basis_vectors) total_corrector_index = 0 @@ -156,7 +157,7 @@ def test_load_from_pickle(sample_trajectory, tmp_path): pickle_path = str(tmp_path / 'test_pickle_path_to_load.pkl') sample_trajectory.write_to_pickle(pickle_path) - loaded_sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path) + loaded_sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path) assert set(sample_trajectory.data.keys()) == set(loaded_sample_trajectory.data.keys()) @@ -169,7 +170,7 @@ def test_load_from_pickle(sample_trajectory, tmp_path): def test_reset(sample_trajectory, tmp_path): pickle_path = str(tmp_path / 'test_pickle_path_reset.pkl') sample_trajectory.write_to_pickle(pickle_path) - loaded_sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path) + loaded_sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path) assert len(loaded_sample_trajectory.data.keys()) != 0 loaded_sample_trajectory.reset()