Skip to content

Commit

Permalink
New callback to show scores along a path.
Browse files Browse the repository at this point in the history
  • Loading branch information
rousseab committed Dec 26, 2024
1 parent d0d0b23 commit f335322
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
instantiate_loss_monitoring_callback
from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import \
instantiate_sampling_visualization_callback
from diffusion_for_multi_scale_molecular_dynamics.callbacks.score_viewer_callback import \
instantiate_score_viewer_callback
from diffusion_for_multi_scale_molecular_dynamics.callbacks.standard_callbacks import (
CustomProgressBar, instantiate_early_stopping_callback,
instantiate_model_checkpoint_callbacks)
Expand All @@ -16,6 +18,7 @@
model_checkpoint=instantiate_model_checkpoint_callbacks,
sampling_visualization=instantiate_sampling_visualization_callback,
loss_monitoring=instantiate_loss_monitoring_callback,
score_viewer=instantiate_score_viewer_callback
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dataclasses import dataclass
from typing import Any, AnyStr, Dict

from matplotlib import pyplot as plt
from pytorch_lightning import Callback, LightningModule, Trainer

from diffusion_for_multi_scale_molecular_dynamics.analysis.score_viewer import (
ScoreViewer, ScoreViewerParameters)
from diffusion_for_multi_scale_molecular_dynamics.loggers.logger_loader import \
log_figure
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import \
AnalyticalScoreNetworkParameters


@dataclass(kw_only=True)
class ScoreViewerCallbackParameters:
"""Parameters to decide what to plot and write to disk."""

record_every_n_epochs: int = 1

score_viewer_parameters: ScoreViewerParameters
analytical_score_network_parameters: AnalyticalScoreNetworkParameters


def instantiate_score_viewer_callback(
callback_params: Dict[AnyStr, Any], output_directory: str, verbose: bool
) -> Dict[str, Callback]:
"""Instantiate the Diffusion Sampling callback."""
analytical_score_network_parameters = (
AnalyticalScoreNetworkParameters(**callback_params['analytical_score_network']))

score_viewer_parameters = ScoreViewerParameters(**callback_params['score_viewer_parameters'])

score_viewer_callback_parameters = ScoreViewerCallbackParameters(
record_every_n_epochs=callback_params['record_every_n_epochs'],
score_viewer_parameters=score_viewer_parameters,
analytical_score_network_parameters=analytical_score_network_parameters)

callback = ScoreViewerCallback(
score_viewer_callback_parameters, output_directory
)

return dict(score_viewer=callback)


class ScoreViewerCallback(Callback):
"""Score Viewer Callback."""

def __init__(self, score_viewer_callback_parameters: ScoreViewerCallbackParameters, output_directory: str):
"""Init method."""
self.record_every_n_epochs = score_viewer_callback_parameters.record_every_n_epochs
self.score_viewer = ScoreViewer(
score_viewer_parameters=score_viewer_callback_parameters.score_viewer_parameters,
analytical_score_network_parameters=score_viewer_callback_parameters.analytical_score_network_parameters)

def _compute_results_at_this_epoch(self, current_epoch: int) -> bool:
"""Check if results should be computed at this epoch."""
return current_epoch % self.record_every_n_epochs == 0

def on_validation_end(self, trainer: Trainer, pl_model: LightningModule) -> None:
"""On validation epoch end."""
if not self._compute_results_at_this_epoch(trainer.current_epoch):
return

figure = self.score_viewer.create_figure(score_network=pl_model.axl_network)
figure.suptitle(f"Epoch {trainer.current_epoch}, Step {trainer.global_step}")
# Set the DPI so we can actually see something in the logger window.
figure.set_dpi(100)
figure.tight_layout()

for pl_logger in trainer.loggers:
log_figure(
figure=figure,
global_step=trainer.current_epoch,
dataset="validation",
pl_logger=pl_logger,
name="projected_normalized_scores",
)
plt.close(figure)

0 comments on commit f335322

Please sign in to comment.