Skip to content

Commit

Permalink
Merge pull request #21 from microsoft/wesselb/static-vars-fudge-factor
Browse files Browse the repository at this point in the history
Implement adjustment for the normalisation of surface-level variables
  • Loading branch information
a-lucic authored Sep 15, 2024
2 parents 8fecaee + 88fb245 commit 00a1a7a
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 17 deletions.
40 changes: 32 additions & 8 deletions aurora/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,47 @@ def spatial_shape(self) -> tuple[int, int]:
"""Get the spatial shape from an arbitrary surface-level variable."""
return next(iter(self.surf_vars.values())).shape[-2:]

def normalise(self) -> "Batch":
"""Normalise all variables in the batch."""
def normalise(self, surf_stats: dict[str, tuple[float, float]]) -> "Batch":
"""Normalise all variables in the batch.
Args:
surf_stats (dict[str, tuple[float, float]]): For these surface-level variables, adjust
the normalisation to the given tuple consisting of a new location and scale.
Returns:
:class:`.Batch`: Normalised batch.
"""
return Batch(
surf_vars={k: normalise_surf_var(v, k) for k, v in self.surf_vars.items()},
static_vars={k: normalise_surf_var(v, k) for k, v in self.static_vars.items()},
surf_vars={
k: normalise_surf_var(v, k, stats=surf_stats) for k, v in self.surf_vars.items()
},
static_vars={
k: normalise_surf_var(v, k, stats=surf_stats) for k, v in self.static_vars.items()
},
atmos_vars={
k: normalise_atmos_var(v, k, self.metadata.atmos_levels)
for k, v in self.atmos_vars.items()
},
metadata=self.metadata,
)

def unnormalise(self) -> "Batch":
"""Unnormalise all variables in the batch."""
def unnormalise(self, surf_stats: dict[str, tuple[float, float]]) -> "Batch":
"""Unnormalise all variables in the batch.
Args:
surf_stats (dict[str, tuple[float, float]]): For these surface-level variables, adjust
the normalisation to the given tuple consisting of a new location and scale.
Returns:
:class:`.Batch`: Unnormalised batch.
"""
return Batch(
surf_vars={k: unnormalise_surf_var(v, k) for k, v in self.surf_vars.items()},
static_vars={k: unnormalise_surf_var(v, k) for k, v in self.static_vars.items()},
surf_vars={
k: unnormalise_surf_var(v, k, stats=surf_stats) for k, v in self.surf_vars.items()
},
static_vars={
k: unnormalise_surf_var(v, k, stats=surf_stats) for k, v in self.static_vars.items()
},
atmos_vars={
k: unnormalise_atmos_var(v, k, self.metadata.atmos_levels)
for k, v in self.atmos_vars.items()
Expand Down
12 changes: 10 additions & 2 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
from datetime import timedelta
from functools import partial
from typing import Optional

import torch
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
use_lora: bool = True,
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
surf_stats: Optional[dict[str, tuple[float, float]]] = None,
autocast: bool = False,
) -> None:
"""Construct an instance of the model.
Expand Down Expand Up @@ -97,13 +99,17 @@ def __init__(
lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out
steps, and `"all"` uses a different LoRA for every roll-out step. Defaults to
`"single"`.
surf_stats (dict[str, tuple[float, float]], optional): For these surface-level
variables, adjust the normalisation to the given tuple consisting of a new location
and scale.
autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to
`False`.
"""
super().__init__()
self.surf_vars = surf_vars
self.atmos_vars = atmos_vars
self.patch_size = patch_size
self.surf_stats = surf_stats or dict()
self.autocast = autocast

self.encoder = Perceiver3DEncoder(
Expand Down Expand Up @@ -164,7 +170,7 @@ def forward(self, batch: Batch) -> Batch:
# Get the first parameter. We'll derive the data type and device from this parameter.
p = next(self.parameters())
batch = batch.type(p.dtype)
batch = batch.normalise()
batch = batch.normalise(surf_stats=self.surf_stats)
batch = batch.crop(patch_size=self.patch_size)
batch = batch.to(p.device)

Expand Down Expand Up @@ -213,7 +219,7 @@ def forward(self, batch: Batch) -> Batch:
atmos_vars={k: v[:, None] for k, v in pred.atmos_vars.items()},
)

pred = pred.unnormalise()
pred = pred.unnormalise(surf_stats=self.surf_stats)

return pred

Expand Down Expand Up @@ -327,4 +333,6 @@ def configure_activation_checkpointing(self):
patch_size=10,
encoder_depths=(6, 8, 8),
decoder_depths=(8, 8, 6),
# One particular static variable requires a different normalisation.
surf_stats={"z": (-3.270407e03, 6.540335e04)},
)
9 changes: 7 additions & 2 deletions aurora/normalisation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

from functools import partial
from typing import Optional

import torch

Expand All @@ -15,11 +16,15 @@
def normalise_surf_var(
x: torch.Tensor,
name: str,
stats: Optional[dict[str, tuple[float, float]]] = None,
unnormalise: bool = False,
) -> torch.Tensor:
"""Normalise a surface-level variable."""
location = locations[name]
scale = scales[name]
if stats and name in stats:
location, scale = stats[name]
else:
location = locations[name]
scale = scales[name]
if unnormalise:
return x * scale + location
else:
Expand Down
20 changes: 15 additions & 5 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,21 @@ For optimal performance, the model requires the following variables and pressure
### Static Variables

Aurora 0.1° Fine-Tuned requires
[static variables from IFS HRES analysis](https://rda.ucar.edu/datasets/ds113.1/).
However, due to the way the model was trained,
the model requires these variables to be scaled slightly differently.
Therefore, you should use the static variables provided in
[the HuggingFace repository](https://huggingface.co/microsoft/aurora/blob/main/aurora-0.1-static.pickle).
[static variables from IFS HRES analysis](https://rda.ucar.edu/datasets/ds113.1/) regridded
to 0.1° resolution.
Because of differences between implementations of regridding methods, the resulting static
variables might not be exactly equal to the ones we used during training.
For this reason we also uploaded
[the exact static variables which we used during training](https://huggingface.co/microsoft/aurora/blob/main/aurora-0.1-static.pickle).
To use these, you must remove an exception to the normalisation by instantiating
the model in the following way:

```python
from aurora import AuroraHighRes

model = AuroraHighRes(surf_stats=None) # Use static variables from HF repo.
model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt")
```

### Notes

Expand Down

0 comments on commit 00a1a7a

Please sign in to comment.