Skip to content

Commit

Permalink
add spatial slice fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Aug 6, 2024
1 parent 092147d commit fcfd9b6
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 60 deletions.
88 changes: 47 additions & 41 deletions ocf_data_sampler/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def convert_coords_to_match_xarray(

return x, y


# TODO: This function and _get_idx_of_pixel_closest_to_poi_geostationary() should not be separate
# We should combine them, and consider making a Coord class to help with this
def _get_idx_of_pixel_closest_to_poi(
da: xr.DataArray,
location: Location,
Expand Down Expand Up @@ -131,9 +132,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
da[x_dim].values, center_geostationary.x, assume_ascending=True
)

# y_geostationary is in descending order:
y_index_at_center = searchsorted(
da[y_dim].values, center_geostationary.y, assume_ascending=False
da[y_dim].values, center_geostationary.y, assume_ascending=True
)

return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")
Expand All @@ -142,36 +142,41 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
# ---------------------------- sub-functions for slicing ----------------------------


def _slice_patial_spatial_pixel_window_from_xarray(
def _select_partial_spatial_slice_pixels(
da,
left_idx,
right_idx,
top_idx,
bottom_idx,
top_idx,
left_pad_pixels,
right_pad_pixels,
top_pad_pixels,
bottom_pad_pixels,
top_pad_pixels,
x_dim,
y_dim,
):
"""Return spatial window of given pixel size when window partially overlaps input data"""

# We should never be padding on both sides of a window. This would mean our desired window is
# larger than the size of the input data
assert left_pad_pixels==0 or right_pad_pixels==0
assert bottom_pad_pixels==0 or top_pad_pixels==0

dx = np.median(np.diff(da[x_dim].values))
dy = np.median(np.diff(da[y_dim].values))

# Pad the left of the window
if left_pad_pixels > 0:
assert right_pad_pixels == 0
x_sel = np.concatenate(
[
da[x_dim].values[0] - np.arange(left_pad_pixels, 0, -1) * dx,
da[x_dim].values[0] + np.arange(-left_pad_pixels, 0) * dx,
da[x_dim].values[0:right_idx],
]
)
da = da.isel({x_dim: slice(0, right_idx)}).reindex({x_dim: x_sel})

# Pad the right of the window
elif right_pad_pixels > 0:
assert left_pad_pixels == 0
x_sel = np.concatenate(
[
da[x_dim].values[left_idx:],
Expand All @@ -180,31 +185,33 @@ def _slice_patial_spatial_pixel_window_from_xarray(
)
da = da.isel({x_dim: slice(left_idx, None)}).reindex({x_dim: x_sel})

# No left-right padding required
else:
da = da.isel({x_dim: slice(left_idx, right_idx)})

if top_pad_pixels > 0:
assert bottom_pad_pixels == 0
# Pad the bottom of the window
if bottom_pad_pixels > 0:
y_sel = np.concatenate(
[
da[y_dim].values[0] - np.arange(top_pad_pixels, 0, -1) * dy,
da[y_dim].values[0:bottom_idx],
da[y_dim].values[0] + np.arange(-bottom_pad_pixels, 0) * dy,
da[y_dim].values[0:top_idx],
]
)
da = da.isel({y_dim: slice(0, bottom_idx)}).reindex({y_dim: y_sel})
da = da.isel({y_dim: slice(0, top_idx)}).reindex({y_dim: y_sel})

elif bottom_pad_pixels > 0:
assert top_pad_pixels == 0
# Pad the top of the window
elif top_pad_pixels > 0:
y_sel = np.concatenate(
[
da[y_dim].values[top_idx:],
da[y_dim].values[-1] + np.arange(1, bottom_pad_pixels + 1) * dy,
da[y_dim].values[bottom_idx:],
da[y_dim].values[-1] + np.arange(1, top_pad_pixels + 1) * dy,
]
)
da = da.isel({y_dim: slice(top_idx, None)}).reindex({x_dim: y_sel})
da = da.isel({y_dim: slice(left_idx, None)}).reindex({y_dim: y_sel})

# No bottom-top padding required
else:
da = da.isel({y_dim: slice(top_idx, bottom_idx)})
da = da.isel({y_dim: slice(bottom_idx, top_idx)})

return da

Expand Down Expand Up @@ -240,56 +247,55 @@ def _select_spatial_slice_pixels(

left_idx = int(center_idx.x - half_width)
right_idx = int(center_idx.x + half_width)
top_idx = int(center_idx.y - half_height)
bottom_idx = int(center_idx.y + half_height)
bottom_idx = int(center_idx.y - half_height)
top_idx = int(center_idx.y + half_height)

data_width_pixels = len(da[x_dim])
data_height_pixels = len(da[y_dim])

left_pad_required = left_idx < 0
right_pad_required = right_idx >= data_width_pixels
top_pad_required = top_idx < 0
bottom_pad_required = bottom_idx >= data_height_pixels
right_pad_required = right_idx > data_width_pixels
bottom_pad_required = bottom_idx < 0
top_pad_required = top_idx > data_height_pixels

pad_required = any(
[left_pad_required, right_pad_required, top_pad_required, bottom_pad_required]
)
pad_required = left_pad_required | right_pad_required | bottom_pad_required | top_pad_required

if pad_required:
if allow_partial_slice:

left_pad_pixels = (-left_idx) if left_pad_required else 0
right_pad_pixels = (right_idx - (data_width_pixels - 1)) if right_pad_required else 0
top_pad_pixels = (-top_idx) if top_pad_required else 0
bottom_pad_pixels = (
(bottom_idx - (data_height_pixels - 1)) if bottom_pad_required else 0
)
right_pad_pixels = (right_idx - data_width_pixels) if right_pad_required else 0

bottom_pad_pixels = (-bottom_idx) if bottom_pad_required else 0
top_pad_pixels = (top_idx - data_height_pixels) if top_pad_required else 0

da = _slice_patial_spatial_pixel_window_from_xarray(

da = _select_partial_spatial_slice_pixels(
da,
left_idx,
right_idx,
top_idx,
bottom_idx,
top_idx,
left_pad_pixels,
right_pad_pixels,
top_pad_pixels,
bottom_pad_pixels,
top_pad_pixels,
x_dim,
y_dim,
)
else:
raise ValueError(
f"Window for location {center_idx} not available. Missing (left, right, top, "
f"bottom) pixels = ({left_pad_required}, {right_pad_required}, "
f"{top_pad_required}, {bottom_pad_required}). "
f"Window for location {center_idx} not available. Missing (left, right, bottom, "
f"top) pixels = ({left_pad_required}, {right_pad_required}, "
f"{bottom_pad_required}, {top_pad_required}). "
f"You may wish to set `allow_partial_slice=True`"
)

else:
da = da.isel(
{
x_dim: slice(left_idx, right_idx),
y_dim: slice(top_idx, bottom_idx),
y_dim: slice(bottom_idx, top_idx),
}
)

Expand All @@ -299,7 +305,7 @@ def _select_spatial_slice_pixels(
)
assert len(da[y_dim]) == height_pixels, (
f"Expected y-dim len {height_pixels} got {len(da[y_dim])} "
f"for location {center_idx} for slice {top_idx}:{bottom_idx}"
f"for location {center_idx} for slice {bottom_idx}:{top_idx}"
)

return da
Expand Down
124 changes: 105 additions & 19 deletions tests/select/test_select_spatial_slice.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import numpy as np
import xarray as xr
from ocf_datapipes.utils import Location
import pytest

from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
from ocf_data_sampler.select.select_spatial_slice import (
select_spatial_slice_pixels, _get_idx_of_pixel_closest_to_poi
)


def test_select_spatial_slice_pixels():
@pytest.fixture(scope="module")
def da():
# Create dummy data
x = np.arange(100)
y = np.arange(100)[::-1]
x = np.arange(-100, 100)
y = np.arange(-100, 100)

da = xr.DataArray(
np.random.normal(size=(len(x), len(y))),
Expand All @@ -17,52 +20,135 @@ def test_select_spatial_slice_pixels():
y_osgb=(["y_osgb"], y),
)
)
return da


def test_get_idx_of_pixel_closest_to_poi(da):

idx_location = _get_idx_of_pixel_closest_to_poi(
da,
location=Location(x=10, y=10, coordinate_system="osgb"),
)

assert idx_location.coordinate_system == "idx"
assert idx_location.x == 110
assert idx_location.y == 110

location = Location(x=10, y=10, coordinate_system="osgb")

# Select window which lies within data


def test_select_spatial_slice_pixels(da):

# Select window which lies within x-y bounds of the data
da_sliced = select_spatial_slice_pixels(
da,
location,
location=Location(x=-90, y=-80, coordinate_system="osgb"),
width_pixels=10,
height_pixels=10,
allow_partial_slice=True,
)


assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(5, 15)).all()
assert (da_sliced.y_osgb.values == np.arange(15, 5, -1)).all()
assert (da_sliced.x_osgb.values == np.arange(-95, -85)).all()
assert (da_sliced.y_osgb.values == np.arange(-85, -75)).all()
# No padding in this case so no NaNs
assert not da_sliced.isnull().any()


# Select window where the edge of the window lies at the edge of the data
# Select window where the edge of the window lies right on the edge of the data
da_sliced = select_spatial_slice_pixels(
da,
location,
location=Location(x=-90, y=-80, coordinate_system="osgb"),
width_pixels=20,
height_pixels=20,
allow_partial_slice=True,
)

assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(0, 20)).all()
assert (da_sliced.y_osgb.values == np.arange(20, 0, -1)).all()
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
assert (da_sliced.y_osgb.values == np.arange(-90, -70)).all()
# No padding in this case so no NaNs
assert not da_sliced.isnull().any()

# Select window which is partially outside the boundary of the data
# Select window which is partially outside the boundary of the data - padded on left
da_sliced = select_spatial_slice_pixels(
da,
location,
location=Location(x=-90, y=-80, coordinate_system="osgb"),
width_pixels=30,
height_pixels=30,
allow_partial_slice=True,
)

assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(-5, 25)).all()
assert (da_sliced.y_osgb.values == np.arange(25, -5, -1)).all()
assert da_sliced.isnull().sum() == 275
assert (da_sliced.x_osgb.values == np.arange(-105, -75)).all()
assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
# Data has been padded on left by 5 NaN pixels
assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)


# Select window which is partially outside the boundary of the data - padded on right
da_sliced = select_spatial_slice_pixels(
da,
location=Location(x=90, y=-80, coordinate_system="osgb"),
width_pixels=30,
height_pixels=30,
allow_partial_slice=True,
)

assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(75, 105)).all()
assert (da_sliced.y_osgb.values == np.arange(-95, -65)).all()
# Data has been padded on right by 5 NaN pixels
assert da_sliced.isnull().sum() == 5*len(da_sliced.y_osgb)


