From 404a5e6f452b76205c57b14f0bb3b2123e2e2116 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Thu, 1 Aug 2024 09:29:25 +0100 Subject: [PATCH] feat: store fromatted metadata in lightning checkpoint --- .../training/diagnostics/callbacks/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f3a82400..2a22ff1e 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -21,6 +21,7 @@ import pytorch_lightning as pl import torch import torchinfo +from anemoi.utils.checkpoints import save_metadata from omegaconf import DictConfig from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -742,12 +743,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s torch.save(model, inference_checkpoint_filepath) - with ZipFile(inference_checkpoint_filepath, "a") as zipf: - base = Path(inference_checkpoint_filepath).stem - zipf.writestr( - f"{base}/ai-models.json", - json.dumps(metadata), - ) + save_metadata(inference_checkpoint_filepath, metadata) model.config = save_config model.metadata = save_metadata @@ -758,6 +754,10 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s # saving checkpoint used for pytorch-lightning based training trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) + + # saving metadata for the checkpoint in same format as for inference + save_metadata(lightning_checkpoint_filepath, metadata) + self._last_global_step_saved = trainer.global_step self._last_checkpoint_saved = lightning_checkpoint_filepath