Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework ID-based parametrisation #27

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,11 @@ def __init__(

Args:
surf_vars (tuple[str, ...], optional): All surface-level variables supported by the
model. The model is sensitive to the order of `surf_vars`! Currently, adding
one more variable here causes the model to incorrectly load the static variables.
It is possible to hack around this. We are working on a more principled fix. Please
open an issue if this is a problem for you.
model.
static_vars (tuple[str, ...], optional): All static variables supported by the
model. The model is sensitive to the order of `static_vars`!
model.
atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the
model. The model is sensitive to the order of `atmos-vars`!
model.
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
Expand Down Expand Up @@ -231,12 +228,64 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None:
path = hf_hub_download(repo_id=repo, filename=name)
d = torch.load(path, map_location=device, weights_only=True)

# Rename keys to ensure compatibility.
# You can safely ignore all cumbersome processing below. We modified the model after we
# trained it. The code below manually adapts the checkpoints, so the checkpoints are
# compatible with the new model.

# Remove possibly prefix from the keys.
for k, v in list(d.items()):
if k.startswith("net."):
del d[k]
d[k[4:]] = v

# Convert the ID-based parametrisation to a name-based parametrisation.

if "encoder.surf_token_embeds.weight" in d:
weight = d["encoder.surf_token_embeds.weight"]
del d["encoder.surf_token_embeds.weight"]

assert weight.shape[1] == 4 + 3
for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")):
d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]]

if "encoder.atmos_token_embeds.weight" in d:
weight = d["encoder.atmos_token_embeds.weight"]
del d["encoder.atmos_token_embeds.weight"]

assert weight.shape[1] == 5
for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]

if "decoder.surf_head.weight" in d:
weight = d["decoder.surf_head.weight"]
bias = d["decoder.surf_head.bias"]
del d["decoder.surf_head.weight"]
del d["decoder.surf_head.bias"]

assert weight.shape[0] == 4 * self.patch_size**2
assert bias.shape[0] == 4 * self.patch_size**2
weight = weight.reshape(self.patch_size**2, 4, -1)
bias = bias.reshape(self.patch_size**2, 4)

for i, name in enumerate(("2t", "10u", "10v", "msl")):
d[f"decoder.surf_heads.{name}.weight"] = weight[:, i]
d[f"decoder.surf_heads.{name}.bias"] = bias[:, i]

if "decoder.atmos_head.weight" in d:
weight = d["decoder.atmos_head.weight"]
bias = d["decoder.atmos_head.bias"]
del d["decoder.atmos_head.weight"]
del d["decoder.atmos_head.bias"]

assert weight.shape[0] == 5 * self.patch_size**2
assert bias.shape[0] == 5 * self.patch_size**2
weight = weight.reshape(self.patch_size**2, 5, -1)
bias = bias.reshape(self.patch_size**2, 5)

for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

self.load_state_dict(d, strict=strict)


