Skip to content

Commit

Permalink
Merge pull request #12 from mila-iqia/sampling
Browse files Browse the repository at this point in the history
Sampling
  • Loading branch information
rousseab authored Apr 2, 2024
2 parents 46dd060 + 51839ad commit b2b94d4
Show file tree
Hide file tree
Showing 9 changed files with 696 additions and 144 deletions.
36 changes: 17 additions & 19 deletions crystal_diffusion/analysis/exploding_variance_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,25 @@
noise_parameters = NoiseParameters(total_time_steps=1000)
variance_sampler = ExplodingVarianceSampler(noise_parameters=noise_parameters)

noise = variance_sampler.get_all_noise()
noise, langevin_dynamics = variance_sampler.get_all_sampling_parameters()

fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig1.suptitle("Noise Schedule")

ax1 = fig1.add_subplot(221)
ax2 = fig1.add_subplot(223)
ax3 = fig1.add_subplot(122)
ax1 = fig1.add_subplot(121)
ax2 = fig1.add_subplot(122)

ax1.plot(noise.time, noise.sigma, '-', c='k', lw=2)
ax2.plot(noise.time[1:], noise.g[1:], '-', c='k', lw=2)
ax1.plot(noise.time, noise.sigma, '-', c='b', lw=4, label='$\\sigma(t)$')
ax1.plot(noise.time, noise.g, '-', c='g', lw=4, label="$g(t)$")

ax1.set_ylabel('$\\sigma(t)$')
ax2.set_ylabel('$g(t)$')
shifted_time = torch.cat([torch.tensor([0]), noise.time[:-1]])
ax1.plot(shifted_time, langevin_dynamics.epsilon, '-', c='r', lw=4, label="$\\epsilon(t)$")
ax1.legend(loc=0)

for ax in [ax1, ax2]:
ax.set_xlabel('time')
ax.set_xlim([-0.01, 1.01])
ax1.set_xlabel('time')
ax1.set_xlim([-0.01, 1.01])

ax1.set_title("$\\sigma$ schedule")
ax2.set_title("g schedule")
ax1.set_title("$\\sigma, g, \\epsilon$ schedules")

relative_positions = torch.linspace(0, 1, 101)[:-1]

Expand All @@ -55,13 +53,13 @@
target_sigma_normalized_scores = get_sigma_normalized_score(relative_positions,
torch.ones_like(relative_positions) * sigma,
kmax=kmax)
ax3.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}")
ax2.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}")

ax3.set_title("Target Normalized Score")
ax3.set_xlabel("relative position, u")
ax3.set_ylabel("$\\sigma(t) \\times S(u, t)$")
ax3.legend(loc=0)
ax3.set_xlim([-0.01, 1.01])
ax2.set_title("Target Normalized Score")
ax2.set_xlabel("relative position, u")
ax2.set_ylabel("$\\sigma(t) \\times S(u, t)$")
ax2.legend(loc=0)
ax2.set_xlim([-0.01, 1.01])

fig1.tight_layout()

Expand Down
119 changes: 117 additions & 2 deletions crystal_diffusion/models/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
from pytorch_lightning import Callback
from pytorch_lightning.loggers import TensorBoardLogger

from crystal_diffusion.analysis import PLEASANT_FIG_SIZE
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.models.score_network import MLPScoreNetworkParameters
from crystal_diffusion.samplers.noisy_position_sampler import \
map_positions_to_unit_cell
from crystal_diffusion.samplers.predictor_corrector_position_sampler import \
AnnealedLangevinDynamicsSampler
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
from crystal_diffusion.score.wrapped_gaussian_score import \
get_sigma_normalized_score

plt.style.use(PLOT_STYLE_PATH)
TENSORBOARD_FIGSIZE = (0.65 * PLEASANT_FIG_SIZE[0], 0.65 * PLEASANT_FIG_SIZE[1])


class HPLoggingCallback(Callback):
Expand Down Expand Up @@ -94,10 +105,114 @@ def log_artifact(self, pl_module, tbx_logger):
list_sigmas.append(output["sigmas"].flatten())
list_xt = torch.cat(list_xt)
list_sigmas = torch.cat(list_sigmas)
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig = plt.figure(figsize=TENSORBOARD_FIGSIZE)
ax = fig.add_subplot(111)
ax.set_title(f"Position Samples: global step = {pl_module.global_step}")
ax.set_ylabel("$\\sigma$")
ax.set_xlabel("position samples $x(t)$")
ax.plot(list_xt, list_sigmas, "bo")
ax.set_xlim([-0.05, 1.05])
fig.tight_layout()
tbx_logger.add_figure("train/samples", fig, global_step=pl_module.global_step)


class TensorboardGeneratedSamplesLoggingCallback(TensorBoardDebuggingLoggingCallback):
"""This callback will log an image of a histogram of generated samples on tensorboard."""

def __init__(self, noise_parameters: NoiseParameters,
number_of_corrector_steps: int,
score_network_parameters: MLPScoreNetworkParameters, number_of_samples: int):
"""Init method."""
super().__init__()
self.noise_parameters = noise_parameters
self.number_of_corrector_steps = number_of_corrector_steps
self.score_network_parameters = score_network_parameters
self.number_of_atoms = score_network_parameters.number_of_atoms
self.spatial_dimension = score_network_parameters.spatial_dimension
self.number_of_samples = number_of_samples

