Skip to content

Commit

Permalink
Implement crop and validate lats and lons
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 9, 2024
1 parent c5403af commit 37bb4d8
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions aurora/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,24 @@ class Metadata:
lat (:class:`torch.Tensor`): Latitudes.
lon (:class:`torch.Tensor`): Longitudes.
time (tuple[datetime, ...]): For every batch element, the time.
atmos_levels (tuple[int |float, ...): Pressure levels for the atmospheric variables in hPa.
atmos_levels (tuple[int |float, ...]): Pressure levels for the atmospheric variables in hPa.
"""

lat: torch.Tensor
lon: torch.Tensor
time: tuple[datetime, ...]
atmos_levels: tuple[int | float, ...]

def __post_init__(self):
if not torch.all(self.lat[1:] - self.lat[:-1] < 0):
raise ValueError("Latitudes must be strictly decreasing.")
if not (torch.all(self.lat <= 90) and torch.all(self.lat >= -90)):
raise ValueError("Latitudes must be in the range [-90, 90].")
if not torch.all(self.lon[1:] - self.lon[:-1] > 0):
raise ValueError("Longitudes must be strictly increasing.")
if not (torch.all(self.lon >= 0) and torch.all(self.lon < 360)):
raise ValueError("Longitudes must be in the range [0, 360).")


@dataclasses.dataclass
class Batch:
Expand Down Expand Up @@ -83,6 +93,21 @@ def unnormalise(self) -> "Batch":
def crop(self, patch_size: int) -> "Batch":
"""Crop the variables in the batch to patch size `patch_size`."""
h, w = self.spatial_shape
assert h % patch_size == 0
assert w % patch_size == 0
return self

if w % patch_size != 0:
raise ValueError("Width of the data must be a multiple of the patch size.")

if h % patch_size == 0:
return self
elif h % patch_size == 1:
return Batch(
surf_vars={k: v[..., :-1, :] for k, v in self.surf_vars.items()},
static_vars={k: v[..., :-1, :] for k, v in self.static_vars.items()},
atmos_vars={k: v[..., :-1, :] for k, v in self.atmos_vars.items()},
metadata=self.metadata,
)
else:
raise ValueError(
f"There can at most be one latitude too many, "
f"but there are {h % patch_size} too many."
)

0 comments on commit 37bb4d8

Please sign in to comment.