Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling #12

Merged
merged 22 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions crystal_diffusion/analysis/exploding_variance_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,25 @@
noise_parameters = NoiseParameters(total_time_steps=1000)
variance_sampler = ExplodingVarianceSampler(noise_parameters=noise_parameters)

noise = variance_sampler.get_all_noise()
noise, langevin_dynamics = variance_sampler.get_all_sampling_parameters()

fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig1.suptitle("Noise Schedule")

ax1 = fig1.add_subplot(221)
ax2 = fig1.add_subplot(223)
ax3 = fig1.add_subplot(122)
ax1 = fig1.add_subplot(121)
ax2 = fig1.add_subplot(122)

ax1.plot(noise.time, noise.sigma, '-', c='k', lw=2)
ax2.plot(noise.time[1:], noise.g[1:], '-', c='k', lw=2)
ax1.plot(noise.time, noise.sigma, '-', c='b', lw=4, label='$\\sigma(t)$')
ax1.plot(noise.time, noise.g, '-', c='g', lw=4, label="$g(t)$")

ax1.set_ylabel('$\\sigma(t)$')
ax2.set_ylabel('$g(t)$')
shifted_time = torch.cat([torch.tensor([0]), noise.time[:-1]])
ax1.plot(shifted_time, langevin_dynamics.epsilon, '-', c='r', lw=4, label="$\\epsilon(t)$")
ax1.legend(loc=0)

for ax in [ax1, ax2]:
ax.set_xlabel('time')
ax.set_xlim([-0.01, 1.01])
ax1.set_xlabel('time')
ax1.set_xlim([-0.01, 1.01])

ax1.set_title("$\\sigma$ schedule")
ax2.set_title("g schedule")
ax1.set_title("$\\sigma, g, \\epsilon$ schedules")

relative_positions = torch.linspace(0, 1, 101)[:-1]

Expand All @@ -55,13 +53,13 @@
target_sigma_normalized_scores = get_sigma_normalized_score(relative_positions,
torch.ones_like(relative_positions) * sigma,
kmax=kmax)
ax3.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}")
ax2.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}")

ax3.set_title("Target Normalized Score")
ax3.set_xlabel("relative position, u")
ax3.set_ylabel("$\\sigma(t) \\times S(u, t)$")
ax3.legend(loc=0)
ax3.set_xlim([-0.01, 1.01])
ax2.set_title("Target Normalized Score")
ax2.set_xlabel("relative position, u")
ax2.set_ylabel("$\\sigma(t) \\times S(u, t)$")
ax2.legend(loc=0)
ax2.set_xlim([-0.01, 1.01])

fig1.tight_layout()

Expand Down
119 changes: 117 additions & 2 deletions crystal_diffusion/models/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
from pytorch_lightning import Callback
from pytorch_lightning.loggers import TensorBoardLogger

from crystal_diffusion.analysis import PLEASANT_FIG_SIZE
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.models.score_network import MLPScoreNetworkParameters
from crystal_diffusion.samplers.noisy_position_sampler import \
map_positions_to_unit_cell
from crystal_diffusion.samplers.predictor_corrector_position_sampler import \
AnnealedLangevinDynamicsSampler
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
from crystal_diffusion.score.wrapped_gaussian_score import \
get_sigma_normalized_score

plt.style.use(PLOT_STYLE_PATH)
TENSORBOARD_FIGSIZE = (0.65 * PLEASANT_FIG_SIZE[0], 0.65 * PLEASANT_FIG_SIZE[1])


class HPLoggingCallback(Callback):
Expand Down Expand Up @@ -94,10 +105,114 @@ def log_artifact(self, pl_module, tbx_logger):
list_sigmas.append(output["sigmas"].flatten())
list_xt = torch.cat(list_xt)
list_sigmas = torch.cat(list_sigmas)
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig = plt.figure(figsize=TENSORBOARD_FIGSIZE)
ax = fig.add_subplot(111)
ax.set_title(f"Position Samples: global step = {pl_module.global_step}")
ax.set_ylabel("$\\sigma$")
ax.set_xlabel("position samples $x(t)$")
ax.plot(list_xt, list_sigmas, "bo")
ax.set_xlim([-0.05, 1.05])
fig.tight_layout()
tbx_logger.add_figure("train/samples", fig, global_step=pl_module.global_step)


class TensorboardGeneratedSamplesLoggingCallback(TensorBoardDebuggingLoggingCallback):
"""This callback will log an image of a histogram of generated samples on tensorboard."""

def __init__(self, noise_parameters: NoiseParameters,
number_of_corrector_steps: int,
score_network_parameters: MLPScoreNetworkParameters, number_of_samples: int):
"""Init method."""
super().__init__()
self.noise_parameters = noise_parameters
self.number_of_corrector_steps = number_of_corrector_steps
self.score_network_parameters = score_network_parameters
self.number_of_atoms = score_network_parameters.number_of_atoms
self.spatial_dimension = score_network_parameters.spatial_dimension
self.number_of_samples = number_of_samples

