From 129bc8d0862d9322b28372a8627b27037a8c7ab3 Mon Sep 17 00:00:00 2001 From: scottcha Date: Mon, 16 Sep 2024 12:17:14 -0600 Subject: [PATCH 1/9] Add new logic to enable stored checkpoint weights to be copied to new history dimensions --- aurora/model/aurora.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 418f981..3695366 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -112,6 +112,7 @@ def __init__( self.patch_size = patch_size self.surf_stats = surf_stats or dict() self.autocast = autocast + self.max_history_size = max_history_size if self.surf_stats: warnings.warn( @@ -269,22 +270,31 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: d[k[4:]] = v # Convert the ID-based parametrisation to a name-based parametrisation. - if "encoder.surf_token_embeds.weight" in d: weight = d["encoder.surf_token_embeds.weight"] del d["encoder.surf_token_embeds.weight"] - + assert weight.shape[1] == 4 + 3 for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): - d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] - + # Initialize the new weight tensor with zeros + new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) + # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions + for j in range(new_weight.shape[2]): + new_weight[:, :, j, :, :] = weight[:, [i], j % weight.shape[2], :, :] + d[f"encoder.surf_token_embeds.weights.{name}"] = new_weight + if "encoder.atmos_token_embeds.weight" in d: weight = d["encoder.atmos_token_embeds.weight"] del d["encoder.atmos_token_embeds.weight"] - + assert weight.shape[1] == 5 for i, name in enumerate(("z", "u", "v", "t", "q")): - d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] + # Initialize the new weight tensor with zeros + new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) + # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions + for j in range(new_weight.shape[2]): + new_weight[:, :, j, :, :] = weight[:, [i], j % weight.shape[2], :, :] + d[f"encoder.atmos_token_embeds.weights.{name}"] = new_weight if "decoder.surf_head.weight" in d: weight = d["decoder.surf_head.weight"] From 7127f0c7fa955592d85a1f4ff5bd9f6eb9e565e8 Mon Sep 17 00:00:00 2001 From: scottcha Date: Tue, 17 Sep 2024 09:09:45 -0600 Subject: [PATCH 2/9] Refactor checkpoint adaptation logic to allow for more flexibility and different fn to adapt history --- aurora/model/aurora.py | 155 +++++++++++++++++++++++++++++------------ 1 file changed, 110 insertions(+), 45 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 3695366..9a6a282 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -5,7 +5,7 @@ import warnings from datetime import timedelta from functools import partial -from typing import Optional +from typing import Any, Callable, Optional import torch from huggingface_hub import hf_hub_download @@ -233,7 +233,7 @@ def forward(self, batch: Batch) -> Batch: return pred - def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: + def load_checkpoint(self, repo: str, name: str, strict: bool = True, adapt_fn: Callable[[Any], Any] = 'default', adapt_history: Callable[[Any], Any] = 'default') -> None: """Load a checkpoint from HuggingFace. Args: @@ -242,65 +242,116 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: `checkpoint.cpkt`. strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. + adapt_fn (Callable[[Any], Any], optional): Function to adapt the checkpoint to the current model. + Defaults to `self.adapt_checkpoint`. Pass `None` to skip adaptation. + adapt_history (Callable[[Any], Any], optional): Function to fill the history of the checkpoint when + Model is larger than checkpoint. Defaults to `self.adapt_checkpoint_max_history_size`. Pass `None` to skip adaptation. """ + if adapt_fn == 'default': + adapt_fn = self.adapt_checkpoint + + if adapt_history == 'default': + adapt_history = self.adapt_checkpoint_max_history_size + path = hf_hub_download(repo_id=repo, filename=name) - self.load_checkpoint_local(path, strict=strict) + self.load_checkpoint_local(path, strict=strict, adapt_fn=adapt_fn, adapt_history=adapt_history) - def load_checkpoint_local(self, path: str, strict: bool = True) -> None: + def load_checkpoint_local(self, path: str, strict: bool = True, adapt_fn: Callable[[Any], Any] = 'default', adapt_history: Callable[[Any], Any] = 'default') -> None: """Load a checkpoint directly from a file. Args: path (str): Path to the checkpoint. strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. + adapt_fn (Callable[[Any], Any], optional): Function to adapt the checkpoint to the current model. + Defaults to `self.adapt_checkpoint`. Pass `None` to skip adaptation. + adapt_history (Callable[[Any], Any], optional): Function to fill the history of the checkpoint when + Model is larger than checkpoint. Defaults to `self.adapt_checkpoint_max_history_size`. Pass `None` to skip adaptation. """ + if adapt_fn == 'default': + adapt_fn = self.adapt_checkpoint + + if adapt_history == 'default': + adapt_history = self.adapt_checkpoint_max_history_size + # Assume that all parameters are either on the CPU or on the GPU. device = next(self.parameters()).device d = torch.load(path, map_location=device, weights_only=True) - # You can safely ignore all cumbersome processing below. We modified the model after we - # trained it. The code below manually adapts the checkpoints, so the checkpoints are - # compatible with the new model. + # Adapt the checkpoint using the provided function, if not None + if adapt_fn is not None: + d = adapt_fn(d) + + # Adapt the checkpoint history size using the provided function, if not None + if adapt_history is not None: + d = adapt_history(d) - # Remove possibly prefix from the keys. - for k, v in list(d.items()): - if k.startswith("net."): - del d[k] - d[k[4:]] = v + self.load_state_dict(d, strict=strict) - # Convert the ID-based parametrisation to a name-based parametrisation. - if "encoder.surf_token_embeds.weight" in d: - weight = d["encoder.surf_token_embeds.weight"] - del d["encoder.surf_token_embeds.weight"] + def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: + """Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size + than the current model. - assert weight.shape[1] == 4 + 3 - for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): + If a checkpoint was trained with a larger max_history_size than the current model, this function + will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint + which will likely cause the checkpoint to degrade is performance. + + This implementation copies weights from the checkpoint to the model, duplicating the weights. + """ + + #find all weights with prefix "encoder.surf_token_embeds.weights." + for name, weight in list(checkpoint.items()): + if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."): + assert weight.shape[2] >= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ + into model with max_history_size {self.max_history_size} for weight {name}" + # Initialize the new weight tensor with zeros new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions for j in range(new_weight.shape[2]): - new_weight[:, :, j, :, :] = weight[:, [i], j % weight.shape[2], :, :] - d[f"encoder.surf_token_embeds.weights.{name}"] = new_weight + new_weight[:, :, j, :, :] = weight[:, :, j % weight.shape[2], :, :] + checkpoint[name] = new_weight + + return checkpoint + + @staticmethod + def _adapt_checkpoint_prefix(checkpoint) -> Any: + """Adapt a checkpoint with a different prefix to the current model. + + If a checkpoint was trained with a different prefix than the current model, this function + will remove the prefix from the checkpoint so that it can be loaded. + """ + # Remove possibly prefix from the keys. + for k, v in list(checkpoint.items()): + if k.startswith("net."): + del checkpoint[k] + checkpoint[k[4:]] = v + return checkpoint + + def _adapt_checkpoint_parametrization(self, checkpoint) -> Any: + # Convert the ID-based parametrization to a name-based parametrization. + if "encoder.surf_token_embeds.weight" in checkpoint: + weight = checkpoint["encoder.surf_token_embeds.weight"] + del checkpoint["encoder.surf_token_embeds.weight"] + + assert weight.shape[1] == 4 + 3 + for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): + checkpoint[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] - if "encoder.atmos_token_embeds.weight" in d: - weight = d["encoder.atmos_token_embeds.weight"] - del d["encoder.atmos_token_embeds.weight"] + if "encoder.atmos_token_embeds.weight" in checkpoint: + weight = checkpoint["encoder.atmos_token_embeds.weight"] + del checkpoint["encoder.atmos_token_embeds.weight"] assert weight.shape[1] == 5 for i, name in enumerate(("z", "u", "v", "t", "q")): - # Initialize the new weight tensor with zeros - new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) - # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions - for j in range(new_weight.shape[2]): - new_weight[:, :, j, :, :] = weight[:, [i], j % weight.shape[2], :, :] - d[f"encoder.atmos_token_embeds.weights.{name}"] = new_weight + checkpoint[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] - if "decoder.surf_head.weight" in d: - weight = d["decoder.surf_head.weight"] - bias = d["decoder.surf_head.bias"] - del d["decoder.surf_head.weight"] - del d["decoder.surf_head.bias"] + if "decoder.surf_head.weight" in checkpoint: + weight = checkpoint["decoder.surf_head.weight"] + bias = checkpoint["decoder.surf_head.bias"] + del checkpoint["decoder.surf_head.weight"] + del checkpoint["decoder.surf_head.bias"] assert weight.shape[0] == 4 * self.patch_size**2 assert bias.shape[0] == 4 * self.patch_size**2 @@ -308,14 +359,14 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: bias = bias.reshape(self.patch_size**2, 4) for i, name in enumerate(("2t", "10u", "10v", "msl")): - d[f"decoder.surf_heads.{name}.weight"] = weight[:, i] - d[f"decoder.surf_heads.{name}.bias"] = bias[:, i] + checkpoint[f"decoder.surf_heads.{name}.weight"] = weight[:, i] + checkpoint[f"decoder.surf_heads.{name}.bias"] = bias[:, i] - if "decoder.atmos_head.weight" in d: - weight = d["decoder.atmos_head.weight"] - bias = d["decoder.atmos_head.bias"] - del d["decoder.atmos_head.weight"] - del d["decoder.atmos_head.bias"] + if "decoder.atmos_head.weight" in checkpoint: + weight = checkpoint["decoder.atmos_head.weight"] + bias = checkpoint["decoder.atmos_head.bias"] + del checkpoint["decoder.atmos_head.weight"] + del checkpoint["decoder.atmos_head.bias"] assert weight.shape[0] == 5 * self.patch_size**2 assert bias.shape[0] == 5 * self.patch_size**2 @@ -323,11 +374,25 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: bias = bias.reshape(self.patch_size**2, 5) for i, name in enumerate(("z", "u", "v", "t", "q")): - d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] - d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] - - self.load_state_dict(d, strict=strict) + checkpoint[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] + checkpoint[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] + return checkpoint + + def adapt_checkpoint(self, checkpoint) -> Any: + """Adapt a checkpoint to the current model. + + Current model has a different structure than the model that was used to train the checkpoint + that is being loaded. This function adapts the checkpoint to the current model so that it can + be loaded. + """ + # You can safely ignore all cumbersome processing below. We modified the model after we + # trained it. The code below manually adapts the checkpoints, so the checkpoints are + # compatible with the new model. + checkpoint = Aurora._adapt_checkpoint_prefix(checkpoint) + checkpoint = self._adapt_checkpoint_parametrization(checkpoint) + return checkpoint + def configure_activation_checkpointing(self): """Configure activation checkpointing. From c9277a9847927ac0c5658dbd198b70919d431ae8 Mon Sep 17 00:00:00 2001 From: scottcha Date: Wed, 18 Sep 2024 12:04:42 -0600 Subject: [PATCH 3/9] refactor ability to adapt max_history_size from a checkpoint to its own method --- aurora/model/aurora.py | 177 +++++++++++----------------- tests/test_checkpoint_adaptation.py | 33 ++++++ 2 files changed, 102 insertions(+), 108 deletions(-) create mode 100644 tests/test_checkpoint_adaptation.py diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 614e7ef..c86bbd9 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -5,7 +5,7 @@ import warnings from datetime import timedelta from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Optional import torch from huggingface_hub import hf_hub_download @@ -233,7 +233,7 @@ def forward(self, batch: Batch) -> Batch: return pred - def load_checkpoint(self, repo: str, name: str, strict: bool = True, adapt_fn: Callable[[Any], Any] = 'default', adapt_history: Callable[[Any], Any] = 'default') -> None: + def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: """Load a checkpoint from HuggingFace. Args: @@ -242,116 +242,55 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True, adapt_fn: C `checkpoint.cpkt`. strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. - adapt_fn (Callable[[Any], Any], optional): Function to adapt the checkpoint to the current model. - Defaults to `self.adapt_checkpoint`. Pass `None` to skip adaptation. - adapt_history (Callable[[Any], Any], optional): Function to fill the history of the checkpoint when - Model is larger than checkpoint. Defaults to `self.adapt_checkpoint_max_history_size`. Pass `None` to skip adaptation. """ - if adapt_fn == 'default': - adapt_fn = self.adapt_checkpoint - - if adapt_history == 'default': - adapt_history = self.adapt_checkpoint_max_history_size - path = hf_hub_download(repo_id=repo, filename=name) - self.load_checkpoint_local(path, strict=strict, adapt_fn=adapt_fn, adapt_history=adapt_history) + self.load_checkpoint_local(path, strict=strict) - def load_checkpoint_local(self, path: str, strict: bool = True, adapt_fn: Callable[[Any], Any] = 'default', adapt_history: Callable[[Any], Any] = 'default') -> None: + def load_checkpoint_local(self, path: str, strict: bool = True) -> None: """Load a checkpoint directly from a file. Args: path (str): Path to the checkpoint. strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. - adapt_fn (Callable[[Any], Any], optional): Function to adapt the checkpoint to the current model. - Defaults to `self.adapt_checkpoint`. Pass `None` to skip adaptation. - adapt_history (Callable[[Any], Any], optional): Function to fill the history of the checkpoint when - Model is larger than checkpoint. Defaults to `self.adapt_checkpoint_max_history_size`. Pass `None` to skip adaptation. """ - if adapt_fn == 'default': - adapt_fn = self.adapt_checkpoint - - if adapt_history == 'default': - adapt_history = self.adapt_checkpoint_max_history_size - # Assume that all parameters are either on the CPU or on the GPU. device = next(self.parameters()).device d = torch.load(path, map_location=device, weights_only=True) - # Adapt the checkpoint using the provided function, if not None - if adapt_fn is not None: - d = adapt_fn(d) - - # Adapt the checkpoint history size using the provided function, if not None - if adapt_history is not None: - d = adapt_history(d) - - self.load_state_dict(d, strict=strict) + # You can safely ignore all cumbersome processing below. We modified the model after we + # trained it. The code below manually adapts the checkpoints, so the checkpoints are + # compatible with the new model. - def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: - """Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size - than the current model. - - If a checkpoint was trained with a larger max_history_size than the current model, this function - will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint - which will likely cause the checkpoint to degrade is performance. - - This implementation copies weights from the checkpoint to the model, duplicating the weights. - """ - - #find all weights with prefix "encoder.surf_token_embeds.weights." - for name, weight in list(checkpoint.items()): - if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."): - assert weight.shape[2] >= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ - into model with max_history_size {self.max_history_size} for weight {name}" - - # Initialize the new weight tensor with zeros - new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) - # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions - for j in range(new_weight.shape[2]): - new_weight[:, :, j, :, :] = weight[:, :, j % weight.shape[2], :, :] - checkpoint[name] = new_weight - - return checkpoint - - @staticmethod - def _adapt_checkpoint_prefix(checkpoint) -> Any: - """Adapt a checkpoint with a different prefix to the current model. - - If a checkpoint was trained with a different prefix than the current model, this function - will remove the prefix from the checkpoint so that it can be loaded. - """ # Remove possibly prefix from the keys. - for k, v in list(checkpoint.items()): + for k, v in list(d.items()): if k.startswith("net."): - del checkpoint[k] - checkpoint[k[4:]] = v - return checkpoint - - def _adapt_checkpoint_parametrization(self, checkpoint) -> Any: + del d[k] + d[k[4:]] = v + # Convert the ID-based parametrization to a name-based parametrization. - if "encoder.surf_token_embeds.weight" in checkpoint: - weight = checkpoint["encoder.surf_token_embeds.weight"] - del checkpoint["encoder.surf_token_embeds.weight"] + if "encoder.surf_token_embeds.weight" in d: + weight = d["encoder.surf_token_embeds.weight"] + del d["encoder.surf_token_embeds.weight"] assert weight.shape[1] == 4 + 3 for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): - checkpoint[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] + d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] - if "encoder.atmos_token_embeds.weight" in checkpoint: - weight = checkpoint["encoder.atmos_token_embeds.weight"] - del checkpoint["encoder.atmos_token_embeds.weight"] + if "encoder.atmos_token_embeds.weight" in d: + weight = d["encoder.atmos_token_embeds.weight"] + del d["encoder.atmos_token_embeds.weight"] assert weight.shape[1] == 5 for i, name in enumerate(("z", "u", "v", "t", "q")): - checkpoint[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] + d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] - if "decoder.surf_head.weight" in checkpoint: - weight = checkpoint["decoder.surf_head.weight"] - bias = checkpoint["decoder.surf_head.bias"] - del checkpoint["decoder.surf_head.weight"] - del checkpoint["decoder.surf_head.bias"] + if "decoder.surf_head.weight" in d: + weight = d["decoder.surf_head.weight"] + bias = d["decoder.surf_head.bias"] + del d["decoder.surf_head.weight"] + del d["decoder.surf_head.bias"] assert weight.shape[0] == 4 * self.patch_size**2 assert bias.shape[0] == 4 * self.patch_size**2 @@ -359,14 +298,14 @@ def _adapt_checkpoint_parametrization(self, checkpoint) -> Any: bias = bias.reshape(self.patch_size**2, 4) for i, name in enumerate(("2t", "10u", "10v", "msl")): - checkpoint[f"decoder.surf_heads.{name}.weight"] = weight[:, i] - checkpoint[f"decoder.surf_heads.{name}.bias"] = bias[:, i] + d[f"decoder.surf_heads.{name}.weight"] = weight[:, i] + d[f"decoder.surf_heads.{name}.bias"] = bias[:, i] - if "decoder.atmos_head.weight" in checkpoint: - weight = checkpoint["decoder.atmos_head.weight"] - bias = checkpoint["decoder.atmos_head.bias"] - del checkpoint["decoder.atmos_head.weight"] - del checkpoint["decoder.atmos_head.bias"] + if "decoder.atmos_head.weight" in d: + weight = d["decoder.atmos_head.weight"] + bias = d["decoder.atmos_head.bias"] + del d["decoder.atmos_head.weight"] + del d["decoder.atmos_head.bias"] assert weight.shape[0] == 5 * self.patch_size**2 assert bias.shape[0] == 5 * self.patch_size**2 @@ -374,24 +313,44 @@ def _adapt_checkpoint_parametrization(self, checkpoint) -> Any: bias = bias.reshape(self.patch_size**2, 5) for i, name in enumerate(("z", "u", "v", "t", "q")): - checkpoint[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] - checkpoint[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] - return checkpoint - - def adapt_checkpoint(self, checkpoint) -> Any: - """Adapt a checkpoint to the current model. + d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] + d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] + + #check if history size is compatible and adjust weights if necessary + if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]: + d = self.adapt_checkpoint_max_history_size(d) + elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]: + assert False, f"Cannot load checkpoint with max_history_size {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ + into model with max_history_size {self.max_history_size}" + + self.load_state_dict(d, strict=strict) + + def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: + """Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size + than the current model. + + If a checkpoint was trained with a larger max_history_size than the current model, this function + will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint + which will likely cause the checkpoint to degrade is performance. - Current model has a different structure than the model that was used to train the checkpoint - that is being loaded. This function adapts the checkpoint to the current model so that it can - be loaded. + This implementation copies weights from the checkpoint to the model, duplicating the weights. """ - # You can safely ignore all cumbersome processing below. We modified the model after we - # trained it. The code below manually adapts the checkpoints, so the checkpoints are - # compatible with the new model. - - checkpoint = Aurora._adapt_checkpoint_prefix(checkpoint) - checkpoint = self._adapt_checkpoint_parametrization(checkpoint) - return checkpoint + #find all weights with prefix "encoder.surf_token_embeds.weights." + for name, weight in list(checkpoint.items()): + if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."): + #this shouldn't get called with current logic but leaving here for future proofing and in cases where its called + #outside current context + assert weight.shape[2] <= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ + into model with max_history_size {self.max_history_size} for weight {name}" + + # Initialize the new weight tensor with zeros + new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) + # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions + for j in range(new_weight.shape[2]): + new_weight[:, :, j, :, :] = weight[:, :, j % weight.shape[2], :, :] + checkpoint[name] = new_weight + + return checkpoint def configure_activation_checkpointing(self): """Configure activation checkpointing. @@ -417,4 +376,6 @@ def configure_activation_checkpointing(self): patch_size=10, encoder_depths=(6, 8, 8), decoder_depths=(8, 8, 6), + # One particular static variable requires a different normalisation. + surf_stats={"z": (-3.270407e03, 6.540335e04)}, ) diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py new file mode 100644 index 0000000..ac3bfe6 --- /dev/null +++ b/tests/test_checkpoint_adaptation.py @@ -0,0 +1,33 @@ +import pytest +import torch +from unittest.mock import patch +from aurora.model.aurora import AuroraSmall + +@pytest.fixture +def model(request): + return AuroraSmall(max_history_size=request.param) + +@pytest.fixture +def checkpoint(): + return { + "encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), + "encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)) + } + +#check both history sizes which are divisible by 2 (original shape) and not +@pytest.mark.parametrize('model', [4, 5], indirect=True) +def test_adapt_checkpoint_max_history(model, checkpoint): + # checkpoint starts with history dim, shape[2], as size 2 + assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + + for name, weight in adapted_checkpoint.items(): + assert weight.shape[2] == model.max_history_size + for j in range(weight.shape[2]): + assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) + +#check that assert is thrown when trying to load a larger checkpoint to a smaller history size +@pytest.mark.parametrize('model', [1], indirect=True) +def test_adapt_checkpoint_max_history_fail(model, checkpoint): + with pytest.raises(AssertionError): + model.adapt_checkpoint_max_history_size(checkpoint) \ No newline at end of file From ac23718de30c697ce511a58acf3e9471b1b27188 Mon Sep 17 00:00:00 2001 From: scottcha Date: Wed, 18 Sep 2024 12:18:45 -0600 Subject: [PATCH 4/9] Add addiitonal test for multiple calls --- tests/test_checkpoint_adaptation.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py index ac3bfe6..3599c13 100644 --- a/tests/test_checkpoint_adaptation.py +++ b/tests/test_checkpoint_adaptation.py @@ -30,4 +30,15 @@ def test_adapt_checkpoint_max_history(model, checkpoint): @pytest.mark.parametrize('model', [1], indirect=True) def test_adapt_checkpoint_max_history_fail(model, checkpoint): with pytest.raises(AssertionError): - model.adapt_checkpoint_max_history_size(checkpoint) \ No newline at end of file + model.adapt_checkpoint_max_history_size(checkpoint) + +#test adapting the checkpoint twice to ensure that the second time should not change the weights +@pytest.mark.parametrize('model', [4], indirect=True) +def test_adapt_checkpoint_max_history_twice(model, checkpoint): + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) + + for name, weight in adapted_checkpoint.items(): + assert weight.shape[2] == model.max_history_size + for j in range(weight.shape[2]): + assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) \ No newline at end of file From 9a78600626a15a7000bb8ff4ce6941ce116de2e6 Mon Sep 17 00:00:00 2001 From: scottcha Date: Wed, 18 Sep 2024 12:23:57 -0600 Subject: [PATCH 5/9] Add copyright to the new test file --- tests/test_checkpoint_adaptation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py index 3599c13..364e1c2 100644 --- a/tests/test_checkpoint_adaptation.py +++ b/tests/test_checkpoint_adaptation.py @@ -1,3 +1,5 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + import pytest import torch from unittest.mock import patch From c9b722e5b4d1a28bbe26141d8d8b1541d14fb8c2 Mon Sep 17 00:00:00 2001 From: scottcha Date: Thu, 19 Sep 2024 16:05:13 -0600 Subject: [PATCH 6/9] fill with zeroes instead of previous weights match previous weights device and dtype --- aurora/model/aurora.py | 24 ++++++++++++++---------- tests/test_checkpoint_adaptation.py | 13 ++++++++++--- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index c86bbd9..8c940c0 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -333,25 +333,29 @@ def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint which will likely cause the checkpoint to degrade is performance. - This implementation copies weights from the checkpoint to the model, duplicating the weights. + This implementation copies weights from the checkpoint to the model and fills 0 for the new history + width dimension. """ - #find all weights with prefix "encoder.surf_token_embeds.weights." + # Find all weights with prefix "encoder.surf_token_embeds.weights." for name, weight in list(checkpoint.items()): if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."): - #this shouldn't get called with current logic but leaving here for future proofing and in cases where its called - #outside current context + # This shouldn't get called with current logic but leaving here for future proofing and in cases where its called + # outside current context assert weight.shape[2] <= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ into model with max_history_size {self.max_history_size} for weight {name}" - # Initialize the new weight tensor with zeros - new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4])) - # Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions + # Initialize the new weight tensor + new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), + device=weight.device, dtype=weight.dtype) + + # Copy the existing weights to the new tensor by duplicating the histories provided into any new history dimensions for j in range(new_weight.shape[2]): - new_weight[:, :, j, :, :] = weight[:, :, j % weight.shape[2], :, :] + if j < weight.shape[2]: + # only fill existing weights, others are zeros + new_weight[:, :, j, :, :] = weight[:, :, j, :, :] checkpoint[name] = new_weight - return checkpoint - + def configure_activation_checkpointing(self): """Configure activation checkpointing. diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py index 364e1c2..e43c48f 100644 --- a/tests/test_checkpoint_adaptation.py +++ b/tests/test_checkpoint_adaptation.py @@ -26,7 +26,10 @@ def test_adapt_checkpoint_max_history(model, checkpoint): for name, weight in adapted_checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): - assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) + if j >= checkpoint[name].shape[2]: + assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) + else: + assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) #check that assert is thrown when trying to load a larger checkpoint to a smaller history size @pytest.mark.parametrize('model', [1], indirect=True) @@ -37,10 +40,14 @@ def test_adapt_checkpoint_max_history_fail(model, checkpoint): #test adapting the checkpoint twice to ensure that the second time should not change the weights @pytest.mark.parametrize('model', [4], indirect=True) def test_adapt_checkpoint_max_history_twice(model, checkpoint): - adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) for name, weight in adapted_checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): - assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) \ No newline at end of file + if j >= checkpoint[name].shape[2]: + assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) + else: + assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) + From 8924e8e8c416627e0654b15c9b859bd6f18316e0 Mon Sep 17 00:00:00 2001 From: scottcha Date: Thu, 19 Sep 2024 16:07:32 -0600 Subject: [PATCH 7/9] manuall fix AuroraHighRes to match main --- aurora/model/aurora.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 8c940c0..9eba914 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -380,6 +380,4 @@ def configure_activation_checkpointing(self): patch_size=10, encoder_depths=(6, 8, 8), decoder_depths=(8, 8, 6), - # One particular static variable requires a different normalisation. - surf_stats={"z": (-3.270407e03, 6.540335e04)}, ) From 70c5d2ce64ec3c48e8c7fc6d63c9a4dd67b3ae43 Mon Sep 17 00:00:00 2001 From: scottcha Date: Thu, 19 Sep 2024 16:22:19 -0600 Subject: [PATCH 8/9] simplify weight copying logic --- aurora/model/aurora.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 9eba914..efe5541 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -349,10 +349,9 @@ def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: device=weight.device, dtype=weight.dtype) # Copy the existing weights to the new tensor by duplicating the histories provided into any new history dimensions - for j in range(new_weight.shape[2]): - if j < weight.shape[2]: - # only fill existing weights, others are zeros - new_weight[:, :, j, :, :] = weight[:, :, j, :, :] + for j in range(weight.shape[2]): + # only fill existing weights, others are zeros + new_weight[:, :, j, :, :] = weight[:, :, j, :, :] checkpoint[name] = new_weight return checkpoint From 872541b8c756cb9d6380893e441041f27d26d581 Mon Sep 17 00:00:00 2001 From: scottcha Date: Sun, 22 Sep 2024 08:22:42 -0600 Subject: [PATCH 9/9] Fix pre-commit issues --- aurora/model/aurora.py | 68 +++++++++++++++++------------ tests/test_checkpoint_adaptation.py | 42 +++++++++++------- 2 files changed, 65 insertions(+), 45 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index efe5541..1fe4220 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -273,18 +273,18 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: if "encoder.surf_token_embeds.weight" in d: weight = d["encoder.surf_token_embeds.weight"] del d["encoder.surf_token_embeds.weight"] - + assert weight.shape[1] == 4 + 3 for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] - + if "encoder.atmos_token_embeds.weight" in d: weight = d["encoder.atmos_token_embeds.weight"] del d["encoder.atmos_token_embeds.weight"] - + assert weight.shape[1] == 5 for i, name in enumerate(("z", "u", "v", "t", "q")): - d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] + d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] if "decoder.surf_head.weight" in d: weight = d["decoder.surf_head.weight"] @@ -316,45 +316,55 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] - #check if history size is compatible and adjust weights if necessary + # check if history size is compatible and adjust weights if necessary if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]: d = self.adapt_checkpoint_max_history_size(d) elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]: - assert False, f"Cannot load checkpoint with max_history_size {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ - into model with max_history_size {self.max_history_size}" - + raise AssertionError(f"Cannot load checkpoint with max_history_size \ + {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ + into model with max_history_size {self.max_history_size}") + self.load_state_dict(d, strict=strict) def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: - """Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size - than the current model. - - If a checkpoint was trained with a larger max_history_size than the current model, this function - will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint - which will likely cause the checkpoint to degrade is performance. - - This implementation copies weights from the checkpoint to the model and fills 0 for the new history - width dimension. + """Adapt a checkpoint with smaller max_history_size to a model with a larger + max_history_size than the current model. + + If a checkpoint was trained with a larger max_history_size than the current model, + this function will assert fail to prevent loading the checkpoint. This is to + prevent loading a checkpoint which will likely cause the checkpoint to degrade is + performance. + + This implementation copies weights from the checkpoint to the model and fills 0 + for the new history width dimension. """ - # Find all weights with prefix "encoder.surf_token_embeds.weights." + # Find all weights with prefix "encoder.surf_token_embeds.weights." for name, weight in list(checkpoint.items()): - if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."): - # This shouldn't get called with current logic but leaving here for future proofing and in cases where its called - # outside current context - assert weight.shape[2] <= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ - into model with max_history_size {self.max_history_size} for weight {name}" - + if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith( + "encoder.atmos_token_embeds.weights." + ): + # This shouldn't get called with current logic but leaving here for future proofing + # and in cases where its called outside current context + assert ( + weight.shape[2] <= self.max_history_size + ), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ + into model with max_history_size {self.max_history_size} for weight {name}" + # Initialize the new weight tensor - new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), - device=weight.device, dtype=weight.dtype) - - # Copy the existing weights to the new tensor by duplicating the histories provided into any new history dimensions + new_weight = torch.zeros( + (weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), + device=weight.device, + dtype=weight.dtype, + ) + + # Copy the existing weights to the new tensor by duplicating the histories provided + # into any new history dimensions for j in range(weight.shape[2]): # only fill existing weights, others are zeros new_weight[:, :, j, :, :] = weight[:, :, j, :, :] checkpoint[name] = new_weight return checkpoint - + def configure_activation_checkpointing(self): """Configure activation checkpointing. diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py index e43c48f..83793cc 100644 --- a/tests/test_checkpoint_adaptation.py +++ b/tests/test_checkpoint_adaptation.py @@ -2,52 +2,62 @@ import pytest import torch -from unittest.mock import patch + from aurora.model.aurora import AuroraSmall + @pytest.fixture def model(request): return AuroraSmall(max_history_size=request.param) + @pytest.fixture def checkpoint(): return { "encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), - "encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)) + "encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), } -#check both history sizes which are divisible by 2 (original shape) and not -@pytest.mark.parametrize('model', [4, 5], indirect=True) + +# check both history sizes which are divisible by 2 (original shape) and not +@pytest.mark.parametrize("model", [4, 5], indirect=True) def test_adapt_checkpoint_max_history(model, checkpoint): # checkpoint starts with history dim, shape[2], as size 2 - assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 + assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) - + for name, weight in adapted_checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): if j >= checkpoint[name].shape[2]: assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) else: - assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) + assert torch.equal( + weight[:, :, j, :, :], + checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], + ) -#check that assert is thrown when trying to load a larger checkpoint to a smaller history size -@pytest.mark.parametrize('model', [1], indirect=True) + +# check that assert is thrown when trying to load a larger checkpoint to a smaller history size +@pytest.mark.parametrize("model", [1], indirect=True) def test_adapt_checkpoint_max_history_fail(model, checkpoint): with pytest.raises(AssertionError): model.adapt_checkpoint_max_history_size(checkpoint) - -#test adapting the checkpoint twice to ensure that the second time should not change the weights -@pytest.mark.parametrize('model', [4], indirect=True) + + +# test adapting the checkpoint twice to ensure that the second time should not change the weights +@pytest.mark.parametrize("model", [4], indirect=True) def test_adapt_checkpoint_max_history_twice(model, checkpoint): - adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) - + for name, weight in adapted_checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): if j >= checkpoint[name].shape[2]: assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) else: - assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) - + assert torch.equal( + weight[:, :, j, :, :], + checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], + )