diff --git a/coast/__init__.py b/coast/__init__.py index 37279aa6..b9388ca2 100644 --- a/coast/__init__.py +++ b/coast/__init__.py @@ -28,6 +28,7 @@ from .data.copernicus import Copernicus, Product from ._utils.experiments_file_handling import experiments from ._utils.experiments_file_handling import nemo_filename_maker +from ._utils.coordinates import Coordinates2D, Coordinates3D, Coordinates4D, Coordinates # Set default for logging level when coast is imported import logging diff --git a/coast/_utils/coordinates.py b/coast/_utils/coordinates.py new file mode 100644 index 00000000..2e1f04e2 --- /dev/null +++ b/coast/_utils/coordinates.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from numpy import number, datetime64 +from numbers import Number +from typing import Union, Optional + + +Numeric = Optional[Union[Number, number]] +Coordinate = Union[Numeric, datetime64] + + +@dataclass +class Coordinates2D: + """Represent a point in one-to-two-dimensional space with optional X and Y coordinates.""" + + x: Numeric + y: Numeric + + +@dataclass +class Coordinates3D(Coordinates2D): + """Represent a point in one-to-three-dimensional space with optional X, Y, and Z coordinates.""" + + z: Numeric + + +@dataclass +class Coordinates4D(Coordinates3D): + """Represent a point in one-to-four-dimensional spacetime with optional X, Y, Z, and T coordinates.""" + + t: datetime64 + + +Coordinates = Union[Coordinates2D, Coordinates3D, Coordinates4D] diff --git a/coast/data/coast.py b/coast/data/coast.py index c0966701..02d61d20 100644 --- a/coast/data/coast.py +++ b/coast/data/coast.py @@ -1,6 +1,5 @@ """The coast class is the main access point into this package.""" -from typing import Any, Dict, List - +from typing import Any, Dict, List, Optional from dask import array import xarray as xr import matplotlib @@ -9,6 +8,7 @@ from dask.distributed import Client import copy from .._utils.logging_util import get_slug, debug, info, warn, warning +from .._utils.coordinates import Coordinates, Coordinates3D, Coordinates4D, Coordinate from .opendap import OpendapInfo @@ -507,6 +507,189 @@ def plot_cartopy(self, var: str, plot_var: array, params, time_counter: int = 0) info("Displaying plot!") plt.show() + def set_constraint(self, start: Coordinates, end: Coordinates, drop: bool = True) -> None: + """Constrain the underlying dataset to values within an arbitrarily sized hyperrectangle. + + Coordinates that exceed the boundaries of the dataset will wrap around. i.e. a longitude value of 190 applied + to a dataset with a maximum value of 180 would wrap to -10. + + Args: + start: The start coordinates of the shape to define. + end: The end coordinates of the shape to define. + drop: Whether values should be dropped from the constrained dataset (if False, they will be NaNed). + """ + self.dataset = self.constrain(start, end, drop=drop) + + def constrain(self, start: Coordinates, end: Coordinates, drop: bool = True) -> xr.Dataset: + """Return the underlying dataset with values constrained to an arbitrarily sized hyperrectangle. + + Coordinates that exceed the boundaries of the dataset will wrap around. i.e. a longitude value of 190 applied + to a dataset with a maximum value of 180 would wrap to -10. + + Args: + start: The start coordinates of the shape to define. + end: The end coordinates of the shape to define. + drop: Whether values should be dropped from the constrained dataset (if False, they will be NaNed). + + Returns: + The underlying dataset with values constrained to within the defined selection. + """ + return constrain(self.dataset, start, end, drop=drop) + + @property + def x_dim(self) -> xr.DataArray: + """Return the X coordinate array of the underlying dataset.""" + return x_dim(self.dataset) + + @property + def y_dim(self) -> xr.DataArray: + """Return the Y coordinate array of the underlying dataset.""" + return y_dim(self.dataset) + + @property + def z_dim(self) -> xr.DataArray: + """Return the Z coordinate array of the underlying dataset.""" + return z_dim(self.dataset) + + @property + def t_dim(self) -> xr.DataArray: + """Return the T[ime] coordinate array of the underlying dataset.""" + return t_dim(self.dataset) + + def get_coord(self, dim: str) -> xr.DataArray: + """Get the coordinate array for a dimension from the underlying dataset. + + Args: + dim: The name of the dimension (i.e. "x", "y", "z", or "t"). + + Returns: + The corresponding coordinate array from the underlying dataset. + """ + return get_coord(self.dataset, dim) + def plot_movie(self): """Plot movie.""" raise NotImplementedError + + +def create_constraint(start: Coordinate, end: Coordinate, dim: xr.DataArray) -> np.typing.NDArray[bool]: + """Create a mask to exclude coordinates that do not fall within a range of two arbitrary values. + + Coordinates that exceed the boundaries of the dataset will wrap around. i.e. a longitude value of 190 applied + to a dataset with a maximum value of 180 would wrap to -10. + + Args: + start: The start of the range of values to constrain within. + end: The end of the range of values ot constrain within. + dim: The coordinate array to constrain values from. + + Returns: + A mask that can be applied to dim to exclude unwanted values. + """ + minimum = dim.min().item() + maximum = dim.max().item() + + mask = np.logical_and(dim >= start, dim <= end) + if start < minimum or start > maximum: + diff = start % minimum + mask = np.logical_or(mask, dim >= maximum + diff) + if end > maximum or end < minimum: + diff = end % maximum + mask = np.logical_or(mask, dim <= minimum + diff) + + return mask + + +def get_coord(dataset: xr.Dataset, dim: str) -> xr.DataArray: + """Get the coordinate array for a dimension in a dataset. + + Args: + dataset: The dataset to interrogate. + dim: The name of the dimension (i.e. "x", "y", "z", or "t"). + + Returns: + The corresponding coordinate array from the provided dataset. + """ + # TODO Really not a fan of this, is there an easier way to get the mapping? + return dataset[list(dataset[f"{dim.lower()}_dim"].coords)[0]] + + +def x_dim(dataset: xr.Dataset) -> xr.DataArray: + """Get the X coordinate array for a dimension in a dataset. + + Args: + dataset: The dataset to interrogate. + + Returns: + The corresponding coordinate array from the provided dataset. + """ + return get_coord(dataset, "x") + + +def y_dim(dataset: xr.Dataset) -> xr.DataArray: + """Get the Y coordinate array for a dimension in a dataset. + + Args: + dataset: The dataset to interrogate. + + Returns: + The corresponding coordinate array from the provided dataset. + """ + return get_coord(dataset, "y") + + +def z_dim(dataset: xr.Dataset) -> xr.DataArray: + """Get the Z coordinate array for a dimension in a dataset. + + Args: + dataset: The dataset to interrogate. + + Returns: + The corresponding coordinate array from the provided dataset. + """ + return get_coord(dataset, "z") + + +def t_dim(dataset: xr.Dataset) -> xr.DataArray: + """Get the T[ime] coordinate array for a dimension in a dataset. + + Args: + dataset: The dataset to interrogate. + + Returns: + The corresponding coordinate array from the provided dataset. + """ + return get_coord(dataset, "t") + + +def constrain(dataset: xr.Dataset, start: Coordinates, end: Coordinates, drop: bool = True) -> xr.Dataset: + """Constrain values within a dataset to an arbitrarily sized hyperrectangle. + + Coordinates that exceed the boundaries of the dataset will wrap around. i.e. a longitude value of 190 applied + to a dataset with a maximum value of 180 would wrap to -10. + + Args: + dataset: The dataset to constrain values from. + start: The start coordinates of the shape to define. + end: The end coordinates of the shape to define. + drop: Whether values should be dropped from the constrained dataset (if False, they will be NaNed). + + Returns: + The provided dataset with values constrained to within the defined selection. + """ + assert type(start) == type(end), "Coordinates must be of the same dimensionality!" + + constrained = dataset + if (x_start := start.x is not None) and (x_end := end.x is not None): + assert x_start == x_end, "Tried to constrain on X with a missing paired value!" + constrained = constrained.where(create_constraint(start.x, end.x, x_dim(constrained)), drop=drop) + if (y_start := start.y is not None) and (y_end := end.y is not None): + assert y_start == y_end, "Tried to constrain on Y with a missing paired value!" + constrained = constrained.where(create_constraint(start.y, end.y, y_dim(constrained)), drop=drop) + if isinstance(start, Coordinates3D) and (z_start := start.z is not None) and (z_end := end.z is not None): + assert z_start == z_end, "Tried to constrain on Z with a missing paired value!" + constrained = constrained.where(create_constraint(start.z, end.y, z_dim(constrained)), drop=drop) + if isinstance(start, Coordinates4D) and (t_start := start.t is not None) and (t_end := end.t is not None): + assert t_start == t_end, "Tried to constrain on Z with a missing paired value!" + constrained = constrained.where(create_constraint(start.t, end.t, t_dim(constrained)), drop=drop) + return constrained diff --git a/tests/test_subsetting.py b/tests/test_subsetting.py new file mode 100644 index 00000000..0c9330de --- /dev/null +++ b/tests/test_subsetting.py @@ -0,0 +1,64 @@ +import logging +from os import environ +from pathlib import Path +import pytest +from coast import Copernicus, Coast, Gridded, Coordinates2D + + +DATABASE = "nrt" +PRODUCT_ID = "global-analysis-forecast-phy-001-024" +CONFIG = (Path(__file__).parent.parent / "config" / "example_cmems_grid_t.json").resolve(strict=True) +USERNAME = environ.get("COPERNICUS_USERNAME") +PASSWORD = environ.get("COPERNICUS_PASSWORD") +CREDENTIALS = USERNAME is not None and PASSWORD is not None + + +@pytest.fixture(name="copernicus") +def copernicus_fixture() -> Copernicus: + """Return a functional Copernicus data accessor.""" + return Copernicus(USERNAME, PASSWORD, DATABASE) + + +@pytest.fixture(name="gridded") +def gridded_fixture(copernicus) -> Gridded: + forecast = copernicus.get_product("global-analysis-forecast-phy-001-024") + return Gridded(fn_data=forecast, config=str(CONFIG)) + + +@pytest.mark.skipif(condition=not CREDENTIALS, reason="Copernicus credentials are not set.") +def test_2d(gridded): + start = Coordinates2D(10, 13) + end = Coordinates2D(20, 50) + # Validate test values + assert gridded.dataset.longitude.min().item() < start.x + assert gridded.dataset.latitude.min().item() < start.y + assert gridded.dataset.longitude.max().item() > end.x + assert gridded.dataset.latitude.max().item() > end.y + + # Constrain dataset + constrained = gridded.constrain(start, end) + + # Check constrained dataset + assert constrained.longitude.min().item() == start.x + assert constrained.latitude.min().item() == start.y + assert constrained.longitude.max().item() == end.x + assert constrained.latitude.max().item() == end.y + + +@pytest.mark.skipif(condition=not CREDENTIALS, reason="Copernicus credentials are not set.") +def test_wrap(gridded): + start = Coordinates2D(175, 50) + end = Coordinates2D(gridded.dataset.longitude.max().item() + 5, 60) + wrapped = gridded.dataset.longitude.min().item() + 5 + + # Constraint dataset + constrained = gridded.constrain(start, end) + + # Check constrained dataset + assert wrapped < start.x + assert wrapped in constrained.longitude + assert constrained.max().item() == 185 % gridded.dataset.max.item() + + +if not CREDENTIALS: + logging.warning("https://marine.copernicus.eu/ credentials not set, integration tests will not be run!")