From 23cea7f0a27de8108179fb72fc28bcff26b29a56 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Fri, 2 Aug 2024 15:46:35 +0000 Subject: [PATCH 1/4] stash spatial slice tests --- .../select/select_spatial_slice.py | 265 +++++++----------- tests/select/test_select_spatial_slice.py | 68 +++++ 2 files changed, 172 insertions(+), 161 deletions(-) create mode 100644 tests/select/test_select_spatial_slice.py diff --git a/ocf_data_sampler/select/select_spatial_slice.py b/ocf_data_sampler/select/select_spatial_slice.py index 19eee30..dbaf4df 100644 --- a/ocf_data_sampler/select/select_spatial_slice.py +++ b/ocf_data_sampler/select/select_spatial_slice.py @@ -1,11 +1,9 @@ """Select spatial slices""" import logging -from typing import Optional, Union import numpy as np import xarray as xr -from scipy.spatial import KDTree from ocf_datapipes.utils import Location from ocf_datapipes.utils.geospatial import ( @@ -23,35 +21,40 @@ # -------------------------------- utility functions -------------------------------- -def convert_coords_to_match_xarray(x, y, from_coords, xr_data): +def convert_coords_to_match_xarray( + x: float | np.ndarray, + y: float | np.ndarray, + from_coords: str, + da: xr.DataArray + ): """Convert x and y coords to cooridnate system matching xarray data Args: x: Float or array-like y: Float or array-like from_coords: String describing coordinate system of x and y - xr_data: xarray data object to which coordinates should be matched + da: DataArray to which coordinates should be matched """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + target_coords, *_ = spatial_coord_type(da) assert from_coords in ["osgb", "lon_lat"] - assert xr_coords in ["geostationary", "osgb", "lon_lat"] + assert target_coords in ["geostationary", "osgb", "lon_lat"] - if xr_coords == "geostationary": + if target_coords == "geostationary": if from_coords == "osgb": - x, y = osgb_to_geostationary_area_coords(x, y, xr_data) + x, y = osgb_to_geostationary_area_coords(x, y, da) elif from_coords == "lon_lat": - x, y = lon_lat_to_geostationary_area_coords(x, y, xr_data) + x, y = lon_lat_to_geostationary_area_coords(x, y, da) - elif xr_coords == "lon_lat": + elif target_coords == "lon_lat": if from_coords == "osgb": x, y = osgb_to_lon_lat(x, y) # else the from_coords=="lon_lat" and we don't need to convert - elif xr_coords == "osgb": + elif target_coords == "osgb": if from_coords == "lon_lat": x, y = lon_lat_to_osgb(x, y) @@ -61,19 +64,19 @@ def convert_coords_to_match_xarray(x, y, from_coords, xr_data): def _get_idx_of_pixel_closest_to_poi( - xr_data: xr.DataArray, + da: xr.DataArray, location: Location, ) -> Location: """ Return x and y index location of pixel at center of region of interest. Args: - xr_data: Xarray dataset - location: Center + da: xarray DataArray + location: Location to find index of Returns: The Location for the center pixel """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + xr_coords, x_dim, y_dim = spatial_coord_type(da) if xr_coords not in ["osgb", "lon_lat"]: raise NotImplementedError(f"Only 'osgb' and 'lon_lat' are supported - not '{xr_coords}'") @@ -83,15 +86,15 @@ def _get_idx_of_pixel_closest_to_poi( location.x, location.y, from_coords=location.coordinate_system, - xr_data=xr_data, + da=da, ) # Check that the requested point lies within the data - assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() - assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() + assert da[x_dim].min() < x < da[x_dim].max() + assert da[y_dim].min() < y < da[y_dim].max() - x_index = xr_data.get_index(xr_x_dim) - y_index = xr_data.get_index(xr_y_dim) + x_index = da.get_index(x_dim) + y_index = da.get_index(y_dim) closest_x = x_index.get_indexer([x], method="nearest")[0] closest_y = y_index.get_indexer([y], method="nearest")[0] @@ -100,103 +103,47 @@ def _get_idx_of_pixel_closest_to_poi( def _get_idx_of_pixel_closest_to_poi_geostationary( - xr_data: xr.DataArray, + da: xr.DataArray, center_osgb: Location, ) -> Location: """ Return x and y index location of pixel at center of region of interest. Args: - xr_data: Xarray dataset + da: xarray DataArray center_osgb: Center in OSGB coordinates Returns: Location for the center pixel in geostationary coordinates """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + _, x_dim, y_dim = spatial_coord_type(da) - x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data) + x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da) center_geostationary = Location(x=x, y=y, coordinate_system="geostationary") # Check that the requested point lies within the data - assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() - assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() + assert da[x_dim].min() < x < da[x_dim].max() + assert da[y_dim].min() < y < da[y_dim].max() # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary: x_index_at_center = searchsorted( - xr_data[xr_x_dim].values, center_geostationary.x, assume_ascending=True + da[x_dim].values, center_geostationary.x, assume_ascending=True ) # y_geostationary is in descending order: y_index_at_center = searchsorted( - xr_data[xr_y_dim].values, center_geostationary.y, assume_ascending=False + da[y_dim].values, center_geostationary.y, assume_ascending=False ) return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx") -def _get_points_from_unstructured_grids( - xr_data: xr.DataArray, - location: Location, - location_idx_name: str = "values", - num_points: int = 1, -): - """ - Get the closest points from an unstructured grid (i.e. Icosahedral grid) - - This is primarily used for the Icosahedral grid, which is not a regular grid, - and so is not an image - - Args: - xr_data: Xarray dataset - location: Location of center point - location_idx_name: Name of the index values dimension - (i.e. where we index into to get the lat/lon for that point) - num_points: Number of points to return (should be width * height) - - Returns: - The closest points from the grid - """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - assert xr_coords == "lon_lat" - - # Check if need to convert from different coordinate system to lat/lon - if location.coordinate_system == "osgb": - longitude, latitude = osgb_to_lon_lat(x=location.x, y=location.y) - location = Location( - x=longitude, - y=latitude, - coordinate_system="lon_lat", - ) - elif location.coordinate_system == "geostationary": - raise NotImplementedError( - "Does not currently support geostationary coordinates when using unstructured grids" - ) - - # Extract lat, lon, and locidx data - lat = xr_data.longitude.values - lon = xr_data.latitude.values - locidx = xr_data[location_idx_name].values - - # Create a KDTree - tree = KDTree(list(zip(lat, lon))) - - # Query with the [longitude, latitude] of your point - _, idx = tree.query([location.x, location.y], k=num_points) - - # Retrieve the location_idxs for these grid points - location_idxs = locidx[idx] - - data = xr_data.sel({location_idx_name: location_idxs}) - return data - - # ---------------------------- sub-functions for slicing ---------------------------- def _slice_patial_spatial_pixel_window_from_xarray( - xr_data, + da, left_idx, right_idx, top_idx, @@ -205,77 +152,89 @@ def _slice_patial_spatial_pixel_window_from_xarray( right_pad_pixels, top_pad_pixels, bottom_pad_pixels, - xr_x_dim, - xr_y_dim, + x_dim, + y_dim, ): """Return spatial window of given pixel size when window partially overlaps input data""" - dx = np.median(np.diff(xr_data[xr_x_dim].values)) - dy = np.median(np.diff(xr_data[xr_y_dim].values)) + dx = np.median(np.diff(da[x_dim].values)) + dy = np.median(np.diff(da[y_dim].values)) if left_pad_pixels > 0: assert right_pad_pixels == 0 x_sel = np.concatenate( [ - xr_data[xr_x_dim].values[0] - np.arange(left_pad_pixels, 0, -1) * dx, - xr_data[xr_x_dim].values[0:right_idx], + da[x_dim].values[0] - np.arange(left_pad_pixels, 0, -1) * dx, + da[x_dim].values[0:right_idx], ] ) - xr_data = xr_data.isel({xr_x_dim: slice(0, right_idx)}).reindex({xr_x_dim: x_sel}) + da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel}) elif right_pad_pixels > 0: assert left_pad_pixels == 0 x_sel = np.concatenate( [ - xr_data[xr_x_dim].values[left_idx:], - xr_data[xr_x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx, + da[x_dim].values[left_idx:], + da[x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx, ] ) - xr_data = xr_data.isel({xr_x_dim: slice(left_idx, None)}).reindex({xr_x_dim: x_sel}) + da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel}) else: - xr_data = xr_data.isel({xr_x_dim: slice(left_idx, right_idx)}) + da = da.isel({x_dim: slice(left_idx, right_idx)}) if top_pad_pixels > 0: assert bottom_pad_pixels == 0 y_sel = np.concatenate( [ - xr_data[xr_y_dim].values[0] - np.arange(top_pad_pixels, 0, -1) * dy, - xr_data[xr_y_dim].values[0:bottom_idx], + da[y_dim].values[0] - np.arange(top_pad_pixels, 0, -1) * dy, + da[y_dim].values[0:bottom_idx], ] ) - xr_data = xr_data.isel({xr_y_dim: slice(0, bottom_idx)}).reindex({xr_y_dim: y_sel}) + da = da.isel({y_dim: slice(0, bottom_idx)}).reindex({y_dim: y_sel}) elif bottom_pad_pixels > 0: assert top_pad_pixels == 0 y_sel = np.concatenate( [ - xr_data[xr_y_dim].values[top_idx:], - xr_data[xr_y_dim].values[-1] + np.arange(1, bottom_pad_pixels + 1) * dy, + da[y_dim].values[top_idx:], + da[y_dim].values[-1] + np.arange(1, bottom_pad_pixels + 1) * dy, ] ) - xr_data = xr_data.isel({xr_y_dim: slice(top_idx, None)}).reindex({xr_x_dim: y_sel}) + da = da.isel({y_dim: slice(top_idx, None)}).reindex({x_dim: y_sel}) else: - xr_data = xr_data.isel({xr_y_dim: slice(top_idx, bottom_idx)}) + da = da.isel({y_dim: slice(top_idx, bottom_idx)}) - return xr_data + return da -def slice_spatial_pixel_window_from_xarray( - xr_data, center_idx, width_pixels, height_pixels, xr_x_dim, xr_y_dim, allow_partial_slice +def _select_spatial_slice_pixels( + da: xr.DataArray, + center_idx: Location, + width_pixels: int, + height_pixels: int, + x_dim: str, + y_dim: str, + allow_partial_slice: bool, ): """Select a spatial slice from an xarray object Args: - xr_data: Xarray object - center_idx: Location object describing the centre of the window + da: xarray DataArray to slice from + center_idx: Location object describing the centre of the window with index coordinates width_pixels: Window with in pixels height_pixels: Window height in pixels - xr_x_dim: Name of the x-dimension in the xr_data - xr_y_dim: Name of the y-dimension in the xr_data + x_dim: Name of the x-dimension in `da` + y_dim: Name of the y-dimension in `da` allow_partial_slice: Whether to allow a partially filled window """ + + assert center_idx.coordinate_system == "idx" + # TODO: It shouldn't take much effort to allow height and width to be odd + assert (width_pixels % 2)==0, "Width must be an even number" + assert (height_pixels % 2)==0, "Height must be an even number" + half_width = width_pixels // 2 half_height = height_pixels // 2 @@ -284,8 +243,8 @@ def slice_spatial_pixel_window_from_xarray( top_idx = int(center_idx.y - half_height) bottom_idx = int(center_idx.y + half_height) - data_width_pixels = len(xr_data[xr_x_dim]) - data_height_pixels = len(xr_data[xr_y_dim]) + data_width_pixels = len(da[x_dim]) + data_height_pixels = len(da[y_dim]) left_pad_required = left_idx < 0 right_pad_required = right_idx >= data_width_pixels @@ -305,8 +264,8 @@ def slice_spatial_pixel_window_from_xarray( (bottom_idx - (data_height_pixels - 1)) if bottom_pad_required else 0 ) - xr_data = _slice_patial_spatial_pixel_window_from_xarray( - xr_data, + da = _slice_patial_spatial_pixel_window_from_xarray( + da, left_idx, right_idx, top_idx, @@ -315,8 +274,8 @@ def slice_spatial_pixel_window_from_xarray( right_pad_pixels, top_pad_pixels, bottom_pad_pixels, - xr_x_dim, - xr_y_dim, + x_dim, + y_dim, ) else: raise ValueError( @@ -327,35 +286,34 @@ def slice_spatial_pixel_window_from_xarray( ) else: - xr_data = xr_data.isel( + da = da.isel( { - xr_x_dim: slice(left_idx, right_idx), - xr_y_dim: slice(top_idx, bottom_idx), + x_dim: slice(left_idx, right_idx), + y_dim: slice(top_idx, bottom_idx), } ) - assert len(xr_data[xr_x_dim]) == width_pixels, ( - f"Expected x-dim len {width_pixels} got {len(xr_data[xr_x_dim])} " + assert len(da[x_dim]) == width_pixels, ( + f"Expected x-dim len {width_pixels} got {len(da[x_dim])} " f"for location {center_idx} for slice {left_idx}:{right_idx}" ) - assert len(xr_data[xr_y_dim]) == height_pixels, ( - f"Expected y-dim len {height_pixels} got {len(xr_data[xr_y_dim])} " + assert len(da[y_dim]) == height_pixels, ( + f"Expected y-dim len {height_pixels} got {len(da[y_dim])} " f"for location {center_idx} for slice {top_idx}:{bottom_idx}" ) - return xr_data + return da # ---------------------------- main functions for slicing --------------------------- def select_spatial_slice_pixels( - xr_data: Union[xr.Dataset, xr.DataArray], + da: xr.DataArray, location: Location, - roi_width_pixels: int, - roi_height_pixels: int, + width_pixels: int, + height_pixels: int, allow_partial_slice: bool = False, - location_idx_name: Optional[str] = None, ): """ Select spatial slice based off pixels from location point of interest @@ -367,43 +325,28 @@ def select_spatial_slice_pixels( input data. Args: - xr_data: Xarray DataArray or Dataset to slice from + da: xarray DataArray to slice from location: Location of interest - roi_height_pixels: ROI height in pixels - roi_width_pixels: ROI width in pixels + height_pixels: Height of the slice in pixels + width_pixels: Width of the slice in pixels allow_partial_slice: Whether to allow a partial slice. - location_idx_name: Name for location index of unstructured grid data, - None if not relevant """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - if location_idx_name is not None: - selected = _get_points_from_unstructured_grids( - xr_data=xr_data, - location=location, - location_idx_name=location_idx_name, - num_points=roi_width_pixels * roi_height_pixels, - ) - else: - if xr_coords == "geostationary": - center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary( - xr_data=xr_data, - center_osgb=location, - ) - else: - center_idx: Location = _get_idx_of_pixel_closest_to_poi( - xr_data=xr_data, - location=location, - ) + xr_coords, x_dim, y_dim = spatial_coord_type(da) - selected = slice_spatial_pixel_window_from_xarray( - xr_data, - center_idx, - roi_width_pixels, - roi_height_pixels, - xr_x_dim, - xr_y_dim, - allow_partial_slice=allow_partial_slice, - ) + if xr_coords == "geostationary": + center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(da, location) + else: + center_idx: Location = _get_idx_of_pixel_closest_to_poi(da, location) + + selected = _select_spatial_slice_pixels( + da, + center_idx, + width_pixels, + height_pixels, + x_dim, + y_dim, + allow_partial_slice=allow_partial_slice, + ) return selected \ No newline at end of file diff --git a/tests/select/test_select_spatial_slice.py b/tests/select/test_select_spatial_slice.py new file mode 100644 index 0000000..c8d194c --- /dev/null +++ b/tests/select/test_select_spatial_slice.py @@ -0,0 +1,68 @@ +import numpy as np +import xarray as xr +from ocf_datapipes.utils import Location + +from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels + + +def test_select_spatial_slice_pixels(): + # Create dummy data + x = np.arange(100) + y = np.arange(100)[::-1] + + da = xr.DataArray( + np.random.normal(size=(len(x), len(y))), + coords=dict( + x_osgb=(["x_osgb"], x), + y_osgb=(["y_osgb"], y), + ) + ) + + location = Location(x=10, y=10, coordinate_system="osgb") + + # Select window which lies within data + da_sliced = select_spatial_slice_pixels( + da, + location, + width_pixels=10, + height_pixels=10, + allow_partial_slice=True, + ) + + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(5, 15)).all() + assert (da_sliced.y_osgb.values == np.arange(15, 5, -1)).all() + assert not da_sliced.isnull().any() + + + # Select window where the edge of the window lies at the edge of the data + da_sliced = select_spatial_slice_pixels( + da, + location, + width_pixels=20, + height_pixels=20, + allow_partial_slice=True, + ) + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(0, 20)).all() + assert (da_sliced.y_osgb.values == np.arange(20, 0, -1)).all() + assert not da_sliced.isnull().any() + + # Select window which is partially outside the boundary of the data + da_sliced = select_spatial_slice_pixels( + da, + location, + width_pixels=30, + height_pixels=30, + allow_partial_slice=True, + ) + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(-5, 25)).all() + assert (da_sliced.y_osgb.values == np.arange(25, -5, -1)).all() + assert da_sliced.isnull().sum() == 275 + + + From 387c8d074dc0c08976b8b5a7bbf73ee6070ede16 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 6 Aug 2024 09:32:52 +0000 Subject: [PATCH 2/4] make coords always in order x, y and always increasing --- ocf_data_sampler/load/gsp.py | 12 ++--- ocf_data_sampler/load/nwp/nwp.py | 2 - ocf_data_sampler/load/nwp/providers/ecmwf.py | 20 ++++++--- ocf_data_sampler/load/nwp/providers/ukv.py | 22 +++++---- ocf_data_sampler/load/nwp/providers/utils.py | 8 ---- ocf_data_sampler/load/satellite.py | 47 +++++++------------- ocf_data_sampler/load/utils.py | 29 ++++++++++++ tests/conftest.py | 2 +- tests/load/test_load_nwp.py | 4 ++ tests/load/test_load_satellite.py | 6 +-- 10 files changed, 83 insertions(+), 69 deletions(-) create mode 100644 ocf_data_sampler/load/utils.py diff --git a/ocf_data_sampler/load/gsp.py b/ocf_data_sampler/load/gsp.py index 2280611..368b3e7 100644 --- a/ocf_data_sampler/load/gsp.py +++ b/ocf_data_sampler/load/gsp.py @@ -1,20 +1,17 @@ from pathlib import Path import pkg_resources -import xarray as xr - import pandas as pd +import xarray as xr -def open_gsp(zarr_path: str | Path): +def open_gsp(zarr_path: str | Path) -> xr.DataArray: # Load GSP generation xr.Dataset ds = xr.open_zarr(zarr_path) # Rename to standard time name - ds = ds.rename({ - "datetime_gmt": "time_utc", - }) + ds = ds.rename({"datetime_gmt": "time_utc"}) # Load UK GSP locations df_gsp_loc = pd.read_csv( @@ -31,7 +28,6 @@ def open_gsp(zarr_path: str | Path): ) - # Return dataarray - return ds["generation_mw"] + return ds.generation_mw diff --git a/ocf_data_sampler/load/nwp/nwp.py b/ocf_data_sampler/load/nwp/nwp.py index dd02958..c74c1a4 100755 --- a/ocf_data_sampler/load/nwp/nwp.py +++ b/ocf_data_sampler/load/nwp/nwp.py @@ -1,7 +1,5 @@ from pathlib import Path -import xarray as xr - from ocf_data_sampler.load.nwp.providers.ukv import open_ukv from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs diff --git a/ocf_data_sampler/load/nwp/providers/ecmwf.py b/ocf_data_sampler/load/nwp/providers/ecmwf.py index a3e8809..46ebfb3 100755 --- a/ocf_data_sampler/load/nwp/providers/ecmwf.py +++ b/ocf_data_sampler/load/nwp/providers/ecmwf.py @@ -1,10 +1,8 @@ """ECMWF provider loaders""" from pathlib import Path import xarray as xr -from ocf_data_sampler.load.nwp.providers.utils import ( - open_zarr_paths, check_time_unique_increasing -) - +from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths +from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray: """ @@ -19,13 +17,21 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray: # Open the data ds = open_zarr_paths(zarr_path) - ds = ds.transpose("init_time", "step", "variable", "latitude", "longitude") + # Rename ds = ds.rename( { "init_time": "init_time_utc", "variable": "channel", } ) - # Sanity checks. - check_time_unique_increasing(ds) + + # Check the timestmps are unique and increasing + check_time_unique_increasing(ds.init_time_utc) + + # Make sure the spatial coords are in increasing order + ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude") + + ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude") + + # TODO: should we control the dtype of the DataArray? return ds.ECMWF_UK diff --git a/ocf_data_sampler/load/nwp/providers/ukv.py b/ocf_data_sampler/load/nwp/providers/ukv.py index 087cead..586817a 100755 --- a/ocf_data_sampler/load/nwp/providers/ukv.py +++ b/ocf_data_sampler/load/nwp/providers/ukv.py @@ -4,10 +4,8 @@ from pathlib import Path - -from ocf_data_sampler.load.nwp.providers.utils import ( - open_zarr_paths, check_time_unique_increasing -) +from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths +from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray: @@ -23,19 +21,25 @@ def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray: # Open the data ds = open_zarr_paths(zarr_path) - ds = ds.transpose("init_time", "step", "variable", "y", "x") + # Rename ds = ds.rename( { "init_time": "init_time_utc", "variable": "channel", - "y": "y_osgb", "x": "x_osgb", + "y": "y_osgb", } ) - # Sanity checks. - assert ds.y_osgb[0] > ds.y_osgb[1], "UKV must run from top-to-bottom." - check_time_unique_increasing(ds) + # Check the timestmps are unique and increasing + check_time_unique_increasing(ds.init_time_utc) + + # Make sure the spatial coords are in increasing order + ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb") + + ds = ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb") + + # TODO: should we control the dtype of the DataArray? return ds.UKV diff --git a/ocf_data_sampler/load/nwp/providers/utils.py b/ocf_data_sampler/load/nwp/providers/utils.py index 9a6d15d..16babf2 100755 --- a/ocf_data_sampler/load/nwp/providers/utils.py +++ b/ocf_data_sampler/load/nwp/providers/utils.py @@ -1,6 +1,5 @@ from pathlib import Path import xarray as xr -import pandas as pd def open_zarr_paths( @@ -33,10 +32,3 @@ def open_zarr_paths( chunks="auto", ) return ds - - -def check_time_unique_increasing(ds: xr.Dataset) -> None: - """Check that the time dimension is unique and increasing""" - time = pd.DatetimeIndex(ds.init_time_utc) - assert time.is_unique - assert time.is_monotonic_increasing \ No newline at end of file diff --git a/ocf_data_sampler/load/satellite.py b/ocf_data_sampler/load/satellite.py index 5596ce9..5b8474b 100755 --- a/ocf_data_sampler/load/satellite.py +++ b/ocf_data_sampler/load/satellite.py @@ -1,14 +1,11 @@ """Satellite loader""" -import logging import subprocess from pathlib import Path import pandas as pd import xarray as xr - - -_log = logging.getLogger(__name__) +from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing def _get_single_sat_data(zarr_path: Path | str) -> xr.DataArray: @@ -73,9 +70,8 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra ``` """ + # Open the data if isinstance(zarr_path, (list, tuple)): - message_files_list = "\n - " + "\n - ".join([str(s) for s in zarr_path]) - _log.info(f"Opening satellite data: {message_files_list}") ds = xr.combine_nested( [_get_single_sat_data(path) for path in zarr_path], concat_dim="time", @@ -83,34 +79,23 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra join="override", ) else: - _log.info(f"Opening satellite data: {zarr_path}") ds = _get_single_sat_data(zarr_path) + # Rename + ds = ds.rename( + { + "variable": "channel", + "time": "time_utc", + } + ) - ds = ds.rename({"variable": "channel"}) - - # Rename coords to be more explicit about exactly what some coordinates hold: - # Note that `rename` renames *both* the coordinates and dimensions, and keeps - # the connection between the dims and coordinates, so we don't have to manually - # use `data_array.set_index()`. - ds = ds.rename({"time": "time_utc"}) - - # Flip coordinates to top-left first - if ds.y_geostationary[0] < ds.y_geostationary[-1]: - ds = ds.isel(y_geostationary=slice(None, None, -1)) - if ds.x_geostationary[0] > ds.x_geostationary[-1]: - ds = ds.isel(x_geostationary=slice(None, None, -1)) - - # Ensure the y and x coords are in the right order (top-left first): - assert ds.y_geostationary[0] > ds.y_geostationary[-1] - assert ds.x_geostationary[0] < ds.x_geostationary[-1] + # Check the timestmps are unique and increasing + check_time_unique_increasing(ds.time_utc) - ds = ds.transpose("time_utc", "channel", "y_geostationary", "x_geostationary") + # Make sure the spatial coords are in increasing order + ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary") - # Sanity checks! - datetime_index = pd.DatetimeIndex(ds.time_utc) - assert datetime_index.is_unique - assert datetime_index.is_monotonic_increasing + ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary") - # Return DataArray - return ds["data"] \ No newline at end of file + # TODO: should we control the dtype of the DataArray? + return ds.data \ No newline at end of file diff --git a/ocf_data_sampler/load/utils.py b/ocf_data_sampler/load/utils.py new file mode 100644 index 0000000..1c012b0 --- /dev/null +++ b/ocf_data_sampler/load/utils.py @@ -0,0 +1,29 @@ +import xarray as xr +import pandas as pd + +def check_time_unique_increasing(datetimes) -> None: + """Check that the time dimension is unique and increasing""" + time = pd.DatetimeIndex(datetimes) + assert time.is_unique + assert time.is_monotonic_increasing + +def make_spatial_coords_increasing(ds: xr.Dataset, x_coord: str, y_coord: str) -> xr.Dataset: + """Make sure the spatial coordinates are in increasing order + + Args: + ds: Xarray Dataset + x_coord: Name of the x coordinate + y_coord: Name of the y coordinate + """ + + # Make sure the coords are in increasing order + if ds[x_coord][0] > ds[x_coord][-1]: + ds = ds.isel({x_coord:slice(None, None, -1)}) + if ds[y_coord][0] > ds[y_coord][-1]: + ds = ds.isel({y_coord:slice(None, None, -1)}) + + # Check the coords are all increasing now + assert (ds[x_coord].diff(dim=x_coord) > 0).all() + assert (ds[y_coord].diff(dim=y_coord) > 0).all() + + return ds \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index b39a6c5..053f2a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,7 +50,7 @@ def ds_nwp_ukv(): # This is much faster: x = np.linspace(-239_000, 857_000, 100) - y = np.linspace(-183_000, 1225_000, 100)[::-1] # UKV data must run top to bottom + y = np.linspace(-183_000, 1225_000, 200) variables = ["si10", "dswrf", "t", "prate"] coords = ( diff --git a/tests/load/test_load_nwp.py b/tests/load/test_load_nwp.py index 714a9fb..076ab87 100755 --- a/tests/load/test_load_nwp.py +++ b/tests/load/test_load_nwp.py @@ -1,5 +1,6 @@ import pandas as pd from xarray import DataArray +import numpy as np from ocf_data_sampler.load.nwp import open_nwp @@ -7,6 +8,9 @@ def test_load_ukv(nwp_ukv_zarr_path): da = open_nwp(zarr_path=nwp_ukv_zarr_path, provider="ukv") assert isinstance(da, DataArray) + assert da.dims == ("init_time_utc", "step", "channel", "x_osgb", "y_osgb") + assert da.shape == (24 * 7, 11, 4, 100, 200) + assert np.issubdtype(da.dtype, np.number) def _test_load_ecmwf(ecmwf_nwp_zarr_path): da = open_nwp(zarr_path=ecmwf_nwp_zarr_path, provider="ecmwf") diff --git a/tests/load/test_load_satellite.py b/tests/load/test_load_satellite.py index 33aecb4..a1858e5 100755 --- a/tests/load/test_load_satellite.py +++ b/tests/load/test_load_satellite.py @@ -7,8 +7,8 @@ def test_open_satellite(sat_zarr_path): da = open_sat_data(zarr_path=sat_zarr_path) assert isinstance(da, xr.DataArray) - assert da.dims == ("time_utc", "channel", "y_geostationary", "x_geostationary") - assert da.shape == (576, 11, 20, 49) - assert da.dtype == np.float32 + assert da.dims == ("time_utc", "channel", "x_geostationary", "y_geostationary") + assert da.shape == (576, 11, 49, 20) + assert np.issubdtype(da.dtype, np.number) From 092147d5687518cd26880915dc4b584515f33948 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 6 Aug 2024 09:37:09 +0000 Subject: [PATCH 3/4] note TODO --- ocf_data_sampler/select/find_contiguous_t0_time_periods.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ocf_data_sampler/select/find_contiguous_t0_time_periods.py b/ocf_data_sampler/select/find_contiguous_t0_time_periods.py index 201f8d5..28756f2 100644 --- a/ocf_data_sampler/select/find_contiguous_t0_time_periods.py +++ b/ocf_data_sampler/select/find_contiguous_t0_time_periods.py @@ -256,6 +256,12 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) -> # b: |--| |---| |---| |------| |-| # In all five, `a` must always start before `b` ends, # and `a` must always end after `b` starts: + + # TODO: <= and >= because we should allow overlap time periods of length 1. e.g. + # a: |----| or |---| + # b: |--| |---| + # These aren't allowed if we use < and >. + overlapping_periods = b[(a_period.start_dt < b.end_dt) & (a_period.end_dt > b.start_dt)] # There are two ways in which two periods may *not* overlap: From fcfd9b6999cd7ecbaf7059ea6c0a6a0164ad5c27 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 6 Aug 2024 11:07:20 +0000 Subject: [PATCH 4/4] add spatial slice fixes and tests --- .../select/select_spatial_slice.py | 88 +++++++------ tests/select/test_select_spatial_slice.py | 124 +++++++++++++++--- 2 files changed, 152 insertions(+), 60 deletions(-) diff --git a/ocf_data_sampler/select/select_spatial_slice.py b/ocf_data_sampler/select/select_spatial_slice.py index dbaf4df..33dc9c7 100644 --- a/ocf_data_sampler/select/select_spatial_slice.py +++ b/ocf_data_sampler/select/select_spatial_slice.py @@ -62,7 +62,8 @@ def convert_coords_to_match_xarray( return x, y - +# TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate +# We should combine them, and consider making a Coord class to help with this def _get_idx_of_pixel_closest_to_poi( da: xr.DataArray, location: Location, @@ -131,9 +132,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( da[x_dim].values, center_geostationary.x, assume_ascending=True ) - # y_geostationary is in descending order: y_index_at_center = searchsorted( - da[y_dim].values, center_geostationary.y, assume_ascending=False + da[y_dim].values, center_geostationary.y, assume_ascending=True ) return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx") @@ -142,36 +142,41 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( # ---------------------------- sub-functions for slicing ---------------------------- -def _slice_patial_spatial_pixel_window_from_xarray( +def _select_partial_spatial_slice_pixels( da, left_idx, right_idx, - top_idx, bottom_idx, + top_idx, left_pad_pixels, right_pad_pixels, - top_pad_pixels, bottom_pad_pixels, + top_pad_pixels, x_dim, y_dim, ): """Return spatial window of given pixel size when window partially overlaps input data""" + # We should never be padding on both sides of a window. This would mean our desired window is + # larger than the size of the input data + assert left_pad_pixels==0 or right_pad_pixels==0 + assert bottom_pad_pixels==0 or top_pad_pixels==0 + dx = np.median(np.diff(da[x_dim].values)) dy = np.median(np.diff(da[y_dim].values)) + # Pad the left of the window if left_pad_pixels > 0: - assert right_pad_pixels == 0 x_sel = np.concatenate( [ - da[x_dim].values[0] - np.arange(left_pad_pixels, 0, -1) * dx, + da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx, da[x_dim].values[0:right_idx], ] ) da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel}) + # Pad the right of the window elif right_pad_pixels > 0: - assert left_pad_pixels == 0 x_sel = np.concatenate( [ da[x_dim].values[left_idx:], @@ -180,31 +185,33 @@ def _slice_patial_spatial_pixel_window_from_xarray( ) da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel}) + # No left-right padding required else: da = da.isel({x_dim: slice(left_idx, right_idx)}) - if top_pad_pixels > 0: - assert bottom_pad_pixels == 0 + # Pad the bottom of the window + if bottom_pad_pixels > 0: y_sel = np.concatenate( [ - da[y_dim].values[0] - np.arange(top_pad_pixels, 0, -1) * dy, - da[y_dim].values[0:bottom_idx], + da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy, + da[y_dim].values[0:top_idx], ] ) - da = da.isel({y_dim: slice(0, bottom_idx)}).reindex({y_dim: y_sel}) + da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel}) - elif bottom_pad_pixels > 0: - assert top_pad_pixels == 0 + # Pad the top of the window + elif top_pad_pixels > 0: y_sel = np.concatenate( [ - da[y_dim].values[top_idx:], - da[y_dim].values[-1] + np.arange(1, bottom_pad_pixels + 1) * dy, + da[y_dim].values[bottom_idx:], + da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy, ] ) - da = da.isel({y_dim: slice(top_idx, None)}).reindex({x_dim: y_sel}) + da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel}) + # No bottom-top padding required else: - da = da.isel({y_dim: slice(top_idx, bottom_idx)}) + da = da.isel({y_dim: slice(bottom_idx, top_idx)}) return da @@ -240,48 +247,47 @@ def _select_spatial_slice_pixels( left_idx = int(center_idx.x - half_width) right_idx = int(center_idx.x + half_width) - top_idx = int(center_idx.y - half_height) - bottom_idx = int(center_idx.y + half_height) + bottom_idx = int(center_idx.y - half_height) + top_idx = int(center_idx.y + half_height) data_width_pixels = len(da[x_dim]) data_height_pixels = len(da[y_dim]) left_pad_required = left_idx < 0 - right_pad_required = right_idx >= data_width_pixels - top_pad_required = top_idx < 0 - bottom_pad_required = bottom_idx >= data_height_pixels + right_pad_required = right_idx > data_width_pixels + bottom_pad_required = bottom_idx < 0 + top_pad_required = top_idx > data_height_pixels - pad_required = any( - [left_pad_required, right_pad_required, top_pad_required, bottom_pad_required] - ) + pad_required = left_pad_required | right_pad_required | bottom_pad_required | top_pad_required if pad_required: if allow_partial_slice: + left_pad_pixels = (-left_idx) if left_pad_required else 0 - right_pad_pixels = (right_idx - (data_width_pixels - 1)) if right_pad_required else 0 - top_pad_pixels = (-top_idx) if top_pad_required else 0 - bottom_pad_pixels = ( - (bottom_idx - (data_height_pixels - 1)) if bottom_pad_required else 0 - ) + right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0 + + bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0 + top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0 - da = _slice_patial_spatial_pixel_window_from_xarray( + + da = _select_partial_spatial_slice_pixels( da, left_idx, right_idx, - top_idx, bottom_idx, + top_idx, left_pad_pixels, right_pad_pixels, - top_pad_pixels, bottom_pad_pixels, + top_pad_pixels, x_dim, y_dim, ) else: raise ValueError( - f"Window for location {center_idx} not available. Missing (left, right, top, " - f"bottom) pixels = ({left_pad_required}, {right_pad_required}, " - f"{top_pad_required}, {bottom_pad_required}). " + f"Window for location {center_idx} not available. Missing (left, right, bottom, " + f"top) pixels = ({left_pad_required}, {right_pad_required}, " + f"{bottom_pad_required}, {top_pad_required}). " f"You may wish to set `allow_partial_slice=True`" ) @@ -289,7 +295,7 @@ def _select_spatial_slice_pixels( da = da.isel( { x_dim: slice(left_idx, right_idx), - y_dim: slice(top_idx, bottom_idx), + y_dim: slice(bottom_idx, top_idx), } ) @@ -299,7 +305,7 @@ def _select_spatial_slice_pixels( ) assert len(da[y_dim]) == height_pixels, ( f"Expected y-dim len {height_pixels} got {len(da[y_dim])} " - f"for location {center_idx} for slice {top_idx}:{bottom_idx}" + f"for location {center_idx} for slice {bottom_idx}:{top_idx}" ) return da diff --git a/tests/select/test_select_spatial_slice.py b/tests/select/test_select_spatial_slice.py index c8d194c..2d7d5bc 100644 --- a/tests/select/test_select_spatial_slice.py +++ b/tests/select/test_select_spatial_slice.py @@ -1,14 +1,17 @@ import numpy as np import xarray as xr from ocf_datapipes.utils import Location +import pytest -from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels +from ocf_data_sampler.select.select_spatial_slice import ( + select_spatial_slice_pixels, _get_idx_of_pixel_closest_to_poi +) - -def test_select_spatial_slice_pixels(): +@pytest.fixture(scope="module") +def da(): # Create dummy data - x = np.arange(100) - y = np.arange(100)[::-1] + x = np.arange(-100, 100) + y = np.arange(-100, 100) da = xr.DataArray( np.random.normal(size=(len(x), len(y))), @@ -17,13 +20,29 @@ def test_select_spatial_slice_pixels(): y_osgb=(["y_osgb"], y), ) ) + return da + + +def test_get_idx_of_pixel_closest_to_poi(da): + + idx_location = _get_idx_of_pixel_closest_to_poi( + da, + location=Location(x=10, y=10, coordinate_system="osgb"), + ) + + assert idx_location.coordinate_system == "idx" + assert idx_location.x == 110 + assert idx_location.y == 110 - location = Location(x=10, y=10, coordinate_system="osgb") - # Select window which lies within data + + +def test_select_spatial_slice_pixels(da): + + # Select window which lies within x-y bounds of the data da_sliced = select_spatial_slice_pixels( da, - location, + location=Location(x=-90, y=-80, coordinate_system="osgb"), width_pixels=10, height_pixels=10, allow_partial_slice=True, @@ -31,38 +50,105 @@ def test_select_spatial_slice_pixels(): assert isinstance(da_sliced, xr.DataArray) - assert (da_sliced.x_osgb.values == np.arange(5, 15)).all() - assert (da_sliced.y_osgb.values == np.arange(15, 5, -1)).all() + assert (da_sliced.x_osgb.values == np.arange(-95, -85)).all() + assert (da_sliced.y_osgb.values == np.arange(-85, -75)).all() + # No padding in this case so no NaNs assert not da_sliced.isnull().any() - # Select window where the edge of the window lies at the edge of the data + # Select window where the edge of the window lies right on the edge of the data da_sliced = select_spatial_slice_pixels( da, - location, + location=Location(x=-90, y=-80, coordinate_system="osgb"), width_pixels=20, height_pixels=20, allow_partial_slice=True, ) assert isinstance(da_sliced, xr.DataArray) - assert (da_sliced.x_osgb.values == np.arange(0, 20)).all() - assert (da_sliced.y_osgb.values == np.arange(20, 0, -1)).all() + assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all() + assert (da_sliced.y_osgb.values == np.arange(-90, -70)).all() + # No padding in this case so no NaNs assert not da_sliced.isnull().any() - # Select window which is partially outside the boundary of the data + # Select window which is partially outside the boundary of the data - padded on left da_sliced = select_spatial_slice_pixels( da, - location, + location=Location(x=-90, y=-80, coordinate_system="osgb"), width_pixels=30, height_pixels=30, allow_partial_slice=True, ) assert isinstance(da_sliced, xr.DataArray) - assert (da_sliced.x_osgb.values == np.arange(-5, 25)).all() - assert (da_sliced.y_osgb.values == np.arange(25, -5, -1)).all() - assert da_sliced.isnull().sum() == 275 + assert (da_sliced.x_osgb.values == np.arange(-105, -75)).all() + assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all() + # Data has been padded on left by 5 NaN pixels + assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb) + + + # Select window which is partially outside the boundary of the data - padded on right + da_sliced = select_spatial_slice_pixels( + da, + location=Location(x=90, y=-80, coordinate_system="osgb"), + width_pixels=30, + height_pixels=30, + allow_partial_slice=True, + ) + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(75, 105)).all() + assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all() + # Data has been padded on right by 5 NaN pixels + assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb) + + + location = Location(x=-90, y=-0, coordinate_system="osgb") + + # Select window which is partially outside the boundary of the data - padded on top + da_sliced = select_spatial_slice_pixels( + da, + location=Location(x=-90, y=95, coordinate_system="osgb"), + width_pixels=20, + height_pixels=20, + allow_partial_slice=True, + ) + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all() + assert (da_sliced.y_osgb.values == np.arange(85, 105)).all() + # Data has been padded on top by 5 NaN pixels + assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb) + + # Select window which is partially outside the boundary of the data - padded on bottom + da_sliced = select_spatial_slice_pixels( + da, + location=Location(x=-90, y=-95, coordinate_system="osgb"), + width_pixels=20, + height_pixels=20, + allow_partial_slice=True, + ) + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all() + assert (da_sliced.y_osgb.values == np.arange(-105, -85)).all() + # Data has been padded on bottom by 5 NaN pixels + assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb) + + # Select window which is partially outside the boundary of the data - padded right and bottom + da_sliced = select_spatial_slice_pixels( + da, + location=Location(x=90, y=-80, coordinate_system="osgb"), + width_pixels=50, + height_pixels=50, + allow_partial_slice=True, + ) + + assert isinstance(da_sliced, xr.DataArray) + assert (da_sliced.x_osgb.values == np.arange(65, 115)).all() + assert (da_sliced.y_osgb.values == np.arange(-105, -55)).all() + # Data has been padded on right by 15 pixels and bottom by 5 NaN pixels + assert da_sliced.isnull().sum() == 15*len(da_sliced.y_osgb) + 5*len(da_sliced.x_osgb) - 15*5