diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 272a529..1fe4220 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -5,7 +5,7 @@ import warnings from datetime import timedelta from functools import partial -from typing import Optional +from typing import Any, Optional import torch from huggingface_hub import hf_hub_download @@ -112,6 +112,7 @@ def __init__( self.patch_size = patch_size self.surf_stats = surf_stats or dict() self.autocast = autocast + self.max_history_size = max_history_size if self.surf_stats: warnings.warn( @@ -268,8 +269,7 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: del d[k] d[k[4:]] = v - # Convert the ID-based parametrisation to a name-based parametrisation. - + # Convert the ID-based parametrization to a name-based parametrization. if "encoder.surf_token_embeds.weight" in d: weight = d["encoder.surf_token_embeds.weight"] del d["encoder.surf_token_embeds.weight"] @@ -316,8 +316,55 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] + # check if history size is compatible and adjust weights if necessary + if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]: + d = self.adapt_checkpoint_max_history_size(d) + elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]: + raise AssertionError(f"Cannot load checkpoint with max_history_size \ + {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ + into model with max_history_size {self.max_history_size}") + self.load_state_dict(d, strict=strict) + def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: + """Adapt a checkpoint with smaller max_history_size to a model with a larger + max_history_size than the current model. + + If a checkpoint was trained with a larger max_history_size than the current model, + this function will assert fail to prevent loading the checkpoint. This is to + prevent loading a checkpoint which will likely cause the checkpoint to degrade is + performance. + + This implementation copies weights from the checkpoint to the model and fills 0 + for the new history width dimension. + """ + # Find all weights with prefix "encoder.surf_token_embeds.weights." + for name, weight in list(checkpoint.items()): + if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith( + "encoder.atmos_token_embeds.weights." + ): + # This shouldn't get called with current logic but leaving here for future proofing + # and in cases where its called outside current context + assert ( + weight.shape[2] <= self.max_history_size + ), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ + into model with max_history_size {self.max_history_size} for weight {name}" + + # Initialize the new weight tensor + new_weight = torch.zeros( + (weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), + device=weight.device, + dtype=weight.dtype, + ) + + # Copy the existing weights to the new tensor by duplicating the histories provided + # into any new history dimensions + for j in range(weight.shape[2]): + # only fill existing weights, others are zeros + new_weight[:, :, j, :, :] = weight[:, :, j, :, :] + checkpoint[name] = new_weight + return checkpoint + def configure_activation_checkpointing(self): """Configure activation checkpointing. diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py new file mode 100644 index 0000000..83793cc --- /dev/null +++ b/tests/test_checkpoint_adaptation.py @@ -0,0 +1,63 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +import pytest +import torch + +from aurora.model.aurora import AuroraSmall + + +@pytest.fixture +def model(request): + return AuroraSmall(max_history_size=request.param) + + +@pytest.fixture +def checkpoint(): + return { + "encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), + "encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), + } + + +# check both history sizes which are divisible by 2 (original shape) and not +@pytest.mark.parametrize("model", [4, 5], indirect=True) +def test_adapt_checkpoint_max_history(model, checkpoint): + # checkpoint starts with history dim, shape[2], as size 2 + assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + + for name, weight in adapted_checkpoint.items(): + assert weight.shape[2] == model.max_history_size + for j in range(weight.shape[2]): + if j >= checkpoint[name].shape[2]: + assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) + else: + assert torch.equal( + weight[:, :, j, :, :], + checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], + ) + + +# check that assert is thrown when trying to load a larger checkpoint to a smaller history size +@pytest.mark.parametrize("model", [1], indirect=True) +def test_adapt_checkpoint_max_history_fail(model, checkpoint): + with pytest.raises(AssertionError): + model.adapt_checkpoint_max_history_size(checkpoint) + + +# test adapting the checkpoint twice to ensure that the second time should not change the weights +@pytest.mark.parametrize("model", [4], indirect=True) +def test_adapt_checkpoint_max_history_twice(model, checkpoint): + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) + + for name, weight in adapted_checkpoint.items(): + assert weight.shape[2] == model.max_history_size + for j in range(weight.shape[2]): + if j >= checkpoint[name].shape[2]: + assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) + else: + assert torch.equal( + weight[:, :, j, :, :], + checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], + )