diff --git a/aurora/batch.py b/aurora/batch.py index 38a8ece..b3f33c9 100644 --- a/aurora/batch.py +++ b/aurora/batch.py @@ -23,7 +23,7 @@ class Metadata: lat (:class:`torch.Tensor`): Latitudes. lon (:class:`torch.Tensor`): Longitudes. time (tuple[datetime, ...]): For every batch element, the time. - atmos_levels (tuple[int |float, ...): Pressure levels for the atmospheric variables in hPa. + atmos_levels (tuple[int |float, ...]): Pressure levels for the atmospheric variables in hPa. """ lat: torch.Tensor @@ -31,6 +31,16 @@ class Metadata: time: tuple[datetime, ...] atmos_levels: tuple[int | float, ...] + def __post_init__(self): + if not torch.all(self.lat[1:] - self.lat[:-1] < 0): + raise ValueError("Latitudes must be strictly decreasing.") + if not (torch.all(self.lat <= 90) and torch.all(self.lat >= -90)): + raise ValueError("Latitudes must be in the range [-90, 90].") + if not torch.all(self.lon[1:] - self.lon[:-1] > 0): + raise ValueError("Longitudes must be strictly increasing.") + if not (torch.all(self.lon >= 0) and torch.all(self.lon < 360)): + raise ValueError("Longitudes must be in the range [0, 360).") + @dataclasses.dataclass class Batch: @@ -83,6 +93,21 @@ def unnormalise(self) -> "Batch": def crop(self, patch_size: int) -> "Batch": """Crop the variables in the batch to patch size `patch_size`.""" h, w = self.spatial_shape - assert h % patch_size == 0 - assert w % patch_size == 0 - return self + + if w % patch_size != 0: + raise ValueError("Width of the data must be a multiple of the patch size.") + + if h % patch_size == 0: + return self + elif h % patch_size == 1: + return Batch( + surf_vars={k: v[..., :-1, :] for k, v in self.surf_vars.items()}, + static_vars={k: v[..., :-1, :] for k, v in self.static_vars.items()}, + atmos_vars={k: v[..., :-1, :] for k, v in self.atmos_vars.items()}, + metadata=self.metadata, + ) + else: + raise ValueError( + f"There can at most be one latitude too many, " + f"but there are {h % patch_size} too many." + )