Skip to content

Commit

Permalink
Merge branch 'main' into wesselb/checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb authored Sep 11, 2024
2 parents 6233272 + b65b87d commit 0657507
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def forward(self, batch: Batch) -> Batch:
)

# Remove batch and history dimension from static variables.
B, T = next(iter(batch.surf_vars.values()))[0]
pred = dataclasses.replace(
pred,
static_vars={k: v[0, 0] for k, v in batch.static_vars.items()},
Expand All @@ -231,10 +230,20 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None:
strict (bool, optional): Error if the model parameters are not exactly equal to the
parameters in the checkpoint. Defaults to `True`.
"""
path = hf_hub_download(repo_id=repo, filename=name)
self.load_checkpoint_local(path, strict=strict)

def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
"""Load a checkpoint directly from a file.
Args:
path (str): Path to the checkpoint.
strict (bool, optional): Error if the model parameters are not exactly equal to the
parameters in the checkpoint. Defaults to `True`.
"""
# Assume that all parameters are either on the CPU or on the GPU.
device = next(self.parameters()).device

path = hf_hub_download(repo_id=repo, filename=name)
d = torch.load(path, map_location=device, weights_only=True)

# You can safely ignore all cumbersome processing below. We modified the model after we
Expand Down
2 changes: 1 addition & 1 deletion docs/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ scales["new_atmos_var"] = 1.0
## Other Model Extensions

It is possible to extend to model in any way you like.
If you do this, you will likely you add or remove parameters.
If you do this, you will likely add or remove parameters.
Then `Aurora.load_checkpoint` will error,
because the existing checkpoint now mismatches with the model's parameters.
Simply set `Aurora.load_checkpoint(..., strict=False)` to ignore the mismatches:
Expand Down

0 comments on commit 0657507

Please sign in to comment.