diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 1fe4220..17d7e39 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -5,7 +5,7 @@ import warnings from datetime import timedelta from functools import partial -from typing import Any, Optional +from typing import Optional import torch from huggingface_hub import hf_hub_download @@ -93,7 +93,9 @@ def __init__( separate parameter. perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used to stabilise the model. - max_history_size (int, optional): Maximum number of history steps. + max_history_size (int, optional): Maximum number of history steps. You can load + checkpoints with a smaller `max_history_size`, but you cannot load checkpoints + with a larger `max_history_size`. use_lora (bool, optional): Use LoRA adaptation. lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out steps. @@ -316,54 +318,54 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] - # check if history size is compatible and adjust weights if necessary - if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]: - d = self.adapt_checkpoint_max_history_size(d) - elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]: - raise AssertionError(f"Cannot load checkpoint with max_history_size \ - {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ - into model with max_history_size {self.max_history_size}") + # Check if the history size is compatible and adjust weights if necessary. + current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2] + if self.max_history_size > current_history_size: + self.adapt_checkpoint_max_history_size(d) + elif self.max_history_size < current_history_size: + raise AssertionError( + f"Cannot load checkpoint with `max_history_size` {current_history_size} " + f"into model with `max_history_size` {self.max_history_size}." + ) self.load_state_dict(d, strict=strict) - def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: - """Adapt a checkpoint with smaller max_history_size to a model with a larger - max_history_size than the current model. + def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) -> None: + """Adapt a checkpoint with smaller `max_history_size` to a model with a larger + `max_history_size` than the current model. - If a checkpoint was trained with a larger max_history_size than the current model, + If a checkpoint was trained with a larger `max_history_size` than the current model, this function will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint which will likely cause the checkpoint to degrade is performance. - This implementation copies weights from the checkpoint to the model and fills 0 - for the new history width dimension. + This implementation copies weights from the checkpoint to the model and fills zeros + for the new history width dimension. It mutates `checkpoint`. """ - # Find all weights with prefix "encoder.surf_token_embeds.weights." for name, weight in list(checkpoint.items()): - if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith( - "encoder.atmos_token_embeds.weights." - ): + # We only need to adapt the patch embedding in the encoder. + enc_surf_embedding = name.startswith("encoder.surf_token_embeds.weights.") + enc_atmos_embedding = name.startswith("encoder.atmos_token_embeds.weights.") + if enc_surf_embedding or enc_atmos_embedding: # This shouldn't get called with current logic but leaving here for future proofing - # and in cases where its called outside current context - assert ( - weight.shape[2] <= self.max_history_size - ), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ - into model with max_history_size {self.max_history_size} for weight {name}" - - # Initialize the new weight tensor + # and in cases where its called outside current context. + if not (weight.shape[2] <= self.max_history_size): + raise AssertionError( + f"Cannot load checkpoint with `max_history_size` {weight.shape[2]} " + f"into model with `max_history_size` {self.max_history_size}." + ) + + # Initialize the new weight tensor. new_weight = torch.zeros( (weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), device=weight.device, dtype=weight.dtype, ) - # Copy the existing weights to the new tensor by duplicating the histories provided - # into any new history dimensions - for j in range(weight.shape[2]): - # only fill existing weights, others are zeros - new_weight[:, :, j, :, :] = weight[:, :, j, :, :] + # into any new history dimensions. The rest remains at zero. + new_weight[:, :, : weight.shape[2]] = weight + checkpoint[name] = new_weight - return checkpoint def configure_activation_checkpointing(self): """Configure activation checkpointing. diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py index 83793cc..371656f 100644 --- a/tests/test_checkpoint_adaptation.py +++ b/tests/test_checkpoint_adaptation.py @@ -1,5 +1,6 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +import numpy as np import pytest import torch @@ -19,45 +20,41 @@ def checkpoint(): } -# check both history sizes which are divisible by 2 (original shape) and not +# Check both history sizes which are divisible by 2 (original shape) and not. @pytest.mark.parametrize("model", [4, 5], indirect=True) def test_adapt_checkpoint_max_history(model, checkpoint): - # checkpoint starts with history dim, shape[2], as size 2 + # Checkpoint starts with history dim., `shape[2]`, equal to 2. assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 - adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + model.adapt_checkpoint_max_history_size(checkpoint) - for name, weight in adapted_checkpoint.items(): + for name, weight in checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): if j >= checkpoint[name].shape[2]: - assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) + np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :]) else: - assert torch.equal( - weight[:, :, j, :, :], - checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], - ) + np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :]) -# check that assert is thrown when trying to load a larger checkpoint to a smaller history size @pytest.mark.parametrize("model", [1], indirect=True) def test_adapt_checkpoint_max_history_fail(model, checkpoint): + """Check that an assertion error is thrown when trying to load a larger checkpoint to a + smaller history size.""" with pytest.raises(AssertionError): model.adapt_checkpoint_max_history_size(checkpoint) -# test adapting the checkpoint twice to ensure that the second time should not change the weights @pytest.mark.parametrize("model", [4], indirect=True) def test_adapt_checkpoint_max_history_twice(model, checkpoint): - adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) - adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) + """Test adapting the checkpoint twice to ensure that the second time should not change the + weights.""" + model.adapt_checkpoint_max_history_size(checkpoint) + model.adapt_checkpoint_max_history_size(checkpoint) - for name, weight in adapted_checkpoint.items(): + for name, weight in checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): if j >= checkpoint[name].shape[2]: - assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) + np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :]) else: - assert torch.equal( - weight[:, :, j, :, :], - checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], - ) + np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])