def log_artifact(self, pl_module, tbx_logger):
"""Create artifact and log to tensorboard."""
sigma_normalized_score_network = pl_module.sigma_normalized_score_network
pc_sampler = AnnealedLangevinDynamicsSampler(noise_parameters=self.noise_parameters,
number_of_corrector_steps=self.number_of_corrector_steps,
number_of_atoms=self.number_of_atoms,
spatial_dimension=self.spatial_dimension,
sigma_normalized_score_network=sigma_normalized_score_network)

samples = pc_sampler.sample(self.number_of_samples).flatten()

fig = plt.figure(figsize=TENSORBOARD_FIGSIZE)
ax = fig.add_subplot(111)
ax.set_title(f"Generated Samples: global step = {pl_module.global_step}")
ax.set_xlabel('$x$')
ax.hist(samples, bins=101, range=(0, 1), label=f'{self.number_of_samples} samples')
ax.set_title("Samples Count")
ax.set_xlim([-0.05, 1.05])
ax.set_ylim([0., self.number_of_samples])
fig.tight_layout()
tbx_logger.add_figure("train/generated_samples", fig, global_step=pl_module.global_step)


class TensorboardScoreAndErrorLoggingCallback(TensorBoardDebuggingLoggingCallback):
"""This callback will log histograms of the labels, predictions and errors on tensorboard."""

def __init__(self, x0: float):
"""Init method."""
super().__init__()
self.x0 = x0

def log_artifact(self, pl_module, tbx_logger):
"""Create artifact and log to tensorboard."""
fig = plt.figure(figsize=TENSORBOARD_FIGSIZE)
fig.suptitle("Scores within 2 $\\sigma$ of Data")
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.set_ylabel('$\\sigma \\times S_{\\theta}(x, t)$')
ax2.set_ylabel('$\\sigma \\times S_{\\theta}(x, t) - \\sigma \\nabla \\log P(x | 0)$')
for ax in [ax1, ax2]:
ax.set_xlabel('$x$')

list_x = torch.linspace(0, 1, 1001)[:-1]

times = torch.tensor([0.25, 0.75, 1.])
sigmas = pl_module.variance_sampler._create_sigma_array(pl_module.variance_sampler.noise_parameters, times)

with torch.no_grad():
for time, sigma in zip(times, sigmas):
times = time * torch.ones(1000).reshape(-1, 1)

sigma_normalized_kernel = get_sigma_normalized_score(map_positions_to_unit_cell(list_x - self.x0),
sigma * torch.ones_like(list_x),
kmax=4)
predicted_normalized_scores = pl_module._get_predicted_normalized_score(list_x.reshape(-1, 1, 1),
times).flatten()

error = predicted_normalized_scores - sigma_normalized_kernel

# only plot the errors in the sampling region! These regions might be disconnected, let's make
# sure the continuous lines make sense.
mask1 = torch.abs(list_x - self.x0) < 2 * sigma
mask2 = torch.abs(1. - list_x + self.x0) < 2 * sigma

lines = ax1.plot(list_x[mask1], predicted_normalized_scores[mask1], lw=1, label='Prediction')
color = lines[0].get_color()
ax1.plot(list_x[mask2], predicted_normalized_scores[mask2], lw=1, color=color, label='_none_')

ax1.plot(list_x[mask1], sigma_normalized_kernel[mask1], '--', lw=2, color=color, label='Target')
ax1.plot(list_x[mask2], sigma_normalized_kernel[mask2], '--', lw=2, color=color, label='_none_')

ax2.plot(list_x[mask1], error[mask1], '-', color=color,
label=f't = {time:4.3f}, $\\sigma$ = {sigma:4.3f}')
ax2.plot(list_x[mask2], error[mask2], '-', color=color, label='_none_')

for ax in [ax1, ax2]:
ax.set_xlim([-0.05, 1.05])
ax.legend(loc=3, prop={'size': 6})

ax1.set_ylim([-3., 3.])
ax2.set_ylim([-1., 1.])

fig.tight_layout()

tbx_logger.add_figure("train/scores", fig, global_step=pl_module.global_step)
6 changes: 3 additions & 3 deletions crystal_diffusion/models/score_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseScoreNetworkParameters:
spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live.


class BaseScoreNetwork(torch.nn.Module):
class ScoreNetwork(torch.nn.Module):
"""Base score network.

This base class defines the interface that all score networks should have
Expand All @@ -32,7 +32,7 @@ def __init__(self, hyper_params: BaseScoreNetworkParameters):
Args:
hyper_params : hyperparameters from the config file.
"""
super(BaseScoreNetwork, self).__init__()
super(ScoreNetwork, self).__init__()
self._hyper_params = hyper_params
self.spatial_dimension = hyper_params.spatial_dimension

Expand Down Expand Up @@ -124,7 +124,7 @@ class MLPScoreNetworkParameters(BaseScoreNetworkParameters):
hidden_dimensions: List[int] # dimensions of the hidden layers. Length of array determines number of layers.


class MLPScoreNetwork(BaseScoreNetwork):
class MLPScoreNetwork(ScoreNetwork):
"""Simple Model Class.

Inherits from the given framework's model class. This is a simple MLP model.
Expand Down
Loading
Loading