Skip to content

Commit

Permalink
Merge branch 'main' into wesselb/static-vars-fudge-factor
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb authored Sep 15, 2024
2 parents 4018121 + 8fecaee commit 88fb245
Show file tree
Hide file tree
Showing 19 changed files with 492 additions and 213 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ Install with `pip`:
pip install microsoft-aurora
```

Or with `conda` / `mamba`:

```bash
mamba install microsoft-aurora -c conda-forge
```

Run the pretrained small model on random data:

```python
Expand Down Expand Up @@ -122,14 +128,14 @@ This code has not been developed nor tested for non-academic purposes and hence

### Limitations
Although Aurora was trained to accurately predict future weather and air pollution,
Aurora is based on neural networks, which means that there are no strict guarantees that predicts will always be accurate.
Aurora is based on neural networks, which means that there are no strict guarantees that predictions will always be accurate.
Altering the inputs, providing a sample that was not in the training set,
or even providing a sample that was in the training set but is simply unlucky may result in arbitrarily poor predictions.
In addition, even though Aurora was trained on a wide variety of data sets,
it is possible that Aurora inherits biases present in any one of those data sets.
A forecasting system like Aurora is only one piece of the puzzle in a weather prediction pipeline,
and its outputs are not meant to be directly used by people or business to plan their operations.
A series of additional verification tests are needed before it could become operationally useful.
and its outputs are not meant to be directly used by people or businesses to plan their operations.
A series of additional verification tests are needed before it can become operationally useful.

### Data
The models included in the code have been trained on a variety of publicly available data.
Expand Down
101 changes: 101 additions & 0 deletions aurora/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import dataclasses
from datetime import datetime
from functools import partial
from typing import Callable

import numpy as np
import torch
from scipy.interpolate import RegularGridInterpolator as RGI

from aurora.normalisation import (
normalise_atmos_var,
Expand Down Expand Up @@ -169,3 +172,101 @@ def to(self, device: str | torch.device) -> "Batch":
def type(self, t: type) -> "Batch":
"""Convert everything to type `t`."""
return self._fmap(lambda x: x.type(t))

def regrid(self, res: float) -> "Batch":
"""Regrid the batch to a `res` degrees resolution.
This results in `float32` data on the CPU.
This function is not optimised for either speed or accuracy. Use at your own risk.
"""

shape = (round(180 / res) + 1, round(360 / res))
lat_new = torch.from_numpy(np.linspace(90, -90, shape[0]))
lon_new = torch.from_numpy(np.linspace(0, 360, shape[1], endpoint=False))
interpolate_res = partial(
interpolate,
lat=self.metadata.lat,
lon=self.metadata.lon,
lat_new=lat_new,
lon_new=lon_new,
)

return Batch(
surf_vars={k: interpolate_res(v) for k, v in self.surf_vars.items()},
static_vars={k: interpolate_res(v) for k, v in self.static_vars.items()},
atmos_vars={k: interpolate_res(v) for k, v in self.atmos_vars.items()},
metadata=Metadata(
lat=lat_new,
lon=lon_new,
atmos_levels=self.metadata.atmos_levels,
time=self.metadata.time,
rollout_step=self.metadata.rollout_step,
),
)


def interpolate(
v: torch.Tensor,
lat: torch.Tensor,
lon: torch.Tensor,
lat_new: torch.Tensor,
lon_new: torch.Tensor,
) -> torch.Tensor:
"""Interpolate a variable `v` with latitudes `lat` and longitudes `lon` to new latitudes
`lat_new` and new longitudes `lon_new`."""
# Perform the interpolation in double precision.
return torch.from_numpy(
interpolate_numpy(
v.double().numpy(),
lat.double().numpy(),
lon.double().numpy(),
lat_new.double().numpy(),
lon_new.double().numpy(),
)
).float()


def interpolate_numpy(
v: np.ndarray,
lat: np.ndarray,
lon: np.ndarray,
lat_new: np.ndarray,
lon_new: np.ndarray,
) -> np.ndarray:
"""Like :func:`.interpolate`, but for NumPy tensors."""

# Implement periodic longitudes in `lon`.
assert (np.diff(lon) > 0).all()
lon = np.concatenate((lon[-1:] - 360, lon, lon[:1] + 360))

# Merge all batch dimensions into one.
batch_shape = v.shape[:-2]
v = v.reshape(-1, *v.shape[-2:])

# Loop over all batch elements.
vs_regridded = []
for vi in v:
# Implement periodic longitudes in `vi`.
vi = np.concatenate((vi[:, -1:], vi, vi[:, :1]), axis=1)

rgi = RGI(
(lat, lon),
vi,
method="linear",
bounds_error=False, # Allow out of bounds, for the latitudes.
fill_value=None, # Extrapolate latitudes if they are out of bounds.
)
lat_new_grid, lon_new_grid = np.meshgrid(
lat_new,
lon_new,
indexing="ij",
sparse=True,
)
vs_regridded.append(rgi((lat_new_grid, lon_new_grid)))

# Recreate the batch dimensions.
v_regridded = np.stack(vs_regridded, axis=0)
v_regridded = v_regridded.reshape(*batch_shape, lat_new.shape[0], lon_new.shape[0])

return v_regridded
109 changes: 90 additions & 19 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import contextlib
import dataclasses
from datetime import timedelta
from functools import partial
from typing import Optional

import torch
from huggingface_hub import hf_hub_download
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)

from aurora.batch import Batch
from aurora.model.decoder import Perceiver3DDecoder
from aurora.model.encoder import Perceiver3DEncoder
from aurora.model.lora import LoRAMode
from aurora.model.swin3d import Swin3DTransformerBackbone
from aurora.model.swin3d import BasicLayer3D, Swin3DTransformerBackbone

