Skip to content

Commit

Permalink
feat: store fromatted metadata in lightning checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Aug 1, 2024
1 parent 999a224 commit 404a5e6
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytorch_lightning as pl
import torch
import torchinfo
from anemoi.utils.checkpoints import save_metadata
from omegaconf import DictConfig
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Expand Down Expand Up @@ -742,12 +743,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s

torch.save(model, inference_checkpoint_filepath)

with ZipFile(inference_checkpoint_filepath, "a") as zipf:
base = Path(inference_checkpoint_filepath).stem
zipf.writestr(
f"{base}/ai-models.json",
json.dumps(metadata),
)
save_metadata(inference_checkpoint_filepath, metadata)

model.config = save_config
model.metadata = save_metadata
Expand All @@ -758,6 +754,10 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s

# saving checkpoint used for pytorch-lightning based training
trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only)

# saving metadata for the checkpoint in same format as for inference
save_metadata(lightning_checkpoint_filepath, metadata)

self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = lightning_checkpoint_filepath

Expand Down

0 comments on commit 404a5e6

Please sign in to comment.