diff --git a/crystal_diffusion/analysis/exploding_variance_analysis.py b/crystal_diffusion/analysis/exploding_variance_analysis.py index ed19ae26..e57b7264 100644 --- a/crystal_diffusion/analysis/exploding_variance_analysis.py +++ b/crystal_diffusion/analysis/exploding_variance_analysis.py @@ -20,27 +20,25 @@ noise_parameters = NoiseParameters(total_time_steps=1000) variance_sampler = ExplodingVarianceSampler(noise_parameters=noise_parameters) - noise = variance_sampler.get_all_noise() + noise, langevin_dynamics = variance_sampler.get_all_sampling_parameters() fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE) fig1.suptitle("Noise Schedule") - ax1 = fig1.add_subplot(221) - ax2 = fig1.add_subplot(223) - ax3 = fig1.add_subplot(122) + ax1 = fig1.add_subplot(121) + ax2 = fig1.add_subplot(122) - ax1.plot(noise.time, noise.sigma, '-', c='k', lw=2) - ax2.plot(noise.time[1:], noise.g[1:], '-', c='k', lw=2) + ax1.plot(noise.time, noise.sigma, '-', c='b', lw=4, label='$\\sigma(t)$') + ax1.plot(noise.time, noise.g, '-', c='g', lw=4, label="$g(t)$") - ax1.set_ylabel('$\\sigma(t)$') - ax2.set_ylabel('$g(t)$') + shifted_time = torch.cat([torch.tensor([0]), noise.time[:-1]]) + ax1.plot(shifted_time, langevin_dynamics.epsilon, '-', c='r', lw=4, label="$\\epsilon(t)$") + ax1.legend(loc=0) - for ax in [ax1, ax2]: - ax.set_xlabel('time') - ax.set_xlim([-0.01, 1.01]) + ax1.set_xlabel('time') + ax1.set_xlim([-0.01, 1.01]) - ax1.set_title("$\\sigma$ schedule") - ax2.set_title("g schedule") + ax1.set_title("$\\sigma, g, \\epsilon$ schedules") relative_positions = torch.linspace(0, 1, 101)[:-1] @@ -55,13 +53,13 @@ target_sigma_normalized_scores = get_sigma_normalized_score(relative_positions, torch.ones_like(relative_positions) * sigma, kmax=kmax) - ax3.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}") + ax2.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}") - ax3.set_title("Target Normalized Score") - ax3.set_xlabel("relative position, u") - ax3.set_ylabel("$\\sigma(t) \\times S(u, t)$") - ax3.legend(loc=0) - ax3.set_xlim([-0.01, 1.01]) + ax2.set_title("Target Normalized Score") + ax2.set_xlabel("relative position, u") + ax2.set_ylabel("$\\sigma(t) \\times S(u, t)$") + ax2.legend(loc=0) + ax2.set_xlim([-0.01, 1.01]) fig1.tight_layout() diff --git a/crystal_diffusion/models/callbacks.py b/crystal_diffusion/models/callbacks.py index c3adbe4d..3cec517c 100644 --- a/crystal_diffusion/models/callbacks.py +++ b/crystal_diffusion/models/callbacks.py @@ -5,7 +5,18 @@ from pytorch_lightning import Callback from pytorch_lightning.loggers import TensorBoardLogger -from crystal_diffusion.analysis import PLEASANT_FIG_SIZE +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.models.score_network import MLPScoreNetworkParameters +from crystal_diffusion.samplers.noisy_position_sampler import \ + map_positions_to_unit_cell +from crystal_diffusion.samplers.predictor_corrector_position_sampler import \ + AnnealedLangevinDynamicsSampler +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.score.wrapped_gaussian_score import \ + get_sigma_normalized_score + +plt.style.use(PLOT_STYLE_PATH) +TENSORBOARD_FIGSIZE = (0.65 * PLEASANT_FIG_SIZE[0], 0.65 * PLEASANT_FIG_SIZE[1]) class HPLoggingCallback(Callback): @@ -94,10 +105,114 @@ def log_artifact(self, pl_module, tbx_logger): list_sigmas.append(output["sigmas"].flatten()) list_xt = torch.cat(list_xt) list_sigmas = torch.cat(list_sigmas) - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig = plt.figure(figsize=TENSORBOARD_FIGSIZE) ax = fig.add_subplot(111) ax.set_title(f"Position Samples: global step = {pl_module.global_step}") ax.set_ylabel("$\\sigma$") ax.set_xlabel("position samples $x(t)$") ax.plot(list_xt, list_sigmas, "bo") + ax.set_xlim([-0.05, 1.05]) + fig.tight_layout() tbx_logger.add_figure("train/samples", fig, global_step=pl_module.global_step) + + +class TensorboardGeneratedSamplesLoggingCallback(TensorBoardDebuggingLoggingCallback): + """This callback will log an image of a histogram of generated samples on tensorboard.""" + + def __init__(self, noise_parameters: NoiseParameters, + number_of_corrector_steps: int, + score_network_parameters: MLPScoreNetworkParameters, number_of_samples: int): + """Init method.""" + super().__init__() + self.noise_parameters = noise_parameters + self.number_of_corrector_steps = number_of_corrector_steps + self.score_network_parameters = score_network_parameters + self.number_of_atoms = score_network_parameters.number_of_atoms + self.spatial_dimension = score_network_parameters.spatial_dimension + self.number_of_samples = number_of_samples + + def log_artifact(self, pl_module, tbx_logger): + """Create artifact and log to tensorboard.""" + sigma_normalized_score_network = pl_module.sigma_normalized_score_network + pc_sampler = AnnealedLangevinDynamicsSampler(noise_parameters=self.noise_parameters, + number_of_corrector_steps=self.number_of_corrector_steps, + number_of_atoms=self.number_of_atoms, + spatial_dimension=self.spatial_dimension, + sigma_normalized_score_network=sigma_normalized_score_network) + + samples = pc_sampler.sample(self.number_of_samples).flatten() + + fig = plt.figure(figsize=TENSORBOARD_FIGSIZE) + ax = fig.add_subplot(111) + ax.set_title(f"Generated Samples: global step = {pl_module.global_step}") + ax.set_xlabel('$x$') + ax.hist(samples, bins=101, range=(0, 1), label=f'{self.number_of_samples} samples') + ax.set_title("Samples Count") + ax.set_xlim([-0.05, 1.05]) + ax.set_ylim([0., self.number_of_samples]) + fig.tight_layout() + tbx_logger.add_figure("train/generated_samples", fig, global_step=pl_module.global_step) + + +class TensorboardScoreAndErrorLoggingCallback(TensorBoardDebuggingLoggingCallback): + """This callback will log histograms of the labels, predictions and errors on tensorboard.""" + + def __init__(self, x0: float): + """Init method.""" + super().__init__() + self.x0 = x0 + + def log_artifact(self, pl_module, tbx_logger): + """Create artifact and log to tensorboard.""" + fig = plt.figure(figsize=TENSORBOARD_FIGSIZE) + fig.suptitle("Scores within 2 $\\sigma$ of Data") + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + ax1.set_ylabel('$\\sigma \\times S_{\\theta}(x, t)$') + ax2.set_ylabel('$\\sigma \\times S_{\\theta}(x, t) - \\sigma \\nabla \\log P(x | 0)$') + for ax in [ax1, ax2]: + ax.set_xlabel('$x$') + + list_x = torch.linspace(0, 1, 1001)[:-1] + + times = torch.tensor([0.25, 0.75, 1.]) + sigmas = pl_module.variance_sampler._create_sigma_array(pl_module.variance_sampler.noise_parameters, times) + + with torch.no_grad(): + for time, sigma in zip(times, sigmas): + times = time * torch.ones(1000).reshape(-1, 1) + + sigma_normalized_kernel = get_sigma_normalized_score(map_positions_to_unit_cell(list_x - self.x0), + sigma * torch.ones_like(list_x), + kmax=4) + predicted_normalized_scores = pl_module._get_predicted_normalized_score(list_x.reshape(-1, 1, 1), + times).flatten() + + error = predicted_normalized_scores - sigma_normalized_kernel + + # only plot the errors in the sampling region! These regions might be disconnected, let's make + # sure the continuous lines make sense. + mask1 = torch.abs(list_x - self.x0) < 2 * sigma + mask2 = torch.abs(1. - list_x + self.x0) < 2 * sigma + + lines = ax1.plot(list_x[mask1], predicted_normalized_scores[mask1], lw=1, label='Prediction') + color = lines[0].get_color() + ax1.plot(list_x[mask2], predicted_normalized_scores[mask2], lw=1, color=color, label='_none_') + + ax1.plot(list_x[mask1], sigma_normalized_kernel[mask1], '--', lw=2, color=color, label='Target') + ax1.plot(list_x[mask2], sigma_normalized_kernel[mask2], '--', lw=2, color=color, label='_none_') + + ax2.plot(list_x[mask1], error[mask1], '-', color=color, + label=f't = {time:4.3f}, $\\sigma$ = {sigma:4.3f}') + ax2.plot(list_x[mask2], error[mask2], '-', color=color, label='_none_') + + for ax in [ax1, ax2]: + ax.set_xlim([-0.05, 1.05]) + ax.legend(loc=3, prop={'size': 6}) + + ax1.set_ylim([-3., 3.]) + ax2.set_ylim([-1., 1.]) + + fig.tight_layout() + + tbx_logger.add_figure("train/scores", fig, global_step=pl_module.global_step) diff --git a/crystal_diffusion/models/score_network.py b/crystal_diffusion/models/score_network.py index 42cf8156..8fae4e23 100644 --- a/crystal_diffusion/models/score_network.py +++ b/crystal_diffusion/models/score_network.py @@ -17,7 +17,7 @@ class BaseScoreNetworkParameters: spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. -class BaseScoreNetwork(torch.nn.Module): +class ScoreNetwork(torch.nn.Module): """Base score network. This base class defines the interface that all score networks should have @@ -32,7 +32,7 @@ def __init__(self, hyper_params: BaseScoreNetworkParameters): Args: hyper_params : hyperparameters from the config file. """ - super(BaseScoreNetwork, self).__init__() + super(ScoreNetwork, self).__init__() self._hyper_params = hyper_params self.spatial_dimension = hyper_params.spatial_dimension @@ -124,7 +124,7 @@ class MLPScoreNetworkParameters(BaseScoreNetworkParameters): hidden_dimensions: List[int] # dimensions of the hidden layers. Length of array determines number of layers. -class MLPScoreNetwork(BaseScoreNetwork): +class MLPScoreNetwork(ScoreNetwork): """Simple Model Class. Inherits from the given framework's model class. This is a simple MLP model. diff --git a/crystal_diffusion/samplers/predictor_corrector_position_sampler.py b/crystal_diffusion/samplers/predictor_corrector_position_sampler.py new file mode 100644 index 00000000..b549a9a8 --- /dev/null +++ b/crystal_diffusion/samplers/predictor_corrector_position_sampler.py @@ -0,0 +1,196 @@ +from abc import ABC, abstractmethod + +import torch + +from crystal_diffusion.models.score_network import ScoreNetwork +from crystal_diffusion.samplers.noisy_position_sampler import \ + map_positions_to_unit_cell +from crystal_diffusion.samplers.variance_sampler import ( + ExplodingVarianceSampler, NoiseParameters) + + +class PredictorCorrectorPositionSampler(ABC): + """This defines the interface for position samplers.""" + + def __init__(self, number_of_discretization_steps: int, number_of_corrector_steps: int, **kwargs): + """Init method.""" + assert number_of_discretization_steps > 0, "The number of discretization steps should be larger than zero" + assert number_of_corrector_steps >= 0, "The number of corrector steps should be non-negative" + + self.number_of_discretization_steps = number_of_discretization_steps + self.number_of_corrector_steps = number_of_corrector_steps + + def sample(self, number_of_samples: int) -> torch.Tensor: + """Sample. + + This method draws a sample using the PR sampler algorithm. + + Args: + number_of_samples : number of samples to draw. + + Returns: + position samples: position samples. + """ + x_ip1 = map_positions_to_unit_cell(self.initialize(number_of_samples)) + + for i in range(self.number_of_discretization_steps - 1, -1, -1): + x_i = map_positions_to_unit_cell(self.predictor_step(x_ip1, i + 1)) + for j in range(self.number_of_corrector_steps): + x_i = map_positions_to_unit_cell(self.corrector_step(x_i, i)) + x_ip1 = x_i + + return x_i + + @abstractmethod + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + pass + + @abstractmethod + def predictor_step(self, x_ip1: torch.Tensor, ip1: int) -> torch.Tensor: + """Predictor step. + + It is assumed that there are N predictor steps, with index "i" running from N-1 to 0. + + Args: + x_ip1 : sampled relative positions at step "i + 1". + ip1 : index "i + 1" + + Returns: + x_i : sampled relative positions after the predictor step. + """ + pass + + @abstractmethod + def corrector_step(self, x_i: torch.Tensor, i: int) -> torch.Tensor: + """Corrector step. + + It is assumed that there are N predictor steps, with index "i" running from N-1 to 0. + For each value of "i", there are M corrector steps. + Args: + x_i : sampled relative positions at step "i". + i : index "i" OF THE PREDICTOR STEP. + + Returns: + x_i_out : sampled relative positions after the corrector step. + """ + pass + + +class AnnealedLangevinDynamicsSampler(PredictorCorrectorPositionSampler): + """Annealed Langevin Dynamics Sampler. + + This class implements the annealed Langevin Dynamics sampling of + Song & Ermon 2019, namely: + "Generative Modeling by Estimating Gradients of the Data Distribution" + """ + + def __init__(self, + noise_parameters: NoiseParameters, + number_of_corrector_steps: int, + number_of_atoms: int, + spatial_dimension: int, + sigma_normalized_score_network: ScoreNetwork, + ): + """Init method.""" + super().__init__(number_of_discretization_steps=noise_parameters.total_time_steps, + number_of_corrector_steps=number_of_corrector_steps) + self.noise_parameters = noise_parameters + sampler = ExplodingVarianceSampler(noise_parameters) + self.noise, self.langevin_dynamics = sampler.get_all_sampling_parameters() + self.number_of_atoms = number_of_atoms + self.spatial_dimension = spatial_dimension + self.sigma_normalized_score_network = sigma_normalized_score_network + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + return torch.rand(number_of_samples, self.number_of_atoms, self.spatial_dimension) + + def _draw_gaussian_sample(self, number_of_samples): + return torch.randn(number_of_samples, self.number_of_atoms, self.spatial_dimension) + + def _get_sigma_normalized_scores(self, x: torch.Tensor, time: float) -> torch.Tensor: + """Get sigma normalized scores. + + Args: + x : relative positions, of shape [number_of_samples, number_of_atoms, spatial_dimension] + time : time at which to evaluate the score + + Returns: + sigma normalized score: sigma x Score(x, t). + """ + pos_key = self.sigma_normalized_score_network.position_key + time_key = self.sigma_normalized_score_network.timestep_key + + number_of_samples = x.shape[0] + + time_tensor = time * torch.ones(number_of_samples, 1) + augmented_batch = {pos_key: x, time_key: time_tensor} + with torch.no_grad(): + predicted_normalized_scores = self.sigma_normalized_score_network(augmented_batch) + + return predicted_normalized_scores + + def predictor_step(self, x_i: torch.Tensor, index_i: int) -> torch.Tensor: + """Predictor step. + + Args: + x_i : sampled relative positions, at time step i. + index_i : index of the time step. + + Returns: + x_im1 : sampled relative positions, at time step i - 1. + """ + assert 1 <= index_i <= self.number_of_discretization_steps, \ + "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps." + + number_of_samples = x_i.shape[0] + z = self._draw_gaussian_sample(number_of_samples) + + idx = index_i - 1 # python starts indices at zero + t_i = self.noise.time[idx] + g_i = self.noise.g[idx] + g2_i = self.noise.g_squared[idx] + sigma_i = self.noise.sigma[idx] + + sigma_score_i = self._get_sigma_normalized_scores(x_i, t_i) + + x_im1 = x_i + g2_i / sigma_i * sigma_score_i + g_i * z + + return x_im1 + + def corrector_step(self, x_i: torch.Tensor, index_i: int) -> torch.Tensor: + """Corrector Step. + + Args: + x_i : sampled relative positions, at time step i. + index_i : index of the time step. + + Returns: + corrected x_i : sampled relative positions, after corrector step. + """ + assert 0 <= index_i <= self.number_of_discretization_steps - 1, \ + ("The corrector step can only be invoked for index_i between 0 and " + "the total number of discretization steps minus 1.") + + number_of_samples = x_i.shape[0] + z = self._draw_gaussian_sample(number_of_samples) + + # The Langevin dynamics array are indexed with [0,..., N-1] + eps_i = self.langevin_dynamics.epsilon[index_i] + sqrt_2eps_i = self.langevin_dynamics.sqrt_2_epsilon[index_i] + + if index_i == 0: + # TODO: we are extrapolating here; the score network will never have seen this time step... + sigma_i = self.noise_parameters.sigma_min + t_i = 0. + else: + idx = index_i - 1 # python starts indices at zero + sigma_i = self.noise.sigma[idx] + t_i = self.noise.time[idx] + + sigma_score_i = self._get_sigma_normalized_scores(x_i, t_i) + + corrected_x_i = x_i + eps_i / sigma_i * sigma_score_i + sqrt_2eps_i * z + + return corrected_x_i diff --git a/crystal_diffusion/samplers/variance_sampler.py b/crystal_diffusion/samplers/variance_sampler.py index cb449362..7591ad84 100644 --- a/crystal_diffusion/samplers/variance_sampler.py +++ b/crystal_diffusion/samplers/variance_sampler.py @@ -5,13 +5,17 @@ import torch Noise = namedtuple("Noise", ["time", "sigma", "sigma_squared", "g", "g_squared"]) +LangevinDynamics = namedtuple("LangevinDynamics", ["epsilon", "sqrt_2_epsilon"]) @dataclass class NoiseParameters: - """Variance parameters.""" - + """Noise schedule parameters.""" total_time_steps: int + time_delta: float = 1e-5 # the time schedule will cover the range [time_delta, 1] + # As discussed in Appendix C of "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS", + # the time t = 0 is problematic. + # Default values come from the paper: # "Torsional Diffusion for Molecular Conformer Generation", # The original values in the paper are @@ -21,15 +25,48 @@ class NoiseParameters: sigma_min: float = 0.005 sigma_max: float = 0.5 + # Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution" + corrector_step_epsilon: float = 2e-5 + class ExplodingVarianceSampler: """Exploding Variance Sampler. - This class is responsible for creating the all the quantities - needed for noise generation. + This class is responsible for creating all the quantities needed + for noise generation for training and sampling. This implementation will use "exponential diffusion" as discussed in - the paper "Torsional Diffusion for Molecular Conformer Generation". + the following papers (no one paper presents everything clearly) + - [1] "Torsional Diffusion for Molecular Conformer Generation". + - [2] "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" + - [3] "Generative Modeling by Estimating Gradients of the Data Distribution" + + The following quantities are defined: + - total number of times steps, N + + - time steps: + t in [delta, 1], on a discretized grid, t_i for i = 1, ..., N. + We avoid t = 0 because sigma(t) is poorly defined there. See Appendix C of [2]. + + - sigma and sigma^2: + standard deviation, following the "exploding variance scheme", + sigma(t) = sigma_min^(1-t) sigma_max^t. + sigma_i for i = 1, ..., N. + + - g and g^2: + g is the diffusion coefficient that appears in the stochastic differential equation (SDE). + g^2(t) = d sigma^2(t)/ dt for the exploding variance scheme. This becomes discretized as + g^2_i = sigma^2_{i} - sigma^2_{i-1} for i = 1, ..., N. + + --> The papers never clearly state what to do for sigma_{i=0}. + We CHOOSE sigma_{i=0} = sigma_min = sigma(t=0) + + - eps and sqrt_2_eps: + This is for Langevin dynamics within a corrector step. Following [3], we define + + eps_i = 0.5 epsilon_step * sigma^2_i / sigma^2_1 for i = 0, ..., N-1. + + --> Careful! eps_0 is needed for the corrector steps. """ def __init__(self, noise_parameters: NoiseParameters): @@ -39,16 +76,23 @@ def __init__(self, noise_parameters: NoiseParameters): noise_parameters: parameters that define the noise schedule. """ self.noise_parameters = noise_parameters - self._time_array = torch.linspace(0, 1, noise_parameters.total_time_steps) + self._time_array = self._get_time_array(noise_parameters) self._sigma_array = self._create_sigma_array(noise_parameters, self._time_array) self._sigma_squared_array = self._sigma_array**2 - self._g_squared_array = self._create_g_squared_array(self._sigma_squared_array) + self._g_squared_array = self._create_g_squared_array(noise_parameters, self._sigma_squared_array) self._g_array = torch.sqrt(self._g_squared_array) + self._epsilon_array = self._create_epsilon_array(noise_parameters, self._sigma_squared_array) + self._sqrt_two_epsilon_array = torch.sqrt(2. * self._epsilon_array) + self._maximum_random_index = noise_parameters.total_time_steps - 1 - self._minimum_random_index = 1 # we don't want to randomly sample "0". + self._minimum_random_index = 0 + + @staticmethod + def _get_time_array(noise_parameters: NoiseParameters) -> torch.Tensor: + return torch.linspace(noise_parameters.time_delta, 1., noise_parameters.total_time_steps) @staticmethod def _create_sigma_array( @@ -61,17 +105,30 @@ def _create_sigma_array( return sigma @staticmethod - def _create_g_squared_array(sigma_squared_array: torch.Tensor) -> torch.Tensor: - nan_tensor = torch.tensor([float("nan")]) + def _create_g_squared_array(noise_parameters: NoiseParameters, sigma_squared_array: torch.Tensor) -> torch.Tensor: + # g^2_{i} = sigma^2_{i} - sigma^2_{i-1}. For the first element (i=1), we set sigma_{0} = sigma_min. + sigma_min = noise_parameters.sigma_min + zeroth_value_tensor = torch.tensor([sigma_squared_array[0] - sigma_min**2]) + return torch.cat( + [zeroth_value_tensor, sigma_squared_array[1:] - sigma_squared_array[:-1]] + ) + + @staticmethod + def _create_epsilon_array(noise_parameters: NoiseParameters, sigma_squared_array: torch.Tensor) -> torch.Tensor: + + sigma_squared_0 = noise_parameters.sigma_min**2 + sigma_squared_1 = sigma_squared_array[0] + eps = noise_parameters.corrector_step_epsilon + + zeroth_value_tensor = torch.tensor([0.5 * eps * sigma_squared_0 / sigma_squared_1]) return torch.cat( - [nan_tensor, sigma_squared_array[1:] - sigma_squared_array[:-1]] + [zeroth_value_tensor, 0.5 * eps * sigma_squared_array[:-1] / sigma_squared_1] ) def _get_random_time_step_indices(self, shape: Tuple[int]) -> torch.Tensor: """Random time step indices. Generate random indices that correspond to valid time steps. - This sampling avoids index "0", which corresponds to time "0". Args: shape: shape of the random index array. @@ -118,19 +175,24 @@ def get_random_noise_sample(self, batch_size: int) -> Noise: g_squared=gs_squared, ) - def get_all_noise(self) -> Noise: - """Get all noise. + def get_all_sampling_parameters(self) -> Tuple[Noise, LangevinDynamics]: + """Get all sampling parameters. - All the internal noise parameter arrays, passed as a Noise object. + All the internal noise parameter arrays and Langevin dynamics arrays. Returns: all_noise: a collection of all the noise parameters (t, sigma, sigma^2, g, g^2) for all indices. The arrays are all of dimension [total_time_steps]. + langevin_dynamics: a collection of all the langevin dynamics parmaters (epsilon, sqrt{2epsilon}) + needed to apply a langevin dynamics corrector step. """ - return Noise( + noise = Noise( time=self._time_array, sigma=self._sigma_array, sigma_squared=self._sigma_squared_array, g=self._g_array, - g_squared=self._g_squared_array, - ) + g_squared=self._g_squared_array) + langevin_dynamics = LangevinDynamics(epsilon=self._epsilon_array, + sqrt_2_epsilon=self._sqrt_two_epsilon_array) + + return noise, langevin_dynamics diff --git a/sanity_checks/overfit_fake_data.py b/sanity_checks/overfit_fake_data.py index 3754331c..5025f6dc 100644 --- a/sanity_checks/overfit_fake_data.py +++ b/sanity_checks/overfit_fake_data.py @@ -1,14 +1,14 @@ """Overfit fake data. A simple sanity check experiment to check the learning behavior of the position diffusion model. -The training data is taken to be a large batch of identical configurations composed of one atom in 1D at the origin. +The training data is taken to be a large batch of identical configurations composed of one atom in 1D at x0. This highly artificial case is useful to sanity check that the code behaves as expected: - the loss should converge towards zero - the trained score network should reproduce the perturbation kernel, at least in the regions where it is sampled. + - the generated samples should be tightly clustered around x0. """ import os -import matplotlib.pyplot as plt import pytorch_lightning import torch from pytorch_lightning import Trainer @@ -16,33 +16,31 @@ from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader -from crystal_diffusion import ANALYSIS_RESULTS_DIR -from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH from crystal_diffusion.models.callbacks import ( - HPLoggingCallback, TensorboardHistogramLoggingCallback, - TensorboardSamplesLoggingCallback) + HPLoggingCallback, TensorboardGeneratedSamplesLoggingCallback, + TensorboardHistogramLoggingCallback, TensorboardSamplesLoggingCallback, + TensorboardScoreAndErrorLoggingCallback) from crystal_diffusion.models.optimizer import (OptimizerParameters, ValidOptimizerNames) from crystal_diffusion.models.position_diffusion_lightning_model import ( PositionDiffusionLightningModel, PositionDiffusionParameters) from crystal_diffusion.models.score_network import MLPScoreNetworkParameters from crystal_diffusion.samplers.variance_sampler import NoiseParameters -from crystal_diffusion.score.wrapped_gaussian_score import \ - get_sigma_normalized_score from sanity_checks import SANITY_CHECK_FOLDER -plt.style.use(PLOT_STYLE_PATH) - batch_size = 4096 number_of_atoms = 1 spatial_dimension = 1 total_time_steps = 100 +number_of_corrector_steps = 1 + +x0 = 0.5 sigma_min = 0.005 sigma_max = 0.5 -lr = 0.01 -max_epochs = 3000 +lr = 0.001 +max_epochs = 2000 hidden_dimensions = [64, 128, 256] @@ -63,13 +61,21 @@ noise_parameters=noise_parameters, ) +generated_samples_callback = ( + TensorboardGeneratedSamplesLoggingCallback(noise_parameters=noise_parameters, + number_of_corrector_steps=number_of_corrector_steps, + score_network_parameters=score_network_parameters, + number_of_samples=1024)) + +score_error_callback = TensorboardScoreAndErrorLoggingCallback(x0=x0) + tbx_logger = TensorBoardLogger(save_dir=os.path.join(SANITY_CHECK_FOLDER, "tensorboard"), name="overfit_fake_data") if __name__ == '__main__': pytorch_lightning.seed_everything(123) - all_positions = torch.zeros(batch_size, number_of_atoms, spatial_dimension) + all_positions = x0 * torch.ones(batch_size, number_of_atoms, spatial_dimension) data = [dict(relative_positions=configuration) for configuration in all_positions] train_dataloader = DataLoader(data, batch_size=batch_size) @@ -78,58 +84,11 @@ trainer = Trainer(accelerator='cpu', max_epochs=max_epochs, logger=tbx_logger, - log_every_n_steps=50, + log_every_n_steps=25, callbacks=[HPLoggingCallback(), + generated_samples_callback, + score_error_callback, TensorboardHistogramLoggingCallback(), TensorboardSamplesLoggingCallback(), LearningRateMonitor(logging_interval='step')]) trainer.fit(lightning_model, train_dataloaders=train_dataloader) - - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - fig.suptitle("Predictions, Targets and Errors within 2 $\\sigma$ of Data Point") - ax1 = fig.add_subplot(121) - ax2 = fig.add_subplot(122) - ax1.set_ylabel('$\\sigma \\times S_{\\theta}(x, t)$') - ax2.set_ylabel('$\\sigma \\times S_{\\theta}(x, t) - \\sigma \\nabla \\log P(x | 0)$') - ax1.set_xlabel('$x$') - ax2.set_xlabel('$x$') - - list_x = torch.linspace(0, 1, 1001)[:-1] - - times = torch.tensor([0.25, 0.75, 1.]) - sigmas = lightning_model.variance_sampler._create_sigma_array(lightning_model.variance_sampler.noise_parameters, - times) - - with torch.no_grad(): - for time, sigma in zip(times, sigmas): - times = time * torch.ones(1000).reshape(-1, 1) - - sigma_normalized_kernel = get_sigma_normalized_score(list_x, - sigma * torch.ones_like(list_x), - kmax=4) - predicted_normalized_scores = lightning_model._get_predicted_normalized_score(list_x.reshape(-1, 1, 1), - times).flatten() - - error = predicted_normalized_scores - sigma_normalized_kernel - - # only plot the errors in the sampling region! - m1 = list_x < 2 * sigma - m2 = 1. - list_x < 2 * sigma - - lines = ax1.plot(list_x[m1], predicted_normalized_scores[m1], lw=1, label='PREDICTION') - color = lines[0].get_color() - ax1.plot(list_x[m2], predicted_normalized_scores[m2], lw=1, color=color, label='_none_') - - ax1.plot(list_x[m1], sigma_normalized_kernel[m1], '--', lw=1, color=color, label='TARGET') - ax1.plot(list_x[m2], sigma_normalized_kernel[m2], '--', lw=1, color=color, label='_none_') - - ax2.plot(list_x[m1], error[m1], '-', color=color, - label=f't = {time:4.3f}, $\\sigma$ = {sigma:4.3f}') - ax2.plot(list_x[m2], error[m2], '-', color=color, label='_none_') - - for ax in [ax1, ax2]: - ax.set_xlim([-0.05, 1.05]) - ax.legend(loc=0) - - fig.tight_layout() - fig.savefig(ANALYSIS_RESULTS_DIR.joinpath("overfit_fake_data_trained_model_errors.png")) diff --git a/tests/models/test_score_network.py b/tests/models/test_score_network.py index dbbbdb07..0fd1e7da 100644 --- a/tests/models/test_score_network.py +++ b/tests/models/test_score_network.py @@ -1,10 +1,10 @@ import pytest import torch -from crystal_diffusion.models.score_network import (BaseScoreNetwork, - BaseScoreNetworkParameters, +from crystal_diffusion.models.score_network import (BaseScoreNetworkParameters, MLPScoreNetwork, - MLPScoreNetworkParameters) + MLPScoreNetworkParameters, + ScoreNetwork) @pytest.mark.parametrize("spatial_dimension", [2, 3]) @@ -16,14 +16,14 @@ def set_random_seed(self): @pytest.fixture() def base_score_network(self, spatial_dimension): - return BaseScoreNetwork(BaseScoreNetworkParameters(spatial_dimension=spatial_dimension)) + return ScoreNetwork(BaseScoreNetworkParameters(spatial_dimension=spatial_dimension)) @pytest.fixture() def good_batch(self, spatial_dimension): batch_size = 16 positions = torch.rand(batch_size, 8, spatial_dimension) times = torch.rand(batch_size, 1) - return {BaseScoreNetwork.position_key: positions, BaseScoreNetwork.timestep_key: times} + return {ScoreNetwork.position_key: positions, ScoreNetwork.timestep_key: times} @pytest.fixture() def bad_batch(self, good_batch, problem): @@ -32,33 +32,33 @@ def bad_batch(self, good_batch, problem): match problem: case "position_name": - bad_batch['bad_position_name'] = bad_batch[BaseScoreNetwork.position_key] - del bad_batch[BaseScoreNetwork.position_key] + bad_batch['bad_position_name'] = bad_batch[ScoreNetwork.position_key] + del bad_batch[ScoreNetwork.position_key] case "position_shape": - shape = bad_batch[BaseScoreNetwork.position_key].shape - bad_batch[BaseScoreNetwork.position_key] = \ - bad_batch[BaseScoreNetwork.position_key].reshape(shape[0], shape[1] // 2, shape[2] * 2) + shape = bad_batch[ScoreNetwork.position_key].shape + bad_batch[ScoreNetwork.position_key] = \ + bad_batch[ScoreNetwork.position_key].reshape(shape[0], shape[1] // 2, shape[2] * 2) case "position_range1": - bad_batch[BaseScoreNetwork.position_key][0, 0, 0] = 1.01 + bad_batch[ScoreNetwork.position_key][0, 0, 0] = 1.01 case "position_range2": - bad_batch[BaseScoreNetwork.position_key][1, 0, 0] = -0.01 + bad_batch[ScoreNetwork.position_key][1, 0, 0] = -0.01 case "time_name": - bad_batch['bad_time_name'] = bad_batch[BaseScoreNetwork.timestep_key] - del bad_batch[BaseScoreNetwork.timestep_key] + bad_batch['bad_time_name'] = bad_batch[ScoreNetwork.timestep_key] + del bad_batch[ScoreNetwork.timestep_key] case "time_shape": shape = bad_batch['time'].shape - bad_batch[BaseScoreNetwork.timestep_key] = ( - bad_batch[BaseScoreNetwork.timestep_key].reshape(shape[0] // 2, shape[1] * 2)) + bad_batch[ScoreNetwork.timestep_key] = ( + bad_batch[ScoreNetwork.timestep_key].reshape(shape[0] // 2, shape[1] * 2)) case "time_range1": - bad_batch[BaseScoreNetwork.timestep_key][5, 0] = 2.00 + bad_batch[ScoreNetwork.timestep_key][5, 0] = 2.00 case "time_range2": - bad_batch[BaseScoreNetwork.timestep_key][0, 0] = -0.05 + bad_batch[ScoreNetwork.timestep_key][0, 0] = -0.05 return bad_batch @@ -93,13 +93,13 @@ def expected_score_shape(self, batch_size, number_of_atoms, spatial_dimension): def good_batch(self, batch_size, number_of_atoms, spatial_dimension): positions = torch.rand(batch_size, number_of_atoms, spatial_dimension) times = torch.rand(batch_size, 1) - return {BaseScoreNetwork.position_key: positions, BaseScoreNetwork.timestep_key: times} + return {ScoreNetwork.position_key: positions, ScoreNetwork.timestep_key: times} @pytest.fixture() def bad_batch(self, batch_size, number_of_atoms, spatial_dimension): positions = torch.rand(batch_size, number_of_atoms // 2, spatial_dimension) times = torch.rand(batch_size, 1) - return {BaseScoreNetwork.position_key: positions, BaseScoreNetwork.timestep_key: times} + return {ScoreNetwork.position_key: positions, ScoreNetwork.timestep_key: times} @pytest.fixture() def score_network(self, number_of_atoms, spatial_dimension, hidden_dimensions): diff --git a/tests/samplers/test_predictor_corrector_position_sampler.py b/tests/samplers/test_predictor_corrector_position_sampler.py new file mode 100644 index 00000000..c92ce924 --- /dev/null +++ b/tests/samplers/test_predictor_corrector_position_sampler.py @@ -0,0 +1,194 @@ +import pytest +import torch + +from crystal_diffusion.models.score_network import (MLPScoreNetwork, + MLPScoreNetworkParameters) +from crystal_diffusion.samplers.noisy_position_sampler import \ + map_positions_to_unit_cell +from crystal_diffusion.samplers.predictor_corrector_position_sampler import ( + AnnealedLangevinDynamicsSampler, PredictorCorrectorPositionSampler) +from crystal_diffusion.samplers.variance_sampler import ( + ExplodingVarianceSampler, NoiseParameters) + + +class FakePCSampler(PredictorCorrectorPositionSampler): + """A dummy PC sampler for the purpose of testing.""" + + def __init__( + self, + number_of_discretization_steps: int, + number_of_corrector_steps: int, + initial_sample: torch.Tensor, + ): + super().__init__(number_of_discretization_steps, number_of_corrector_steps) + self.initial_sample = initial_sample + + def initialize(self, number_of_samples: int): + return self.initial_sample + + def predictor_step(self, x_ip1: torch.Tensor, ip1: int) -> torch.Tensor: + return 1.2 * x_ip1 + 3.4 + ip1 / 111.0 + + def corrector_step(self, x_i: torch.Tensor, i: int) -> torch.Tensor: + return 0.56 * x_i + 7.89 + i / 117.0 + + +@pytest.mark.parametrize("number_of_samples", [4]) +@pytest.mark.parametrize("number_of_atoms", [8]) +@pytest.mark.parametrize("spatial_dimension", [2, 3]) +@pytest.mark.parametrize("number_of_discretization_steps", [1, 5, 10]) +@pytest.mark.parametrize("number_of_corrector_steps", [0, 1, 2]) +class TestPredictorCorrectorPositionSampler: + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(1234567) + + @pytest.fixture + def initial_sample(self, number_of_samples, number_of_atoms, spatial_dimension): + return torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + + @pytest.fixture + def sampler( + self, number_of_discretization_steps, number_of_corrector_steps, initial_sample + ): + sampler = FakePCSampler( + number_of_discretization_steps, number_of_corrector_steps, initial_sample + ) + return sampler + + @pytest.fixture + def expected_samples( + self, + sampler, + initial_sample, + number_of_discretization_steps, + number_of_corrector_steps, + ): + list_i = list(range(number_of_discretization_steps)) + list_i.reverse() + list_j = list(range(number_of_corrector_steps)) + + noisy_sample = map_positions_to_unit_cell(initial_sample) + x_ip1 = noisy_sample + for i in list_i: + xi = map_positions_to_unit_cell(sampler.predictor_step(x_ip1, i + 1)) + for _ in list_j: + xi = map_positions_to_unit_cell(sampler.corrector_step(xi, i)) + x_ip1 = xi + return xi + + def test_sample(self, sampler, number_of_samples, expected_samples): + computed_samples = sampler.sample(number_of_samples) + torch.testing.assert_allclose(expected_samples, computed_samples) + + +@pytest.mark.parametrize("total_time_steps", [1, 5, 10]) +@pytest.mark.parametrize("number_of_corrector_steps", [0, 1, 2]) +@pytest.mark.parametrize("spatial_dimension", [2, 3]) +@pytest.mark.parametrize("hidden_dimensions", [[8, 16, 32]]) +@pytest.mark.parametrize("number_of_atoms", [8]) +@pytest.mark.parametrize("time_delta", [0.1]) +@pytest.mark.parametrize("sigma_min", [0.15]) +@pytest.mark.parametrize("corrector_step_epsilon", [0.25]) +@pytest.mark.parametrize("number_of_samples", [8]) +class TestAnnealedLangevinDynamics: + @pytest.fixture() + def sigma_normalized_score_network( + self, number_of_atoms, spatial_dimension, hidden_dimensions + ): + hyper_params = MLPScoreNetworkParameters( + spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + hidden_dimensions=hidden_dimensions, + ) + return MLPScoreNetwork(hyper_params) + + @pytest.fixture() + def noise_parameters(self, total_time_steps, time_delta, sigma_min, corrector_step_epsilon): + noise_parameters = NoiseParameters(total_time_steps=total_time_steps, + time_delta=time_delta, + sigma_min=sigma_min, + corrector_step_epsilon=corrector_step_epsilon) + return noise_parameters + + @pytest.fixture() + def pc_sampler(self, noise_parameters, + number_of_corrector_steps, + number_of_atoms, + spatial_dimension, + sigma_normalized_score_network): + sampler = AnnealedLangevinDynamicsSampler(noise_parameters=noise_parameters, + number_of_corrector_steps=number_of_corrector_steps, + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + sigma_normalized_score_network=sigma_normalized_score_network) + + return sampler + + def test_smoke_sample(self, pc_sampler, number_of_samples): + # Just a smoke test that we can sample without crashing. + pc_sampler.sample(number_of_samples) + + @pytest.fixture() + def x_i(self, number_of_samples, number_of_atoms, spatial_dimension): + return map_positions_to_unit_cell(torch.rand(number_of_samples, number_of_atoms, spatial_dimension)) + + def test_predictor_step(self, mocker, pc_sampler, noise_parameters, x_i, total_time_steps, number_of_samples): + + sampler = ExplodingVarianceSampler(noise_parameters) + noise, _ = sampler.get_all_sampling_parameters() + sigma_min = noise_parameters.sigma_min + list_sigma = noise.sigma + list_time = noise.time + + z = pc_sampler._draw_gaussian_sample(number_of_samples) + mocker.patch.object(pc_sampler, "_draw_gaussian_sample", return_value=z) + + for index_i in range(1, total_time_steps + 1): + computed_sample = pc_sampler.predictor_step(x_i, index_i) + + sigma_i = list_sigma[index_i - 1] + t_i = list_time[index_i - 1] + if index_i == 1: + sigma_im1 = sigma_min + else: + sigma_im1 = list_sigma[index_i - 2] + + g2 = sigma_i**2 - sigma_im1**2 + + s_i = pc_sampler._get_sigma_normalized_scores(x_i, t_i) / sigma_i + + expected_sample = x_i + g2 * s_i + torch.sqrt(g2) * z + + torch.testing.assert_allclose(computed_sample, expected_sample) + + def test_corrector_step(self, mocker, pc_sampler, noise_parameters, x_i, total_time_steps, number_of_samples): + + sampler = ExplodingVarianceSampler(noise_parameters) + noise, _ = sampler.get_all_sampling_parameters() + sigma_min = noise_parameters.sigma_min + epsilon = noise_parameters.corrector_step_epsilon + list_sigma = noise.sigma + list_time = noise.time + sigma_1 = list_sigma[0] + + z = pc_sampler._draw_gaussian_sample(number_of_samples) + mocker.patch.object(pc_sampler, "_draw_gaussian_sample", return_value=z) + + for index_i in range(0, total_time_steps): + computed_sample = pc_sampler.corrector_step(x_i, index_i) + + if index_i == 0: + sigma_i = sigma_min + t_i = 0. + else: + sigma_i = list_sigma[index_i - 1] + t_i = list_time[index_i - 1] + + eps_i = 0.5 * epsilon * sigma_i**2 / sigma_1**2 + + s_i = pc_sampler._get_sigma_normalized_scores(x_i, t_i) / sigma_i + + expected_sample = x_i + eps_i * s_i + torch.sqrt(2. * eps_i) * z + + torch.testing.assert_allclose(computed_sample, expected_sample) diff --git a/tests/samplers/test_variance_sampler.py b/tests/samplers/test_variance_sampler.py index 0f997ef3..64fd04d3 100644 --- a/tests/samplers/test_variance_sampler.py +++ b/tests/samplers/test_variance_sampler.py @@ -6,18 +6,28 @@ @pytest.mark.parametrize("total_time_steps", [3, 10, 17]) +@pytest.mark.parametrize("time_delta", [1e-5, 0.1]) +@pytest.mark.parametrize("sigma_min", [0.005, 0.1]) +@pytest.mark.parametrize("corrector_step_epsilon", [2e-5, 0.1]) class TestExplodingVarianceSampler: @pytest.fixture() - def noise_parameters(self, total_time_steps): - return NoiseParameters(total_time_steps=total_time_steps) + def noise_parameters(self, total_time_steps, time_delta, sigma_min, corrector_step_epsilon): + return NoiseParameters(total_time_steps=total_time_steps, + time_delta=time_delta, + sigma_min=sigma_min, + corrector_step_epsilon=corrector_step_epsilon) @pytest.fixture() def variance_sampler(self, noise_parameters): return ExplodingVarianceSampler(noise_parameters=noise_parameters) @pytest.fixture() - def expected_times(self, total_time_steps): - times = torch.linspace(0.0, 1.0, total_time_steps) + def expected_times(self, total_time_steps, time_delta): + times = [] + for i in range(total_time_steps): + t = i / (total_time_steps - 1) * (1. - time_delta) + time_delta + times.append(t) + times = torch.tensor(times) return times @pytest.fixture() @@ -28,21 +38,37 @@ def expected_sigmas(self, expected_times, noise_parameters): sigmas = smin ** (1.0 - expected_times) * smax**expected_times return sigmas + @pytest.fixture() + def expected_epsilons(self, expected_sigmas, noise_parameters): + smin = noise_parameters.sigma_min + eps = noise_parameters.corrector_step_epsilon + + s1 = expected_sigmas[0] + + epsilons = [0.5 * eps * smin**2 / s1**2] + for i in range(len(expected_sigmas) - 1): + si = expected_sigmas[i] + epsilons.append(0.5 * eps * si**2 / s1**2) + + return torch.tensor(epsilons) + @pytest.fixture() def indices(self, time_sampler, shape): return time_sampler.get_random_time_step_indices(shape) def test_time_array(self, variance_sampler, expected_times): - torch.testing.assert_allclose(variance_sampler._time_array, expected_times) + torch.testing.assert_close(variance_sampler._time_array, expected_times) def test_sigma_and_sigma_squared_arrays(self, variance_sampler, expected_sigmas): torch.testing.assert_allclose(variance_sampler._sigma_array, expected_sigmas) torch.testing.assert_allclose(variance_sampler._sigma_squared_array, expected_sigmas**2) - def test_g_and_g_square_array(self, variance_sampler, expected_sigmas): + def test_g_and_g_square_array(self, variance_sampler, expected_sigmas, sigma_min): expected_sigmas_square = expected_sigmas**2 - expected_g_squared_array = [float("nan")] + sigma1 = torch.sqrt(expected_sigmas_square[0]) + + expected_g_squared_array = [sigma1**2 - sigma_min**2] for sigma2_t, sigma2_tm1 in zip( expected_sigmas_square[1:], expected_sigmas_square[:-1] ): @@ -52,15 +78,16 @@ def test_g_and_g_square_array(self, variance_sampler, expected_sigmas): expected_g_squared_array = torch.tensor(expected_g_squared_array) expected_g_array = torch.sqrt(expected_g_squared_array) - assert torch.isnan(variance_sampler._g_array[0]) - assert torch.isnan(variance_sampler._g_squared_array[0]) - torch.testing.assert_allclose(variance_sampler._g_array[1:], expected_g_array[1:]) - torch.testing.assert_allclose(variance_sampler._g_squared_array[1:], expected_g_squared_array[1:]) + torch.testing.assert_allclose(variance_sampler._g_array, expected_g_array) + torch.testing.assert_allclose(variance_sampler._g_squared_array, expected_g_squared_array) + + def test_epsilon_arrays(self, variance_sampler, expected_epsilons): + torch.testing.assert_allclose(variance_sampler._epsilon_array, expected_epsilons) + torch.testing.assert_allclose(variance_sampler._sqrt_two_epsilon_array, torch.sqrt(2. * expected_epsilons)) def test_get_random_time_step_indices(self, variance_sampler, total_time_steps): - # Check that we never sample zero. random_indices = variance_sampler._get_random_time_step_indices(shape=(1000,)) - assert torch.all(random_indices > 0) + assert torch.all(random_indices >= 0) assert torch.all(random_indices < total_time_steps) @pytest.mark.parametrize("batch_size", [1, 10, 100]) @@ -90,12 +117,13 @@ def test_get_random_noise_parameter_sample( torch.testing.assert_allclose(noise_sample.g, expected_gs) torch.testing.assert_allclose(noise_sample.g_squared, expected_gs_squared) - def test_get_all_noise(self, variance_sampler): - noise = variance_sampler.get_all_noise() + def test_get_all_sampling_parameters(self, variance_sampler): + noise, langevin_dynamics = variance_sampler.get_all_sampling_parameters() torch.testing.assert_allclose(noise.time, variance_sampler._time_array) torch.testing.assert_allclose(noise.sigma, variance_sampler._sigma_array) torch.testing.assert_allclose(noise.sigma_squared, variance_sampler._sigma_squared_array) - assert torch.isnan(noise.g[0]) - assert torch.isnan(noise.g_squared[0]) - torch.testing.assert_allclose(noise.g[1:], variance_sampler._g_array[1:]) - torch.testing.assert_allclose(noise.g_squared[1:], variance_sampler._g_squared_array[1:]) + torch.testing.assert_allclose(noise.g, variance_sampler._g_array) + torch.testing.assert_allclose(noise.g_squared, variance_sampler._g_squared_array) + + torch.testing.assert_allclose(langevin_dynamics.epsilon, variance_sampler._epsilon_array) + torch.testing.assert_allclose(langevin_dynamics.sqrt_2_epsilon, variance_sampler._sqrt_two_epsilon_array)