location = Location(x=-90, y=-0, coordinate_system="osgb")

# Select window which is partially outside the boundary of the data - padded on top
da_sliced = select_spatial_slice_pixels(
da,
location=Location(x=-90, y=95, coordinate_system="osgb"),
width_pixels=20,
height_pixels=20,
allow_partial_slice=True,
)

assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
assert (da_sliced.y_osgb.values == np.arange(85, 105)).all()
# Data has been padded on top by 5 NaN pixels
assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)

# Select window which is partially outside the boundary of the data - padded on bottom
da_sliced = select_spatial_slice_pixels(
da,
location=Location(x=-90, y=-95, coordinate_system="osgb"),
width_pixels=20,
height_pixels=20,
allow_partial_slice=True,
)

assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(-100, -80)).all()
assert (da_sliced.y_osgb.values == np.arange(-105, -85)).all()
# Data has been padded on bottom by 5 NaN pixels
assert da_sliced.isnull().sum() == 5*len(da_sliced.x_osgb)

# Select window which is partially outside the boundary of the data - padded right and bottom
da_sliced = select_spatial_slice_pixels(
da,
location=Location(x=90, y=-80, coordinate_system="osgb"),
width_pixels=50,
height_pixels=50,
allow_partial_slice=True,
)

assert isinstance(da_sliced, xr.DataArray)
assert (da_sliced.x_osgb.values == np.arange(65, 115)).all()
assert (da_sliced.y_osgb.values == np.arange(-105, -55)).all()
# Data has been padded on right by 15 pixels and bottom by 5 NaN pixels
assert da_sliced.isnull().sum() == 15*len(da_sliced.y_osgb) + 5*len(da_sliced.x_osgb) - 15*5



0 comments on commit fcfd9b6

Please sign in to comment.