Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select time test #5

Merged
merged 8 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions ocf_data_sampler/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,53 @@
import pandas as pd
import numpy as np

from datetime import timedelta
from typing import Optional
from datetime import timedelta, datetime


def _sel_fillnan(
da: xr.DataArray,
start_dt: datetime,
end_dt: datetime,
sample_period_duration: timedelta,
) -> xr.DataArray:
"""Select a time slice from a DataArray, filling missing times with NaNs."""
requested_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration)
return da.reindex(time_utc=requested_times)


def _sel_default(
da: xr.DataArray,
start_dt: datetime,
end_dt: datetime,
sample_period_duration: timedelta,
) -> xr.DataArray:
"""Select a time slice from a DataArray, without filling missing times."""
return da.sel(time_utc=slice(start_dt, end_dt))


# TODO either implement this or remove it, which would tidy up the code
def _sel_fillinterp(
da: xr.DataArray,
start_dt: datetime,
end_dt: datetime,
sample_period_duration: timedelta,
) -> xr.DataArray:
"""Select a time slice from a DataArray, filling missing times with linear interpolation."""
return NotImplemented


def select_time_slice(
ds: xr.Dataset | xr.DataArray,
t0,
t0: datetime,
sample_period_duration: timedelta,
history_duration: Optional[timedelta] = None,
forecast_duration: Optional[timedelta] = None,
interval_start: Optional[timedelta] = None,
interval_end: Optional[timedelta] = None,
fill_selection: Optional[bool] = False,
history_duration: timedelta | None = None,
forecast_duration: timedelta | None = None,
interval_start: timedelta | None = None,
interval_end: timedelta | None = None,
fill_selection: bool = False,
max_steps_gap: int = 0,
):
"""Select a time slice from a Dataset or DataArray."""
used_duration = history_duration is not None and forecast_duration is not None
used_intervals = interval_start is not None and interval_end is not None
assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
Expand All @@ -30,38 +61,21 @@ def select_time_slice(
interval_start = np.timedelta64(interval_start)
interval_end = np.timedelta64(interval_end)

def _sel_fillnan(ds, start_dt, end_dt):
requested_times = pd.date_range(
start_dt,
end_dt,
freq=sample_period_duration,
)
# Missing time indexes are returned with all NaN values
return ds.reindex(time_utc=requested_times)

def _sel_default(ds, start_dt, end_dt):
return ds.sel(time_utc=slice(start_dt, end_dt))

def _sel_fillinterp(ds, start_dt, end_dt):
return NotImplemented


if fill_selection and max_steps_gap == 0:
_sel = _sel_fillnan
elif fill_selection and max_steps_gap > 0:
_sel = _sel_fillinterp
else:
_sel = _sel_default


t0_datetime_utc = pd.Timestamp(t0)
start_dt = t0_datetime_utc + interval_start
end_dt = t0_datetime_utc + interval_end

start_dt = start_dt.ceil(sample_period_duration)
end_dt = end_dt.ceil(sample_period_duration)

return _sel(ds, start_dt, end_dt)
return _sel(ds, start_dt, end_dt, sample_period_duration)


def select_time_slice_nwp(
Expand All @@ -70,9 +84,9 @@ def select_time_slice_nwp(
sample_period_duration: timedelta,
history_duration: timedelta,
forecast_duration: timedelta,
dropout_timedeltas: Optional[list[timedelta]] = None,
dropout_frac: Optional[float] = 0,
accum_channels: Optional[list[str]] = [],
dropout_timedeltas: list[timedelta] | None = None,
dropout_frac: float | None = 0,
accum_channels: list[str] = [],
channel_dim_name: str = "channel",
):

Expand Down Expand Up @@ -139,7 +153,7 @@ def select_time_slice_nwp(
unique_init_times = np.unique(selected_init_times)
# - find the min and max steps we slice over. Max is extended due to diff
min_step = min(steps)
max_step = max(steps) + (ds.step[1] - ds.step[0])
max_step = max(steps) + sample_period_duration

xr_min = ds.sel(
{
Expand Down
Loading
Loading