Expand Down
27 changes: 13 additions & 14 deletions aurora/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from aurora.model.perceiver import PerceiverResampler
from aurora.model.util import (
check_lat_lon_dtype,
create_var_map,
get_ids_for_var_map,
init_weights,
unpatchify,
)
Expand Down Expand Up @@ -60,8 +58,6 @@ def __init__(
self.patch_size = patch_size
self.surf_vars = surf_vars
self.atmos_vars = atmos_vars
self.surf_var_map = create_var_map(surf_vars)
self.atmos_var_map = create_var_map(atmos_vars)
self.embed_dim = embed_dim

self.level_decoder = PerceiverResampler(
Expand All @@ -76,8 +72,12 @@ def __init__(
ln_eps=perceiver_ln_eps,
)

self.surf_head = nn.Linear(embed_dim, len(surf_vars) * patch_size**2)
self.atmos_head = nn.Linear(embed_dim, len(atmos_vars) * patch_size**2)
self.surf_heads = nn.ParameterDict(
{name: nn.Linear(embed_dim, patch_size**2) for name in surf_vars}
)
self.atmos_heads = nn.ParameterDict(
{name: nn.Linear(embed_dim, patch_size**2) for name in atmos_vars}
)

self.atmos_levels_embed = nn.Linear(embed_dim, embed_dim)

Expand Down Expand Up @@ -145,10 +145,10 @@ def forward(
W=patch_res[2],
)

# Decode surface vars.
x_surf = self.surf_head(x[..., :1, :]) # (B, L, 1, V_S*p*p)
surf_var_ids = get_ids_for_var_map(surf_vars, self.surf_var_map, x_surf.device)
surf_preds = unpatchify(x_surf, len(self.surf_vars), H, W, self.patch_size)[:, surf_var_ids]
# Decode surface vars. Run the head for every surface-level variable.
x_surf = torch.stack([self.surf_heads[name](x[..., :1, :]) for name in surf_vars], dim=-1)
x_surf = x_surf.reshape(*x_surf.shape[:3], -1) # (B, L, 1, V_S*p*p)
surf_preds = unpatchify(x_surf, len(surf_vars), H, W, self.patch_size)
surf_preds = surf_preds.squeeze(2) # (B, V_S, H, W)

# Embed the atmospheric levels.
Expand All @@ -162,10 +162,9 @@ def forward(
x_atmos = self.deaggregate_levels(levels_embed, x[..., 1:, :]) # (B, L, C_A, D)

# Decode the atmospheric vars.
x_atmos = self.atmos_head(x_atmos) # (B, L, C_A, V_A*p*p)
atmos_var_ids = get_ids_for_var_map(atmos_vars, self.atmos_var_map, x.device)
atmos_preds = unpatchify(x_atmos, len(self.atmos_vars), H, W, self.patch_size)
atmos_preds = atmos_preds[:, atmos_var_ids]
x_atmos = torch.stack([self.atmos_heads[name](x_atmos) for name in atmos_vars], dim=-1)
x_atmos = x_atmos.reshape(*x_atmos.shape[:3], -1) # (B, L, C_A, V_A*p*p)
atmos_preds = unpatchify(x_atmos, len(atmos_vars), H, W, self.patch_size)

return Batch(
{v: surf_preds[:, i] for i, v in enumerate(surf_vars)},
Expand Down
20 changes: 10 additions & 10 deletions aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from aurora.model.posencoding import pos_scale_enc
from aurora.model.util import (
check_lat_lon_dtype,
create_var_map,
get_ids_for_var_map,
init_weights,
)

Expand Down Expand Up @@ -78,8 +76,6 @@ def __init__(

# We treat the static variables as surface variables in the model.
surf_vars = surf_vars + static_vars if static_vars is not None else surf_vars
self.surf_var_map = create_var_map(surf_vars)
self.atmos_var_map = create_var_map(atmos_vars)

# Latent tokens
assert latent_levels > 1, "At least two latent levels are required."
Expand All @@ -102,10 +98,16 @@ def __init__(
# Patch embeddings
assert max_history_size > 0, "At least one history step is required."
self.surf_token_embeds = LevelPatchEmbed(
len(surf_vars), patch_size, embed_dim, max_history_size
surf_vars,
patch_size,
embed_dim,
max_history_size,
)
self.atmos_token_embeds = LevelPatchEmbed(
len(atmos_vars), patch_size, embed_dim, max_history_size
atmos_vars,
patch_size,
embed_dim,
max_history_size,
)

# Learnable pressure level aggregation
Expand Down Expand Up @@ -194,14 +196,12 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:

# Patch embed the surface level.
x_surf = rearrange(x_surf, "b t v h w -> b v t h w")
surf_ids = get_ids_for_var_map(surf_vars, self.surf_var_map, x_surf.device)
x_surf = self.surf_token_embeds(x_surf, surf_ids) # (B, L, D)
x_surf = self.surf_token_embeds(x_surf, surf_vars) # (B, L, D)
dtype = x_surf.dtype # When using mixed precision, we need to keep track of the dtype.

# Patch embed the atmospheric levels.
atmos_ids = get_ids_for_var_map(atmos_vars, self.atmos_var_map, x_atmos.device)
x_atmos = rearrange(x_atmos, "b t v c h w -> (b c) v t h w")
x_atmos = self.atmos_token_embeds(x_atmos, atmos_ids)
x_atmos = self.atmos_token_embeds(x_atmos, atmos_vars)
x_atmos = rearrange(x_atmos, "(b c) l d -> b c l d", b=B, c=C)

# Add surface level encoding. This helps the model distinguish between surface and
Expand Down
44 changes: 26 additions & 18 deletions aurora/model/patchembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LevelPatchEmbed(nn.Module):

def __init__(
self,
max_vars: int,
var_names: tuple[str, ...],
patch_size: int,
embed_dim: int,
history_size: int = 1,
Expand All @@ -27,7 +27,7 @@ def __init__(
"""Initialise.

Args:
max_vars (int): Maximum number of variables to embed.
var_names (tuple[str, ...]): Variables to embed.
patch_size (int): Patch size.
embed_dim (int): Embedding dimensionality.
history_size (int, optional): Number of history dimensions. Defaults to `1`.
Expand All @@ -38,18 +38,19 @@ def __init__(
"""
super().__init__()

self.max_vars = max_vars
self.var_names = var_names
self.kernel_size = (history_size,) + to_2tuple(patch_size)
self.flatten = flatten
self.embed_dim = embed_dim

weight = torch.cat(
# Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every variable
# separately.
[torch.empty(embed_dim, 1, *self.kernel_size) for _ in range(max_vars)],
dim=1,
self.weights = nn.ParameterDict(
{
# Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every
# variable separately.
name: nn.Parameter(torch.empty(embed_dim, 1, *self.kernel_size))
for name in var_names
}
)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(torch.empty(embed_dim))
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

Expand All @@ -63,40 +64,47 @@ def init_weights(self) -> None:
#
# https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
#
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
for weight in self.weights.values():
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

# The following initialisation is taken from
#
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv3d
#
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(next(iter(self.weights.values())))
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor, var_ids: list[int]) -> torch.Tensor:
def forward(self, x: torch.Tensor, var_names: tuple[str, ...]) -> torch.Tensor:
"""Run the embedding.

Args:
x (:class:`torch.Tensor`): Tensor to embed of a shape of `(B, V, T, H, W)`.
var_ids (list[int]): A list of variable IDs. The length should be equal to `V`.
var_names (tuple[str, ...]): Names of the variables in `x`. The length should be equal
to `V`.

Returns:
:class:`torch.Tensor`: Embedded tensor a shape of `(B, L, D]) if flattened,
where `L = H * W / P^2`. Otherwise, the shape is `(B, D, H', W')`.

"""
B, V, T, H, W = x.shape
assert len(var_ids) == V, f"{V} != {len(var_ids)}."
assert len(var_names) == V, f"{V} != {len(var_names)}."
assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}."
assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0."
assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0."
assert max(var_ids) < self.max_vars, f"{max(var_ids)} >= {self.max_vars}."
assert min(var_ids) >= 0, f"{min(var_ids)} < 0."
assert len(set(var_ids)) == len(var_ids), f"{var_ids} contains duplicates."
assert len(set(var_names)) == len(var_names), f"{var_names} contains duplicates."

# Select the weights of the variables and history dimensions that are present in the batch.
weight = self.weight[:, var_ids, :T, ...] # (C_out, C_in, T, H, W)
weight = torch.cat(
[
# (C_out, C_in, T, H, W)
self.weights[name][:, :, :T, ...]
for name in var_names
],
dim=1,
)
# Adjust the stride if history is smaller than maximum.
stride = (T,) + self.kernel_size[1:]

Expand Down
33 changes: 0 additions & 33 deletions aurora/model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

__all__ = [
"unpatchify",
"create_var_map",
"get_ids_for_var_map",
"check_lat_lon_dtype",
"maybe_adjust_windows",
"init_weights",
Expand Down Expand Up @@ -43,37 +41,6 @@ def unpatchify(x: torch.Tensor, V: int, H: int, W: int, P: int) -> torch.Tensor:
return x


def create_var_map(variables: tuple[str, ...]) -> dict[str, int]:
"""Create dictionary where the keys are variable names and values are unique IDs.

Args:
variables (tuple[str, ...]): Variable strings.

Returns:
dict[str, int]: Variable map dictionary.
"""
return {v: i for i, v in enumerate(variables)}


def get_ids_for_var_map(
variables: tuple,
var_maps: dict,
device: torch.cuda.device,
) -> torch.Tensor:
"""Construct a tensor of variable IDs after retrieving those from a variable map created with
:func:`.create_var_map`.

Args:
variables (tuples[str, ...]): Variables to retrieve the IDs for.
var_maps (dict[str, int]): Variable map constructed with :func:`.create_var_map`.
device (torch.cuda.device): Device.

Returns:
torch.Tensor: Tensor of variable IDs found in `var_map`.
"""
return torch.tensor([var_maps[v] for v in variables], device=device)


def check_lat_lon_dtype(lat: torch.Tensor, lon: torch.Tensor) -> None:
"""Assert that `lat` and `lon` are at least `float32`s."""
assert lat.dtype in [torch.float32, torch.float64], f"Latitude num. unstable: {lat.dtype}."
Expand Down
11 changes: 0 additions & 11 deletions docs/beware.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,3 @@ If you changed the model and added or removed parameters, you need to set `stric
loading a checkpoint `Aurora.load_checkpoint(..., strict=False)`.
Importantly, enabling or disabling LoRA for a model that was trained respectively without or
with LoRA changes the parameters!

## Extending the Model with New Surface-Level Variables

Whereas we have attempted to design a robust and flexible model,
inevitably some unfortunate design choices slipped through.

A notable unfortunate design choice is that extending the model with a new surface-level
variable breaks compatibility with existing checkpoints.
It is possible to hack around this in a relatively simple way.
We are working on a more principled fix.
Please open an issue if this is a problem for you.
Loading
Loading