Skip to content

Commit

Permalink
Merge pull request #39 from microsoft/wesselb/adjust-docs-and-minor-c…
Browse files Browse the repository at this point in the history
…hanges

Minor style changes and docstring adaptations
  • Loading branch information
a-lucic authored Sep 23, 2024
2 parents c694fca + d1afde8 commit c0311d2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 51 deletions.
66 changes: 34 additions & 32 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from datetime import timedelta
from functools import partial
from typing import Any, Optional
from typing import Optional

import torch
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -93,7 +93,9 @@ def __init__(
separate parameter.
perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used
to stabilise the model.
max_history_size (int, optional): Maximum number of history steps.
max_history_size (int, optional): Maximum number of history steps. You can load
checkpoints with a smaller `max_history_size`, but you cannot load checkpoints
with a larger `max_history_size`.
use_lora (bool, optional): Use LoRA adaptation.
lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out
steps.
Expand Down Expand Up @@ -316,54 +318,54 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

# check if history size is compatible and adjust weights if necessary
if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]:
d = self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]:
raise AssertionError(f"Cannot load checkpoint with max_history_size \
{d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
into model with max_history_size {self.max_history_size}")
# Check if the history size is compatible and adjust weights if necessary.
current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2]
if self.max_history_size > current_history_size:
self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < current_history_size:
raise AssertionError(
f"Cannot load checkpoint with `max_history_size` {current_history_size} "
f"into model with `max_history_size` {self.max_history_size}."
)

self.load_state_dict(d, strict=strict)

def adapt_checkpoint_max_history_size(self, checkpoint) -> Any:
"""Adapt a checkpoint with smaller max_history_size to a model with a larger
max_history_size than the current model.
def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) -> None:
"""Adapt a checkpoint with smaller `max_history_size` to a model with a larger
`max_history_size` than the current model.
If a checkpoint was trained with a larger max_history_size than the current model,
If a checkpoint was trained with a larger `max_history_size` than the current model,
this function will assert fail to prevent loading the checkpoint. This is to
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
performance.
This implementation copies weights from the checkpoint to the model and fills 0
for the new history width dimension.
This implementation copies weights from the checkpoint to the model and fills zeros
for the new history width dimension. It mutates `checkpoint`.
"""
# Find all weights with prefix "encoder.surf_token_embeds.weights."
for name, weight in list(checkpoint.items()):
if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith(
"encoder.atmos_token_embeds.weights."
):
# We only need to adapt the patch embedding in the encoder.
enc_surf_embedding = name.startswith("encoder.surf_token_embeds.weights.")
enc_atmos_embedding = name.startswith("encoder.atmos_token_embeds.weights.")
if enc_surf_embedding or enc_atmos_embedding:
# This shouldn't get called with current logic but leaving here for future proofing
# and in cases where its called outside current context
assert (
weight.shape[2] <= self.max_history_size
), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
into model with max_history_size {self.max_history_size} for weight {name}"

# Initialize the new weight tensor
# and in cases where its called outside current context.
if not (weight.shape[2] <= self.max_history_size):
raise AssertionError(
f"Cannot load checkpoint with `max_history_size` {weight.shape[2]} "
f"into model with `max_history_size` {self.max_history_size}."
)

# Initialize the new weight tensor.
new_weight = torch.zeros(
(weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]),
device=weight.device,
dtype=weight.dtype,
)

# Copy the existing weights to the new tensor by duplicating the histories provided
# into any new history dimensions
for j in range(weight.shape[2]):
# only fill existing weights, others are zeros
new_weight[:, :, j, :, :] = weight[:, :, j, :, :]
# into any new history dimensions. The rest remains at zero.
new_weight[:, :, : weight.shape[2]] = weight

checkpoint[name] = new_weight
return checkpoint

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.
Expand Down
35 changes: 16 additions & 19 deletions tests/test_checkpoint_adaptation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import numpy as np
import pytest
import torch

Expand All @@ -19,45 +20,41 @@ def checkpoint():
}


# check both history sizes which are divisible by 2 (original shape) and not
# Check both history sizes which are divisible by 2 (original shape) and not.
@pytest.mark.parametrize("model", [4, 5], indirect=True)
def test_adapt_checkpoint_max_history(model, checkpoint):
# checkpoint starts with history dim, shape[2], as size 2
# Checkpoint starts with history dim., `shape[2]`, equal to 2.
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
for name, weight in checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :])
else:
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)
np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])


# check that assert is thrown when trying to load a larger checkpoint to a smaller history size
@pytest.mark.parametrize("model", [1], indirect=True)
def test_adapt_checkpoint_max_history_fail(model, checkpoint):
"""Check that an assertion error is thrown when trying to load a larger checkpoint to a
smaller history size."""
with pytest.raises(AssertionError):
model.adapt_checkpoint_max_history_size(checkpoint)


# test adapting the checkpoint twice to ensure that the second time should not change the weights
@pytest.mark.parametrize("model", [4], indirect=True)
def test_adapt_checkpoint_max_history_twice(model, checkpoint):
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint)
"""Test adapting the checkpoint twice to ensure that the second time should not change the
weights."""
model.adapt_checkpoint_max_history_size(checkpoint)
model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
for name, weight in checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :])
else:
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)
np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])

0 comments on commit c0311d2

Please sign in to comment.