def log_artifact(self, pl_module, tbx_logger):
"""Create artifact and log to tensorboard."""
sigma_normalized_score_network = pl_module.sigma_normalized_score_network
pc_sampler = AnnealedLangevinDynamicsSampler(noise_parameters=self.noise_parameters,
number_of_corrector_steps=self.number_of_corrector_steps,
number_of_atoms=self.number_of_atoms,
spatial_dimension=self.spatial_dimension,
sigma_normalized_score_network=sigma_normalized_score_network)

samples = pc_sampler.sample(self.number_of_samples).flatten()

fig = plt.figure(figsize=TENSORBOARD_FIGSIZE)
ax = fig.add_subplot(111)
ax.set_title(f"Generated Samples: global step = {pl_module.global_step}")
ax.set_xlabel('$x$')
ax.hist(samples, bins=101, range=(0, 1), label=f'{self.number_of_samples} samples')
ax.set_title("Samples Count")
ax.set_xlim([-0.05, 1.05])
ax.set_ylim([0., self.number_of_samples])
fig.tight_layout()
tbx_logger.add_figure("train/generated_samples", fig, global_step=pl_module.global_step)


class TensorboardScoreAndErrorLoggingCallback(TensorBoardDebuggingLoggingCallback):
"""This callback will log histograms of the labels, predictions and errors on tensorboard."""

def __init__(self, x0: float):
"""Init method."""
super().__init__()
self.x0 = x0

def log_artifact(self, pl_module, tbx_logger):
"""Create artifact and log to tensorboard."""
fig = plt.figure(figsize=TENSORBOARD_FIGSIZE)
fig.suptitle("Scores within 2 $\\sigma$ of Data")
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.set_ylabel('$\\sigma \\times S_{\\theta}(x, t)$')
ax2.set_ylabel('$\\sigma \\times S_{\\theta}(x, t) - \\sigma \\nabla \\log P(x | 0)$')
for ax in [ax1, ax2]:
ax.set_xlabel('$x$')

list_x = torch.linspace(0, 1, 1001)[:-1]

times = torch.tensor([0.25, 0.75, 1.])
sigmas = pl_module.variance_sampler._create_sigma_array(pl_module.variance_sampler.noise_parameters, times)

with torch.no_grad():
for time, sigma in zip(times, sigmas):
times = time * torch.ones(1000).reshape(-1, 1)

sigma_normalized_kernel = get_sigma_normalized_score(map_positions_to_unit_cell(list_x - self.x0),
sigma * torch.ones_like(list_x),
kmax=4)
predicted_normalized_scores = pl_module._get_predicted_normalized_score(list_x.reshape(-1, 1, 1),
times).flatten()

error = predicted_normalized_scores - sigma_normalized_kernel

# only plot the errors in the sampling region! These regions might be disconnected, let's make
# sure the continuous lines make sense.
mask1 = torch.abs(list_x - self.x0) < 2 * sigma
mask2 = torch.abs(1. - list_x + self.x0) < 2 * sigma

lines = ax1.plot(list_x[mask1], predicted_normalized_scores[mask1], lw=1, label='Prediction')
color = lines[0].get_color()
ax1.plot(list_x[mask2], predicted_normalized_scores[mask2], lw=1, color=color, label='_none_')

ax1.plot(list_x[mask1], sigma_normalized_kernel[mask1], '--', lw=2, color=color, label='Target')
ax1.plot(list_x[mask2], sigma_normalized_kernel[mask2], '--', lw=2, color=color, label='_none_')

ax2.plot(list_x[mask1], error[mask1], '-', color=color,
label=f't = {time:4.3f}, $\\sigma$ = {sigma:4.3f}')
ax2.plot(list_x[mask2], error[mask2], '-', color=color, label='_none_')

for ax in [ax1, ax2]:
ax.set_xlim([-0.05, 1.05])
ax.legend(loc=3, prop={'size': 6})

ax1.set_ylim([-3., 3.])
ax2.set_ylim([-1., 1.])

fig.tight_layout()

tbx_logger.add_figure("train/scores", fig, global_step=pl_module.global_step)
6 changes: 3 additions & 3 deletions crystal_diffusion/models/score_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseScoreNetworkParameters:
spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live.


class BaseScoreNetwork(torch.nn.Module):
class ScoreNetwork(torch.nn.Module):
"""Base score network.
This base class defines the interface that all score networks should have
Expand All @@ -32,7 +32,7 @@ def __init__(self, hyper_params: BaseScoreNetworkParameters):
Args:
hyper_params : hyperparameters from the config file.
"""
super(BaseScoreNetwork, self).__init__()
super(ScoreNetwork, self).__init__()
self._hyper_params = hyper_params
self.spatial_dimension = hyper_params.spatial_dimension

Expand Down Expand Up @@ -124,7 +124,7 @@ class MLPScoreNetworkParameters(BaseScoreNetworkParameters):
hidden_dimensions: List[int] # dimensions of the hidden layers. Length of array determines number of layers.


class MLPScoreNetwork(BaseScoreNetwork):
class MLPScoreNetwork(ScoreNetwork):
"""Simple Model Class.
Inherits from the given framework's model class. This is a simple MLP model.
Expand Down
Loading

0 comments on commit b2b94d4

Please sign in to comment.