__all__ = ["Aurora", "AuroraSmall", "AuroraHighRes"]

Expand Down Expand Up @@ -49,19 +53,17 @@ def __init__(
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
surf_stats: Optional[dict[str, tuple[float, float]]] = None,
autocast: bool = False,
) -> None:
"""Construct an instance of the model.
Args:
surf_vars (tuple[str, ...], optional): All surface-level variables supported by the
model. The model is sensitive to the order of `surf_vars`! Currently, adding
one more variable here causes the model to incorrectly load the static variables.
It is possible to hack around this. We are working on a more principled fix. Please
open an issue if this is a problem for you.
model.
static_vars (tuple[str, ...], optional): All static variables supported by the
model. The model is sensitive to the order of `static_vars`!
model.
atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the
model. The model is sensitive to the order of `atmos-vars`!
model.
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
Expand Down Expand Up @@ -100,12 +102,15 @@ def __init__(
surf_stats (dict[str, tuple[float, float]], optional): For these surface-level
variables, adjust the normalisation to the given tuple consisting of a new location
and scale.
autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to
`False`.
"""
super().__init__()
self.surf_vars = surf_vars
self.atmos_vars = atmos_vars
self.patch_size = patch_size
self.surf_stats = surf_stats or dict()
self.autocast = autocast

self.encoder = Perceiver3DEncoder(
surf_vars=surf_vars,
Expand Down Expand Up @@ -159,9 +164,6 @@ def forward(self, batch: Batch) -> Batch:
Args:
batch (:class:`Batch`): Batch to run the model on.
Raises:
ValueError: If no metric is provided.
Returns:
:class:`Batch`: Prediction for the batch.
"""
Expand Down Expand Up @@ -190,12 +192,13 @@ def forward(self, batch: Batch) -> Batch:
batch,
lead_time=timedelta(hours=6),
)
x = self.backbone(
x,
lead_time=timedelta(hours=6),
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)
with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext():
x = self.backbone(
x,
lead_time=timedelta(hours=6),
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)
pred = self.decoder(
x,
batch,
Expand All @@ -204,7 +207,6 @@ def forward(self, batch: Batch) -> Batch:
)

# Remove batch and history dimension from static variables.
B, T = next(iter(batch.surf_vars.values()))[0]
pred = dataclasses.replace(
pred,
static_vars={k: v[0, 0] for k, v in batch.static_vars.items()},
Expand All @@ -231,20 +233,89 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None:
strict (bool, optional): Error if the model parameters are not exactly equal to the
parameters in the checkpoint. Defaults to `True`.
"""
path = hf_hub_download(repo_id=repo, filename=name)
self.load_checkpoint_local(path, strict=strict)

def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
"""Load a checkpoint directly from a file.
Args:
path (str): Path to the checkpoint.
strict (bool, optional): Error if the model parameters are not exactly equal to the
parameters in the checkpoint. Defaults to `True`.
"""
# Assume that all parameters are either on the CPU or on the GPU.
device = next(self.parameters()).device

path = hf_hub_download(repo_id=repo, filename=name)
d = torch.load(path, map_location=device, weights_only=True)

# Rename keys to ensure compatibility.
# You can safely ignore all cumbersome processing below. We modified the model after we
# trained it. The code below manually adapts the checkpoints, so the checkpoints are
# compatible with the new model.

# Remove possibly prefix from the keys.
for k, v in list(d.items()):
if k.startswith("net."):
del d[k]
d[k[4:]] = v

# Convert the ID-based parametrisation to a name-based parametrisation.

if "encoder.surf_token_embeds.weight" in d:
weight = d["encoder.surf_token_embeds.weight"]
del d["encoder.surf_token_embeds.weight"]

assert weight.shape[1] == 4 + 3
for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")):
d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]]

if "encoder.atmos_token_embeds.weight" in d:
weight = d["encoder.atmos_token_embeds.weight"]
del d["encoder.atmos_token_embeds.weight"]

assert weight.shape[1] == 5
for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]

if "decoder.surf_head.weight" in d:
weight = d["decoder.surf_head.weight"]
bias = d["decoder.surf_head.bias"]
del d["decoder.surf_head.weight"]
del d["decoder.surf_head.bias"]

assert weight.shape[0] == 4 * self.patch_size**2
assert bias.shape[0] == 4 * self.patch_size**2
weight = weight.reshape(self.patch_size**2, 4, -1)
bias = bias.reshape(self.patch_size**2, 4)

for i, name in enumerate(("2t", "10u", "10v", "msl")):
d[f"decoder.surf_heads.{name}.weight"] = weight[:, i]
d[f"decoder.surf_heads.{name}.bias"] = bias[:, i]

if "decoder.atmos_head.weight" in d:
weight = d["decoder.atmos_head.weight"]
bias = d["decoder.atmos_head.bias"]
del d["decoder.atmos_head.weight"]
del d["decoder.atmos_head.bias"]

assert weight.shape[0] == 5 * self.patch_size**2
assert bias.shape[0] == 5 * self.patch_size**2
weight = weight.reshape(self.patch_size**2, 5, -1)
bias = bias.reshape(self.patch_size**2, 5)

for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

self.load_state_dict(d, strict=strict)

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.
This is required in order to compute gradients without running out of memory.
"""
apply_activation_checkpointing(self, check_fn=lambda x: isinstance(x, BasicLayer3D))


AuroraSmall = partial(
Aurora,
Expand Down
Loading

0 comments on commit 88fb245

Please sign in to comment.