Skip to content

Commit

Permalink
Merge pull request #140 from openclimatefix/issue/datetime-bug
Browse files Browse the repository at this point in the history
Fix: sun and datetimes dataloader issue
  • Loading branch information
peterdudfield authored Jan 22, 2025
2 parents bd9fca2 + f2a0b50 commit e751276
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 17 deletions.
19 changes: 10 additions & 9 deletions ocf_data_sampler/torch_datasets/datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,29 +241,30 @@ def process_and_combine_site_sample_dict(

# add datetime features
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site")
datetime_features_xr = xr.Dataset(datetime_features, coords={"site__time_utc": datetimes})
combined_sample_dataset = xr.merge([combined_sample_dataset, datetime_features_xr])
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_")
combined_sample_dataset = combined_sample_dataset.assign_coords(
{k: ("site__time_utc", v) for k, v in datetime_features.items()}
)

# add sun features
sun_position_features = make_sun_position_numpy_sample(
datetimes=datetimes,
lon=combined_sample_dataset.site__longitude.values,
lat=combined_sample_dataset.site__latitude.values,
key_prefix="site",
key_prefix="site_",
)
sun_position_features_xr = xr.Dataset(
sun_position_features, coords={"site__time_utc": datetimes}
combined_sample_dataset = combined_sample_dataset.assign_coords(
{k: ("site__time_utc", v) for k, v in sun_position_features.items()}
)
combined_sample_dataset = xr.merge([combined_sample_dataset, sun_position_features_xr])

# TODO include t0_index in xr dataset?

# Fill any nan values
return combined_sample_dataset.fillna(0.0)


def merge_data_arrays(self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
def merge_data_arrays(
self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]
) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.
Expand Down
87 changes: 79 additions & 8 deletions tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets
from xarray import Dataset, DataArray

from torch.utils.data import DataLoader


def test_site(site_config_filename):

Expand All @@ -18,17 +20,45 @@ def test_site(site_config_filename):
assert isinstance(sample, Dataset)

# Expected dimensions and data variables
expected_dims = {'satellite__x_geostationary', 'site__time_utc', 'nwp-ukv__target_time_utc',
'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb', 'site_solar_azimuth',
'site_solar_elevation', 'site_date_cos', 'site_time_cos', 'site_time_sin', 'site_date_sin'}
expected_dims = {
"satellite__x_geostationary",
"site__time_utc",
"nwp-ukv__target_time_utc",
"nwp-ukv__x_osgb",
"satellite__channel",
"satellite__y_geostationary",
"satellite__time_utc",
"nwp-ukv__channel",
"nwp-ukv__y_osgb",
}

expected_coords_subset = {
"site__solar_azimuth",
"site__solar_elevation",
"site__date_cos",
"site__time_cos",
"site__time_sin",
"site__date_sin",
}

expected_data_vars = {"nwp-ukv", "satellite", "site"}

import xarray as xr

sample.to_netcdf("sample.nc")
sample = xr.open_dataset("sample.nc")

# Check dimensions
assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
assert (
set(sample.dims) == expected_dims
), f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
# Check data variables
assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"
assert (
set(sample.data_vars) == expected_data_vars
), f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"

for coords in expected_coords_subset:
assert coords in sample.coords

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
Expand All @@ -38,6 +68,7 @@ def test_site(site_config_filename):
# 1.5 hours of 30 minute data (inclusive)
assert sample["site"].values.shape == (4,)


def test_site_time_filter_start(site_config_filename):

# Create dataset object
Expand Down Expand Up @@ -74,11 +105,51 @@ def test_convert_from_dataset_to_dict_datasets(site_config_filename):

assert isinstance(sample, dict)

print(sample.keys())

for key in ["nwp", "satellite", "site"]:
assert key in sample


def test_site_dataset_with_dataloader(site_config_filename):
# Create dataset object
dataset = SitesDataset(site_config_filename)

expected_coods = {
"site__solar_azimuth",
"site__solar_elevation",
"site__date_cos",
"site__time_cos",
"site__time_sin",
"site__date_sin",
}

sample = dataset[0]
for key in expected_coods:
assert key in sample

dataloader_kwargs = dict(
shuffle=False,
batch_size=None,
sampler=None,
batch_sampler=None,
num_workers=1,
collate_fn=None,
pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=1,
persistent_workers=False, # Not needed since we only enter the dataloader loop once
)

dataloader = DataLoader(dataset, collate_fn=None, batch_size=None)

for i, sample in zip(range(1), dataloader):

# check that expected_dims is in the sample
for key in expected_coods:
assert key in sample


def test_process_and_combine_site_sample_dict(site_config_filename):
# Load config
# config = load_yaml_configuration(pvnet_config_filename)
Expand Down

0 comments on commit e751276

Please sign in to comment.