Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new logic to enable stored checkpoint weights to be copied to new history dimensions #36

Merged
merged 11 commits into from
Sep 23, 2024
Merged
50 changes: 43 additions & 7 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from datetime import timedelta
from functools import partial
from typing import Optional
from typing import Any, Optional

import torch
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -112,6 +112,7 @@ def __init__(
self.patch_size = patch_size
self.surf_stats = surf_stats or dict()
self.autocast = autocast
self.max_history_size = max_history_size

if self.surf_stats:
warnings.warn(
Expand Down Expand Up @@ -268,23 +269,22 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
del d[k]
d[k[4:]] = v

# Convert the ID-based parametrisation to a name-based parametrisation.

# Convert the ID-based parametrization to a name-based parametrization.
if "encoder.surf_token_embeds.weight" in d:
weight = d["encoder.surf_token_embeds.weight"]
del d["encoder.surf_token_embeds.weight"]

assert weight.shape[1] == 4 + 3
for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")):
d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]]

if "encoder.atmos_token_embeds.weight" in d:
weight = d["encoder.atmos_token_embeds.weight"]
del d["encoder.atmos_token_embeds.weight"]

assert weight.shape[1] == 5
for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]

if "decoder.surf_head.weight" in d:
weight = d["decoder.surf_head.weight"]
Expand Down Expand Up @@ -316,8 +316,42 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

#check if history size is compatible and adjust weights if necessary
if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]:
d = self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]:
assert False, f"Cannot load checkpoint with max_history_size {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
into model with max_history_size {self.max_history_size}"

self.load_state_dict(d, strict=strict)

def adapt_checkpoint_max_history_size(self, checkpoint) -> Any:
"""Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size
than the current model.

If a checkpoint was trained with a larger max_history_size than the current model, this function
will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint
which will likely cause the checkpoint to degrade is performance.

This implementation copies weights from the checkpoint to the model, duplicating the weights.
"""
#find all weights with prefix "encoder.surf_token_embeds.weights."
for name, weight in list(checkpoint.items()):
if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."):
#this shouldn't get called with current logic but leaving here for future proofing and in cases where its called
#outside current context
assert weight.shape[2] <= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
into model with max_history_size {self.max_history_size} for weight {name}"

# Initialize the new weight tensor with zeros
new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]))
wesselb marked this conversation as resolved.
Show resolved Hide resolved
# Copy the existing weights to the new tensor my duplicating the histories provided in to any new history dimensions
for j in range(new_weight.shape[2]):
new_weight[:, :, j, :, :] = weight[:, :, j % weight.shape[2], :, :]
wesselb marked this conversation as resolved.
Show resolved Hide resolved
checkpoint[name] = new_weight

return checkpoint

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.

Expand All @@ -342,4 +376,6 @@ def configure_activation_checkpointing(self):
patch_size=10,
encoder_depths=(6, 8, 8),
decoder_depths=(8, 8, 6),
# One particular static variable requires a different normalisation.
surf_stats={"z": (-3.270407e03, 6.540335e04)},
wesselb marked this conversation as resolved.
Show resolved Hide resolved
)
46 changes: 46 additions & 0 deletions tests/test_checkpoint_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import pytest
import torch
from unittest.mock import patch
from aurora.model.aurora import AuroraSmall

@pytest.fixture
def model(request):
return AuroraSmall(max_history_size=request.param)

@pytest.fixture
def checkpoint():
return {
"encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)),
"encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4))
}

#check both history sizes which are divisible by 2 (original shape) and not
@pytest.mark.parametrize('model', [4, 5], indirect=True)
def test_adapt_checkpoint_max_history(model, checkpoint):
# checkpoint starts with history dim, shape[2], as size 2
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :])

#check that assert is thrown when trying to load a larger checkpoint to a smaller history size
@pytest.mark.parametrize('model', [1], indirect=True)
def test_adapt_checkpoint_max_history_fail(model, checkpoint):
with pytest.raises(AssertionError):
model.adapt_checkpoint_max_history_size(checkpoint)

#test adapting the checkpoint twice to ensure that the second time should not change the weights
@pytest.mark.parametrize('model', [4], indirect=True)
def test_adapt_checkpoint_max_history_twice(model, checkpoint):
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint)

for name, weight in adapted_checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :])
Loading