Skip to content

Commit

Permalink
Merge pull request #9 from openclimatefix/spatial_slice_tests
Browse files Browse the repository at this point in the history
Spatial slice tests
  • Loading branch information
dfulu authored Aug 6, 2024
2 parents 804aeb0 + ce0bd32 commit c292728
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 262 deletions.
12 changes: 4 additions & 8 deletions ocf_data_sampler/load/gsp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from pathlib import Path
import pkg_resources

import xarray as xr

import pandas as pd
import xarray as xr


def open_gsp(zarr_path: str | Path):
def open_gsp(zarr_path: str | Path) -> xr.DataArray:

# Load GSP generation xr.Dataset
ds = xr.open_zarr(zarr_path)

# Rename to standard time name
ds = ds.rename({
"datetime_gmt": "time_utc",
})
ds = ds.rename({"datetime_gmt": "time_utc"})

# Load UK GSP locations
df_gsp_loc = pd.read_csv(
Expand All @@ -31,7 +28,6 @@ def open_gsp(zarr_path: str | Path):

)

# Return dataarray
return ds["generation_mw"]
return ds.generation_mw


2 changes: 0 additions & 2 deletions ocf_data_sampler/load/nwp/nwp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from pathlib import Path

import xarray as xr

from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs

Expand Down
20 changes: 13 additions & 7 deletions ocf_data_sampler/load/nwp/providers/ecmwf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""ECMWF provider loaders"""
from pathlib import Path
import xarray as xr
from ocf_data_sampler.load.nwp.providers.utils import (
open_zarr_paths, check_time_unique_increasing
)

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing

def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
"""
Expand All @@ -19,13 +17,21 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
# Open the data
ds = open_zarr_paths(zarr_path)

ds = ds.transpose("init_time", "step", "variable", "latitude", "longitude")
# Rename
ds = ds.rename(
{
"init_time": "init_time_utc",
"variable": "channel",
}
)
# Sanity checks.
check_time_unique_increasing(ds)

# Check the timestmps are unique and increasing
check_time_unique_increasing(ds.init_time_utc)

# Make sure the spatial coords are in increasing order
ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")

ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")

# TODO: should we control the dtype of the DataArray?
return ds.ECMWF_UK
22 changes: 13 additions & 9 deletions ocf_data_sampler/load/nwp/providers/ukv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

from pathlib import Path


from ocf_data_sampler.load.nwp.providers.utils import (
open_zarr_paths, check_time_unique_increasing
)
from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing


def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
Expand All @@ -23,19 +21,25 @@ def open_ukv(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray:
# Open the data
ds = open_zarr_paths(zarr_path)

ds = ds.transpose("init_time", "step", "variable", "y", "x")
# Rename
ds = ds.rename(
{
"init_time": "init_time_utc",
"variable": "channel",
"y": "y_osgb",
"x": "x_osgb",
"y": "y_osgb",
}
)

# Sanity checks.
assert ds.y_osgb[0] > ds.y_osgb[1], "UKV must run from top-to-bottom."
check_time_unique_increasing(ds)
# Check the timestmps are unique and increasing
check_time_unique_increasing(ds.init_time_utc)

# Make sure the spatial coords are in increasing order
ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")

ds = ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")

# TODO: should we control the dtype of the DataArray?
return ds.UKV


8 changes: 0 additions & 8 deletions ocf_data_sampler/load/nwp/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path
import xarray as xr
import pandas as pd


def open_zarr_paths(
Expand Down Expand Up @@ -33,10 +32,3 @@ def open_zarr_paths(
chunks="auto",
)
return ds


def check_time_unique_increasing(ds: xr.Dataset) -> None:
"""Check that the time dimension is unique and increasing"""
time = pd.DatetimeIndex(ds.init_time_utc)
assert time.is_unique
assert time.is_monotonic_increasing
49 changes: 16 additions & 33 deletions ocf_data_sampler/load/satellite.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Satellite loader"""

import logging
import subprocess
from pathlib import Path

import pandas as pd
import xarray as xr


_log = logging.getLogger(__name__)
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing


def _get_single_sat_data(zarr_path: Path | str) -> xr.DataArray:
Expand Down Expand Up @@ -73,46 +70,32 @@ def open_sat_data(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArra
```
"""

# Open the data
if isinstance(zarr_path, (list, tuple)):
message_files_list = "\n - " + "\n - ".join([str(s) for s in zarr_path])
_log.info(f"Opening satellite data: {message_files_list}")
ds = xr.combine_nested(
[_get_single_sat_data(path) for path in zarr_path],
concat_dim="time",
combine_attrs="override",
join="override",
)
else:
_log.info(f"Opening satellite data: {zarr_path}")
ds = _get_single_sat_data(zarr_path)

if "variable" in ds.coords:
ds = ds.rename({"variable": "channel"})
if "channels" in ds.coords:
ds = ds.rename({"channels": "channel"})

# Rename coords to be more explicit about exactly what some coordinates hold:
# Note that `rename` renames *both* the coordinates and dimensions, and keeps
# the connection between the dims and coordinates, so we don't have to manually
# use `data_array.set_index()`.
ds = ds.rename({"time": "time_utc"})

# Flip coordinates to top-left first
if ds.y_geostationary[0] < ds.y_geostationary[-1]:
ds = ds.isel(y_geostationary=slice(None, None, -1))
if ds.x_geostationary[0] > ds.x_geostationary[-1]:
ds = ds.isel(x_geostationary=slice(None, None, -1))
# Rename
ds = ds.rename(
{
"variable": "channel",
"time": "time_utc",
}
)

# Ensure the y and x coords are in the right order (top-left first):
assert ds.y_geostationary[0] > ds.y_geostationary[-1]
assert ds.x_geostationary[0] < ds.x_geostationary[-1]
# Check the timestmps are unique and increasing
check_time_unique_increasing(ds.time_utc)

ds = ds.transpose("time_utc", "channel", "y_geostationary", "x_geostationary")
# Make sure the spatial coords are in increasing order
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")

# Sanity checks!
datetime_index = pd.DatetimeIndex(ds.time_utc)
assert datetime_index.is_unique
assert datetime_index.is_monotonic_increasing
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")

# Return DataArray
return ds["data"]
# TODO: should we control the dtype of the DataArray?
return ds.data
29 changes: 29 additions & 0 deletions ocf_data_sampler/load/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import xarray as xr
import pandas as pd

def check_time_unique_increasing(datetimes) -> None:
"""Check that the time dimension is unique and increasing"""
time = pd.DatetimeIndex(datetimes)
assert time.is_unique
assert time.is_monotonic_increasing

def make_spatial_coords_increasing(ds: xr.Dataset, x_coord: str, y_coord: str) -> xr.Dataset:
"""Make sure the spatial coordinates are in increasing order
Args:
ds: Xarray Dataset
x_coord: Name of the x coordinate
y_coord: Name of the y coordinate
"""

# Make sure the coords are in increasing order
if ds[x_coord][0] > ds[x_coord][-1]:
ds = ds.isel({x_coord:slice(None, None, -1)})
if ds[y_coord][0] > ds[y_coord][-1]:
ds = ds.isel({y_coord:slice(None, None, -1)})

# Check the coords are all increasing now
assert (ds[x_coord].diff(dim=x_coord) > 0).all()
assert (ds[y_coord].diff(dim=y_coord) > 0).all()

return ds
6 changes: 6 additions & 0 deletions ocf_data_sampler/select/find_contiguous_t0_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ def intersection_of_2_dataframes_of_periods(a: pd.DataFrame, b: pd.DataFrame) ->
# b: |--| |---| |---| |------| |-|
# In all five, `a` must always start before `b` ends,
# and `a` must always end after `b` starts:

# TODO: <= and >= because we should allow overlap time periods of length 1. e.g.
# a: |----| or |---|
# b: |--| |---|
# These aren't allowed if we use < and >.

overlapping_periods = b[(a_period.start_dt < b.end_dt) & (a_period.end_dt > b.start_dt)]

# There are two ways in which two periods may *not* overlap:
Expand Down
Loading

0 comments on commit c292728

Please sign in to comment.