From d6153c3c7a28a25f9caeddf99c405ff8973a9f2d Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Tue, 6 Aug 2024 16:15:42 +0200 Subject: [PATCH] Add model test --- aurora/batch.py | 32 ++++++- aurora/model/aurora.py | 26 +++++- aurora/normalisation.py | 202 ++++++++++++++++++++++++++++++++++++++++ tests/test_model.py | 91 +++++++++++++++--- 4 files changed, 334 insertions(+), 17 deletions(-) create mode 100644 aurora/normalisation.py diff --git a/aurora/batch.py b/aurora/batch.py index dde7b9d..38a8ece 100644 --- a/aurora/batch.py +++ b/aurora/batch.py @@ -5,6 +5,13 @@ import torch +from aurora.normalisation import ( + normalise_atmos_var, + normalise_surf_var, + unnormalise_atmos_var, + unnormalise_surf_var, +) + __all__ = ["Metadata", "Batch"] @@ -47,16 +54,35 @@ class Batch: @property def spatial_shape(self) -> tuple[int, int]: """Get the spatial shape from an arbitrary surface-level variable.""" - return list(self.surf_vars.values())[0].shape[-2:] + return next(iter(self.surf_vars.values())).shape[-2:] def normalise(self) -> "Batch": """Normalise all variables in the batch.""" - return self + return Batch( + surf_vars={k: normalise_surf_var(v, k) for k, v in self.surf_vars.items()}, + static_vars={k: normalise_surf_var(v, k) for k, v in self.static_vars.items()}, + atmos_vars={ + k: normalise_atmos_var(v, k, self.metadata.atmos_levels) + for k, v in self.atmos_vars.items() + }, + metadata=self.metadata, + ) def unnormalise(self) -> "Batch": """Unnormalise all variables in the batch.""" - return self + return Batch( + surf_vars={k: unnormalise_surf_var(v, k) for k, v in self.surf_vars.items()}, + static_vars={k: unnormalise_surf_var(v, k) for k, v in self.static_vars.items()}, + atmos_vars={ + k: unnormalise_atmos_var(v, k, self.metadata.atmos_levels) + for k, v in self.atmos_vars.items() + }, + metadata=self.metadata, + ) def crop(self, patch_size: int) -> "Batch": """Crop the variables in the batch to patch size `patch_size`.""" + h, w = self.spatial_shape + assert h % patch_size == 0 + assert w % patch_size == 0 return self diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 92b8951..f71247a 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -1,5 +1,6 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +import dataclasses from datetime import timedelta from functools import partial @@ -119,6 +120,13 @@ def forward(self, batch: Batch) -> Batch: W // self.encoder.patch_size, ) + # Insert batch and history dimension for static variables. + B, T = next(iter(batch.surf_vars.values())).shape[:2] + batch = dataclasses.replace( + batch, + static_vars={k: v[None, None].repeat(B, T, 1, 1) for k, v in batch.static_vars.items()}, + ) + x = self.encoder( batch, lead_time=timedelta(hours=6), @@ -136,13 +144,25 @@ def forward(self, batch: Batch) -> Batch: patch_res=patch_res, ) - # TODO: Ensure time dim is present and time is right. + # Remove batch and history dimension from static variables. + B, T = next(iter(batch.surf_vars.values()))[0] + pred = dataclasses.replace( + pred, + static_vars={k: v[0, 0] for k, v in batch.static_vars.items()}, + ) + + # Insert history dimension in prediction. The time should already be right. + pred = dataclasses.replace( + pred, + surf_vars={k: v[:, None] for k, v in pred.surf_vars.items()}, + atmos_vars={k: v[:, None] for k, v in pred.atmos_vars.items()}, + ) pred = pred.unnormalise() return pred - def load_checkpoint(self, repo: str, name: str) -> None: + def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: path = hf_hub_download(repo_id=repo, filename=name) d = torch.load(path, map_location="cpu") @@ -152,7 +172,7 @@ def load_checkpoint(self, repo: str, name: str) -> None: del d[k] d[k[4:]] = v - self.load_state_dict(d, strict=True) + self.load_state_dict(d, strict=strict) AuroraSmall = partial( diff --git a/aurora/normalisation.py b/aurora/normalisation.py new file mode 100644 index 0000000..96fb277 --- /dev/null +++ b/aurora/normalisation.py @@ -0,0 +1,202 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +from functools import partial + +import torch + +__all__ = [ + "normalise_surf_var", + "normalise_atmos_var", + "unnormalise_surf_var", + "unnormalise_atmos_var", +] + + +def normalise_surf_var( + x: torch.Tensor, + name: str, + unnormalise: bool = False, +) -> torch.Tensor: + """Normalise a surface-level variable.""" + location = locations[name] + scale = scales[name] + if unnormalise: + return x * scale + location + else: + return (x - location) / scale + + +def normalise_atmos_var( + x: torch.Tensor, + name: str, + atmos_levels: tuple[int | float, ...], + unnormalise: bool = False, +) -> torch.Tensor: + """Normalise an atmospheric variable.""" + level_locations: list[int | float] = [] + level_scales: list[int | float] = [] + for level in atmos_levels: + level_locations.append(locations[f"{name}_{level}"]) + level_scales.append(scales[f"{name}_{level}"]) + location = torch.tensor(level_locations, dtype=x.dtype, device=x.device) + scale = torch.tensor(level_scales, dtype=x.dtype, device=x.device) + + if unnormalise: + return x * scale[..., None, None] + location[..., None, None] + else: + return (x - location[..., None, None]) / scale[..., None, None] + + +unnormalise_surf_var = partial(normalise_surf_var, unnormalise=True) +unnormalise_atmos_var = partial(normalise_atmos_var, unnormalise=True) + + +locations: dict[str, float] = { + "z": -1.386496e03, + "lsm": 0.000000e00, + "slt": 0.000000e00, + "2t": 2.785140e02, + "10u": -5.135059e-02, + "10v": 1.891580e-01, + "msl": 1.009578e05, + "z_50": 1.993730e05, + "z_100": 1.576421e05, + "z_150": 1.331414e05, + "z_200": 1.153300e05, + "z_250": 1.012231e05, + "z_300": 8.941415e04, + "z_400": 6.998038e04, + "z_500": 5.411537e04, + "z_600": 4.064833e04, + "z_700": 2.892882e04, + "z_850": 1.374978e04, + "z_925": 7.015005e03, + "z_1000": 7.381545e02, + "u_50": 5.653076e00, + "u_100": 1.027951e01, + "u_150": 1.354061e01, + "u_200": 1.420915e01, + "u_250": 1.334584e01, + "u_300": 1.180173e01, + "u_400": 8.817291e00, + "u_500": 6.563273e00, + "u_600": 4.814521e00, + "u_700": 3.345237e00, + "u_850": 1.418379e00, + "u_925": 6.172657e-01, + "u_1000": -3.328723e-02, + "v_50": 4.226111e-03, + "v_100": 1.411897e-02, + "v_150": -3.697671e-02, + "v_200": -4.507801e-02, + "v_250": -2.980338e-02, + "v_300": -2.294770e-02, + "v_400": -1.771003e-02, + "v_500": -2.387986e-02, + "v_600": -2.716674e-02, + "v_700": 2.153583e-02, + "v_850": 1.428150e-01, + "v_925": 2.053480e-01, + "v_1000": 1.867637e-01, + "t_50": 2.124864e02, + "t_100": 2.084042e02, + "t_150": 2.133201e02, + "t_200": 2.180615e02, + "t_250": 2.227710e02, + "t_300": 2.288696e02, + "t_400": 2.421368e02, + "t_500": 2.529492e02, + "t_600": 2.611347e02, + "t_700": 2.674010e02, + "t_850": 2.745600e02, + "t_925": 2.773572e02, + "t_1000": 2.810130e02, + "q_50": 2.678180e-06, + "q_100": 2.633677e-06, + "q_150": 5.254625e-06, + "q_200": 1.940632e-05, + "q_250": 5.773618e-05, + "q_300": 1.273861e-04, + "q_400": 3.855659e-04, + "q_500": 8.529599e-04, + "q_600": 1.541429e-03, + "q_700": 2.431637e-03, + "q_850": 4.575618e-03, + "q_925": 6.033134e-03, + "q_1000": 7.030342e-03, +} + +scales: dict[str, float] = { + "z": 5.884467e04, + "lsm": 1.000000e00, + "slt": 7.000000e00, # 7 or 8?! + "2t": 2.122036e01, + "10u": 5.547512e00, + "10v": 4.765339e00, + "msl": 1.332246e03, + "z_50": 5.875553e03, + "z_100": 5.510640e03, + "z_150": 5.823912e03, + "z_200": 5.820169e03, + "z_250": 5.536585e03, + "z_300": 5.091916e03, + "z_400": 4.150851e03, + "z_500": 3.353187e03, + "z_600": 2.695808e03, + "z_700": 2.136436e03, + "z_850": 1.470321e03, + "z_925": 1.228997e03, + "z_1000": 1.072307e03, + "u_50": 1.529281e01, + "u_100": 1.352611e01, + "u_150": 1.604335e01, + "u_200": 1.767630e01, + "u_250": 1.796710e01, + "u_300": 1.711917e01, + "u_400": 1.434276e01, + "u_500": 1.198419e01, + "u_600": 1.033421e01, + "u_700": 9.168821e00, + "u_850": 8.188043e00, + "u_925": 7.940808e00, + "u_1000": 6.141778e00, + "v_50": 7.058931e00, + "v_100": 7.479310e00, + "v_150": 9.571990e00, + "v_200": 1.188069e01, + "v_250": 1.338039e01, + "v_300": 1.334044e01, + "v_400": 1.122955e01, + "v_500": 9.181708e00, + "v_600": 7.803569e00, + "v_700": 6.871040e00, + "v_850": 6.264443e00, + "v_925": 6.470644e00, + "v_1000": 5.308203e00, + "t_50": 1.026284e01, + "t_100": 1.252901e01, + "t_150": 8.928709e00, + "t_200": 7.189547e00, + "t_250": 8.529282e00, + "t_300": 1.071679e01, + "t_400": 1.269102e01, + "t_500": 1.306447e01, + "t_600": 1.342046e01, + "t_700": 1.476523e01, + "t_850": 1.558880e01, + "t_925": 1.608798e01, + "t_1000": 1.713983e01, + "q_50": 3.571687e-07, + "q_100": 5.703754e-07, + "q_150": 3.794077e-06, + "q_200": 2.267534e-05, + "q_250": 7.446644e-05, + "q_300": 1.684361e-04, + "q_400": 5.078644e-04, + "q_500": 1.079294e-03, + "q_600": 1.769722e-03, + "q_700": 2.549169e-03, + "q_850": 4.112368e-03, + "q_925": 5.071058e-03, + "q_1000": 5.913548e-03, +} diff --git a/tests/test_model.py b/tests/test_model.py index c6278c8..fd8aa1a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,30 +1,99 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" import os +import pickle from datetime import datetime +from typing import TypedDict +import numpy as np import torch +from huggingface_hub import hf_hub_download from aurora import AuroraSmall, Batch, Metadata -def test_aurora_small(): +class SavedMetadata(TypedDict): + """Type of metadata of a saved test batch.""" + + lat: np.ndarray + lon: np.ndarray + time: list[datetime] + atmos_levels: list[int | float] + + +class SavedBatch(TypedDict): + """Type of a saved test batch.""" + + surf_vars: dict[str, np.ndarray] + static_vars: dict[str, np.ndarray] + atmos_vars: dict[str, np.ndarray] + metadata: SavedMetadata + + +def test_aurora_small() -> None: model = AuroraSmall() - model.load_checkpoint(os.environ["HUGGINGFACE_REPO"], "aurora-0.25-small-pretrained.ckpt") + # Load test input. + path = hf_hub_download( + repo_id=os.environ["HUGGINGFACE_REPO"], + filename="aurora-0.25-small-pretrained-test-input.pickle", + ) + with open(path, "rb") as f: + test_input: SavedBatch = pickle.load(f) + # Load test output. + path = hf_hub_download( + repo_id=os.environ["HUGGINGFACE_REPO"], + filename="aurora-0.25-small-pretrained-test-output.pickle", + ) + with open(path, "rb") as f: + test_output: SavedBatch = pickle.load(f) + + # Load static variables. + path = hf_hub_download( + repo_id=os.environ["HUGGINGFACE_REPO"], + filename="aurora-0.25-static.pickle", + ) + with open(path, "rb") as f: + static_vars: dict[str, np.ndarray] = pickle.load(f) + + # Select the test region for the static variables. For convenience, these are included wholly. + lat_inds = range(140, 140 + 32) + lon_inds = range(0, 0 + 64) + static_vars = {k: v[lat_inds, :][:, lon_inds] for k, v in static_vars.items()} + + # Construct a proper batch from the test input. batch = Batch( - surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")}, - static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")}, - atmos_vars={k: torch.randn(1, 2, 4, 16, 32) for k in ("z", "u", "v", "t", "q")}, + surf_vars={k: torch.from_numpy(v) for k, v in test_input["surf_vars"].items()}, + static_vars={k: torch.from_numpy(v) for k, v in static_vars.items()}, + atmos_vars={k: torch.from_numpy(v) for k, v in test_input["atmos_vars"].items()}, metadata=Metadata( - lat=torch.linspace(90, -90, 17)[:-1], # Cut off the south pole. - lon=torch.linspace(0, 360, 32 + 1)[:-1], - time=(datetime(2020, 6, 1, 12, 0),), - atmos_levels=(100, 250, 500, 850), + lat=torch.from_numpy(test_input["metadata"]["lat"]), + lon=torch.from_numpy(test_input["metadata"]["lon"]), + atmos_levels=tuple(test_input["metadata"]["atmos_levels"]), + time=tuple(test_input["metadata"]["time"]), ), ) - prediction = model.forward(batch) + # Load the checkpoint and run the model. + model.load_checkpoint(os.environ["HUGGINGFACE_REPO"], "aurora-0.25-small-pretrained.ckpt") + with torch.no_grad(): + torch.manual_seed(0) # Very important to seed! The test data was generated using this. + pred = model.forward(batch) + + def assert_approx_equality(v_out, v_ref) -> None: + err_rel = ((v_out - v_ref) / (v_ref + 1e-10)).abs().mean() + assert err_rel <= 1e-4 + + # Check the outputs. + for k in pred.surf_vars: + assert_approx_equality(pred.surf_vars[k], test_output["surf_vars"][k]) + for k in pred.static_vars: + assert_approx_equality(pred.static_vars[k], static_vars[k]) + for k in pred.atmos_vars: + assert_approx_equality(pred.atmos_vars[k], test_output["atmos_vars"][k]) - assert isinstance(prediction, Batch) + np.testing.assert_allclose(pred.metadata.lon, test_output["metadata"]["lon"]) + np.testing.assert_allclose(pred.metadata.lat, test_output["metadata"]["lat"]) + assert pred.metadata.atmos_levels == tuple(test_output["metadata"]["atmos_levels"]) + assert pred.metadata.time == tuple(test_output["metadata"]["time"])