diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index b701f3e..9fae270 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -204,7 +204,6 @@ def forward(self, batch: Batch) -> Batch: ) # Remove batch and history dimension from static variables. - B, T = next(iter(batch.surf_vars.values()))[0] pred = dataclasses.replace( pred, static_vars={k: v[0, 0] for k, v in batch.static_vars.items()}, @@ -231,10 +230,20 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. """ + path = hf_hub_download(repo_id=repo, filename=name) + self.load_checkpoint_local(path, strict=strict) + + def load_checkpoint_local(self, path: str, strict: bool = True) -> None: + """Load a checkpoint directly from a file. + + Args: + path (str): Path to the checkpoint. + strict (bool, optional): Error if the model parameters are not exactly equal to the + parameters in the checkpoint. Defaults to `True`. + """ # Assume that all parameters are either on the CPU or on the GPU. device = next(self.parameters()).device - path = hf_hub_download(repo_id=repo, filename=name) d = torch.load(path, map_location=device, weights_only=True) # You can safely ignore all cumbersome processing below. We modified the model after we diff --git a/docs/finetuning.md b/docs/finetuning.md index 02ac314..8bdda7b 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -69,7 +69,7 @@ scales["new_atmos_var"] = 1.0 ## Other Model Extensions It is possible to extend to model in any way you like. -If you do this, you will likely you add or remove parameters. +If you do this, you will likely add or remove parameters. Then `Aurora.load_checkpoint` will error, because the existing checkpoint now mismatches with the model's parameters. Simply set `Aurora.load_checkpoint(..., strict=False)` to ignore the mismatches: