Skip to content

Commit

Permalink
Merge branch 'development' into spatial_slice_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu authored Aug 6, 2024
2 parents fcfd9b6 + 804aeb0 commit ce0bd32
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 9 deletions.
10 changes: 5 additions & 5 deletions ocf_data_sampler/datasets/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr
datasets = {}

# We always assume GSP will be included
gsp_config = config.input_data.gsps
gsp_config = config.input_data.gsp

da_gsp = open_gsp(zarr_path=gsp_config.gsp_zarr_path)
da_gsp = add_t0_idx_and_sample_period_duration(da_gsp, gsp_config)
Expand Down Expand Up @@ -429,12 +429,12 @@ def process_and_combine_datasets(dataset_dict: dict, config: Configuration) -> N

def compute(xarray_dict: dict) -> dict:
"""Eagerly load a nested dictionary of xarray DataArrays"""
for k, v in d.items():
for k, v in xarray_dict.items():
if isinstance(v, dict):
d[k] = compute(v)
xarray_dict[k] = compute(v)
else:
d[k] = v.compute(scheduler="single-threaded")
return d
xarray_dict[k] = v.compute(scheduler="single-threaded")
return xarray_dict


def get_locations(gs_gsp: xr.DataArray) -> list[Location]:
Expand Down
19 changes: 15 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@



@pytest.fixture(scope="session")
def config_filename():
return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/pvnet_test_config.yaml"

@pytest.fixture(scope="session")
def sat_zarr_path():

Expand All @@ -28,9 +32,18 @@ def sat_zarr_path():
# Transpose to variables, time, y, x (just in case)
ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")

# add 100,000 to x_geostationary, this to make sure the fix index is within the satellite image
ds["x_geostationary"] = ds["x_geostationary"] - 200_000

# Add some NaNs
ds["data"].values[:, :, 0, 0] = np.nan

# make sure channel values are strings
ds["variable"] = ds["variable"].astype(str)

# add data attrs area
ds["data"].attrs["area"] = 'msg_seviri_rss_3km:\n description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution\n projection:\n proj: geos\n lon_0: 9.5\n h: 35785831\n x_0: 0\n y_0: 0\n a: 6378169\n rf: 295.488065897014\n no_defs: null\n type: crs\n shape:\n height: 298\n width: 615\n area_extent:\n lower_left_xy: [28503.830075263977, 5090183.970808983]\n upper_right_xy: [-1816744.1169023514, 4196063.827395439]\n units: m\n'

# Specifiy chunking
ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1})

Expand All @@ -44,10 +57,9 @@ def sat_zarr_path():

@pytest.fixture(scope="session")
def ds_nwp_ukv():
init_times = pd.date_range(start="2022-09-01 00:00", freq="180min", periods=24 * 7)
init_times = pd.date_range(start="2023-01-01 00:00", freq="180min", periods=24 * 7)
steps = pd.timedelta_range("0h", "10h", freq="1h")


# This is much faster:
x = np.linspace(-239_000, 857_000, 100)
y = np.linspace(-183_000, 1225_000, 200)
Expand Down Expand Up @@ -88,7 +100,7 @@ def nwp_ukv_zarr_path(ds_nwp_ukv):

@pytest.fixture(scope="session")
def ds_uk_gsp():
times = pd.date_range("2022-09-01 00:00", "2022-09-02 00:00", freq="30min")
times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
gsp_ids = np.arange(0, 318)
capacity = np.ones((len(times), len(gsp_ids)))
generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids)))
Expand All @@ -98,7 +110,6 @@ def ds_uk_gsp():
("gsp_id", gsp_ids),
)


da_cap = xr.DataArray(
capacity,
coords=coords,
Expand Down
43 changes: 43 additions & 0 deletions tests/dataset/test_pvnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import tempfile

from ocf_data_sampler.datasets.pvnet import PVNetDataset
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.config.save import save_yaml_configuration


@pytest.fixture()
def pvnet_config_filename(config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path):

# adjust config to point to the zarr file
config = load_yaml_configuration(config_filename)
config.input_data.nwp['ukv'].nwp_zarr_path = nwp_ukv_zarr_path
config.input_data.satellite.satellite_zarr_path = sat_zarr_path
config.input_data.gsp.gsp_zarr_path = uk_gsp_zarr_path

with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/configuration.yaml"
save_yaml_configuration(config, filename)
yield filename


def test_pvnet(pvnet_config_filename):

# Create dataset object
dataset = PVNetDataset(pvnet_config_filename)

# Print number of samples
print(f"Found {len(dataset.valid_t0_times)} possible samples")

idx = 0
t_index, loc_index = dataset.index_pairs[idx]

location = dataset.locations[loc_index]
t0 = dataset.valid_t0_times[t_index]

# Print coords
print(t0)
print(location)

# Generate sample - no printing since its BIG
_ = dataset[idx]
43 changes: 43 additions & 0 deletions tests/test_data/pvnet_test_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
general:
description: Test config for PVNet
name: pvnet_test

input_data:
default_history_minutes: 60
default_forecast_minutes: 120

gsp:
gsp_zarr_path: set_in_temp_file
history_minutes: 60
forecast_minutes: 120
time_resolution_minutes: 30
dropout_timedeltas_minutes: null
dropout_fraction: 0

nwp:
ukv:
nwp_provider: ukv
nwp_zarr_path: set_in_temp_file
history_minutes: 30
forecast_minutes: 15
time_resolution_minutes: 60
nwp_channels:
- t # 2-metre temperature
nwp_image_size_pixels_height: 2
nwp_image_size_pixels_width: 2
dropout_timedeltas_minutes: [-180]
dropout_fraction: 1.0
max_staleness_minutes: null

satellite:
satellite_zarr_path: set_in_temp_file
history_minutes: 30
forecast_minutes: 0
live_delay_minutes: 0
time_resolution_minutes: 15
satellite_channels:
- IR_016
satellite_image_size_pixels_height: 2
satellite_image_size_pixels_width: 2
dropout_timedeltas_minutes: null
dropout_fraction: 0

0 comments on commit ce0bd32

Please sign in to comment.