diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index b08c387..b56d7d5 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -225,10 +225,20 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. """ + path = hf_hub_download(repo_id=repo, filename=name) + self.load_checkpoint_local(path) + + def load_checkpoint_local(self, path: str, strict: bool = True) -> None: + """Load a checkpoint directly from a file. + + Args: + path (str): Path to the checkpoint. + strict (bool, optional): Error if the model parameters are not exactly equal to the + parameters in the checkpoint. Defaults to `True`. + """ # Assume that all parameters are either on the CPU or on the GPU. device = next(self.parameters()).device - path = hf_hub_download(repo_id=repo, filename=name) d = torch.load(path, map_location=device, weights_only=True) # Rename keys to ensure compatibility.