Skip to content

Commit

Permalink
clean up time selection tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Aug 6, 2024
1 parent 52ffb47 commit 155bad3
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 37 deletions.
21 changes: 10 additions & 11 deletions ocf_data_sampler/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np

from datetime import timedelta
from typing import Optional


def _sel_fillnan(ds, start_dt, end_dt, sample_period_duration: timedelta):
Expand All @@ -16,8 +15,8 @@ def _sel_fillnan(ds, start_dt, end_dt, sample_period_duration: timedelta):
return ds.reindex(time_utc=requested_times)


def _sel_default(ds, start_dt, end_dt, time_period_duration: timedelta):
# Note 'time_period_duration' is not used but need as its is needed so it's the same as _sel_fillnan
def _sel_default(ds, start_dt, end_dt, sample_period_duration: timedelta):
# Note 'sample_period_duration' is not used but need as its is needed so it's the same as _sel_fillnan
return ds.sel(time_utc=slice(start_dt, end_dt))


Expand All @@ -30,11 +29,11 @@ def select_time_slice(
ds: xr.Dataset | xr.DataArray,
t0,
sample_period_duration: timedelta,
history_duration: Optional[timedelta] = None,
forecast_duration: Optional[timedelta] = None,
interval_start: Optional[timedelta] = None,
interval_end: Optional[timedelta] = None,
fill_selection: Optional[bool] = False,
history_duration: timedelta | None = None,
forecast_duration: timedelta | None = None,
interval_start: timedelta | None = None,
interval_end: timedelta | None = None,
fill_selection: bool = False,
max_steps_gap: int = 0,
):
used_duration = history_duration is not None and forecast_duration is not None
Expand Down Expand Up @@ -72,9 +71,9 @@ def select_time_slice_nwp(
sample_period_duration: timedelta,
history_duration: timedelta,
forecast_duration: timedelta,
dropout_timedeltas: Optional[list[timedelta]] = None,
dropout_frac: Optional[float] = 0,
accum_channels: Optional[list[str]] = [],
dropout_timedeltas: list[timedelta] | None = None,
dropout_frac: float | None = 0,
accum_channels: list[str] = [],
channel_dim_name: str = "channel",
):

Expand Down
95 changes: 69 additions & 26 deletions tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,89 @@
from datetime import timedelta
import numpy as np
import pandas as pd
import xarray as xr
import pytest


def test_select_time_slice(sat_zarr_path):
sat = open_sat_data(sat_zarr_path)
t0 = pd.Timestamp(sat.time_utc[3].values)
@pytest.fixture(scope="module")
def da_sat_like():
# Create dummy data which looks like satellite data
x = np.arange(-100, 100)
y = np.arange(-100, 100)
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")

da_sat = xr.DataArray(
np.random.normal(size=(len(datetimes), len(x), len(y))),
coords=dict(
time_utc=(["time_utc"], datetimes),
x_geostationary=(["x_geostationary"], x),
y_geostationary=(["y_geostationary"], y),

)
)
return da_sat


def test_select_time_slice(da_sat_like):
t0 = pd.Timestamp("2024-01-02 12:00")

forecast_duration = timedelta(minutes=0)
history_duration = timedelta(minutes=30)
freq = timedelta(minutes=5)

# Expect to return these timestamps from the selection
expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)

# Test history and forecast param usage
sat_sample = select_time_slice(
ds=da_sat_like,
t0=t0,
history_duration=history_duration,
forecast_duration=forecast_duration,
sample_period_duration=freq,
)

assert (sat_sample.time_utc == expected_datetimes).all()

# Test interval param usage
sat_sample = select_time_slice(
ds=sat,
ds=da_sat_like,
t0=t0,
sample_period_duration=timedelta(minutes=5),
history_duration=timedelta(minutes=5),
forecast_duration=timedelta(minutes=5),
interval_start=-history_duration,
interval_end=forecast_duration,
sample_period_duration=freq,
)

assert len(sat_sample.time_utc) == 3
assert (sat_sample.time_utc == pd.date_range(
t0 - timedelta(minutes=5), t0 + timedelta(minutes=5), freq=timedelta(minutes=5)
)).all()
assert (sat_sample.time_utc == expected_datetimes).all()


# TODO could to test with intervals, but we might want to remove this functionaility
def test_select_time_slice_out_of_bounds(da_sat_like):

t0 = pd.Timestamp("2024-01-02 00:30")

def test_select_time_slice_out_of_bounds(sat_zarr_path):
sat = open_sat_data(sat_zarr_path)
t0 = pd.Timestamp(sat.time_utc[-1].values)
forecast_duration = timedelta(minutes=0)
history_duration = timedelta(minutes=60)
freq = timedelta(minutes=5)

# Expect to return these timestamps from the selection
expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq)

sat_sample = select_time_slice(
ds=sat,
ds=da_sat_like,
t0=t0,
sample_period_duration=timedelta(minutes=5),
history_duration=timedelta(minutes=5),
forecast_duration=timedelta(minutes=5),
fill_selection=True,
history_duration=history_duration,
forecast_duration=forecast_duration,
sample_period_duration=freq,
fill_selection=True
)

assert len(sat_sample.time_utc) == 3
assert (sat_sample.time_utc == pd.date_range(
t0 - timedelta(minutes=5), t0 + timedelta(minutes=5), freq=timedelta(minutes=5)
)).all()
assert (sat_sample.time_utc == expected_datetimes).all()

# Correct number of time steps are all NaN
sat_sel = sat_sample.isel(x_geostationary=0, y_geostationary=0, channel=0)
assert np.isnan(sat_sel.values).sum() == 1
all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary"))

# Check all the values before the first timestamp available in the data are NaN
assert all_nan_space.sel(time_utc=slice(None, "2024-01-01 23:55")).all(dim="time_utc")

# check all the values after the first timestamp available in the data are not NaN
assert not all_nan_space.sel(time_utc=slice("2024-01-02 00:00", None)).any(dim="time_utc")

0 comments on commit 155bad3

Please sign in to comment.