From 155bad3010ba35982f19bebae8d2ba499a4c5224 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 6 Aug 2024 11:50:51 +0000 Subject: [PATCH] clean up time selection tests --- ocf_data_sampler/select/select_time_slice.py | 21 +++-- tests/select/test_select_time_slice.py | 95 ++++++++++++++------ 2 files changed, 79 insertions(+), 37 deletions(-) diff --git a/ocf_data_sampler/select/select_time_slice.py b/ocf_data_sampler/select/select_time_slice.py index beb1d2c..2a02974 100644 --- a/ocf_data_sampler/select/select_time_slice.py +++ b/ocf_data_sampler/select/select_time_slice.py @@ -3,7 +3,6 @@ import numpy as np from datetime import timedelta -from typing import Optional def _sel_fillnan(ds, start_dt, end_dt, sample_period_duration: timedelta): @@ -16,8 +15,8 @@ def _sel_fillnan(ds, start_dt, end_dt, sample_period_duration: timedelta): return ds.reindex(time_utc=requested_times) -def _sel_default(ds, start_dt, end_dt, time_period_duration: timedelta): - # Note 'time_period_duration' is not used but need as its is needed so it's the same as _sel_fillnan +def _sel_default(ds, start_dt, end_dt, sample_period_duration: timedelta): + # Note 'sample_period_duration' is not used but need as its is needed so it's the same as _sel_fillnan return ds.sel(time_utc=slice(start_dt, end_dt)) @@ -30,11 +29,11 @@ def select_time_slice( ds: xr.Dataset | xr.DataArray, t0, sample_period_duration: timedelta, - history_duration: Optional[timedelta] = None, - forecast_duration: Optional[timedelta] = None, - interval_start: Optional[timedelta] = None, - interval_end: Optional[timedelta] = None, - fill_selection: Optional[bool] = False, + history_duration: timedelta | None = None, + forecast_duration: timedelta | None = None, + interval_start: timedelta | None = None, + interval_end: timedelta | None = None, + fill_selection: bool = False, max_steps_gap: int = 0, ): used_duration = history_duration is not None and forecast_duration is not None @@ -72,9 +71,9 @@ def select_time_slice_nwp( sample_period_duration: timedelta, history_duration: timedelta, forecast_duration: timedelta, - dropout_timedeltas: Optional[list[timedelta]] = None, - dropout_frac: Optional[float] = 0, - accum_channels: Optional[list[str]] = [], + dropout_timedeltas: list[timedelta] | None = None, + dropout_frac: float | None = 0, + accum_channels: list[str] = [], channel_dim_name: str = "channel", ): diff --git a/tests/select/test_select_time_slice.py b/tests/select/test_select_time_slice.py index fd7c025..eafa440 100644 --- a/tests/select/test_select_time_slice.py +++ b/tests/select/test_select_time_slice.py @@ -4,46 +4,89 @@ from datetime import timedelta import numpy as np import pandas as pd +import xarray as xr +import pytest -def test_select_time_slice(sat_zarr_path): - sat = open_sat_data(sat_zarr_path) - t0 = pd.Timestamp(sat.time_utc[3].values) +@pytest.fixture(scope="module") +def da_sat_like(): + # Create dummy data which looks like satellite data + x = np.arange(-100, 100) + y = np.arange(-100, 100) + datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min") + da_sat = xr.DataArray( + np.random.normal(size=(len(datetimes), len(x), len(y))), + coords=dict( + time_utc=(["time_utc"], datetimes), + x_geostationary=(["x_geostationary"], x), + y_geostationary=(["y_geostationary"], y), + + ) + ) + return da_sat + + +def test_select_time_slice(da_sat_like): + t0 = pd.Timestamp("2024-01-02 12:00") + + forecast_duration = timedelta(minutes=0) + history_duration = timedelta(minutes=30) + freq = timedelta(minutes=5) + + # Expect to return these timestamps from the selection + expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + + # Test history and forecast param usage + sat_sample = select_time_slice( + ds=da_sat_like, + t0=t0, + history_duration=history_duration, + forecast_duration=forecast_duration, + sample_period_duration=freq, + ) + + assert (sat_sample.time_utc == expected_datetimes).all() + + # Test interval param usage sat_sample = select_time_slice( - ds=sat, + ds=da_sat_like, t0=t0, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=5), - forecast_duration=timedelta(minutes=5), + interval_start=-history_duration, + interval_end=forecast_duration, + sample_period_duration=freq, ) - assert len(sat_sample.time_utc) == 3 - assert (sat_sample.time_utc == pd.date_range( - t0 - timedelta(minutes=5), t0 + timedelta(minutes=5), freq=timedelta(minutes=5) - )).all() + assert (sat_sample.time_utc == expected_datetimes).all() -# TODO could to test with intervals, but we might want to remove this functionaility +def test_select_time_slice_out_of_bounds(da_sat_like): + t0 = pd.Timestamp("2024-01-02 00:30") -def test_select_time_slice_out_of_bounds(sat_zarr_path): - sat = open_sat_data(sat_zarr_path) - t0 = pd.Timestamp(sat.time_utc[-1].values) + forecast_duration = timedelta(minutes=0) + history_duration = timedelta(minutes=60) + freq = timedelta(minutes=5) + + # Expect to return these timestamps from the selection + expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) sat_sample = select_time_slice( - ds=sat, + ds=da_sat_like, t0=t0, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=5), - forecast_duration=timedelta(minutes=5), - fill_selection=True, + history_duration=history_duration, + forecast_duration=forecast_duration, + sample_period_duration=freq, + fill_selection=True ) - assert len(sat_sample.time_utc) == 3 - assert (sat_sample.time_utc == pd.date_range( - t0 - timedelta(minutes=5), t0 + timedelta(minutes=5), freq=timedelta(minutes=5) - )).all() + assert (sat_sample.time_utc == expected_datetimes).all() + # Correct number of time steps are all NaN - sat_sel = sat_sample.isel(x_geostationary=0, y_geostationary=0, channel=0) - assert np.isnan(sat_sel.values).sum() == 1 + all_nan_space = sat_sample.isnull().all(dim=("x_geostationary", "y_geostationary")) + + # Check all the values before the first timestamp available in the data are NaN + assert all_nan_space.sel(time_utc=slice(None, "2024-01-01 23:55")).all(dim="time_utc") + + # check all the values after the first timestamp available in the data are not NaN + assert not all_nan_space.sel(time_utc=slice("2024-01-02 00:00", None)).any(dim="time_utc")