Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: sun and datetimes dataloader issue #140

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions ocf_data_sampler/torch_datasets/datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ def process_and_combine_site_sample_dict(
# add datetime features
datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values)
datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site")
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
datetime_features_xr = xr.Dataset(datetime_features, coords={"site__time_utc": datetimes})
combined_sample_dataset = xr.merge([combined_sample_dataset, datetime_features_xr])
combined_sample_dataset = combined_sample_dataset.assign_coords(
{k: ("site__time_utc", v) for k, v in datetime_features.items()}
)

# add sun features
sun_position_features = make_sun_position_numpy_sample(
Expand All @@ -252,18 +253,18 @@ def process_and_combine_site_sample_dict(
lat=combined_sample_dataset.site__latitude.values,
key_prefix="site",
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
)
sun_position_features_xr = xr.Dataset(
sun_position_features, coords={"site__time_utc": datetimes}
combined_sample_dataset = combined_sample_dataset.assign_coords(
{k: ("site__time_utc", v) for k, v in sun_position_features.items()}
)
combined_sample_dataset = xr.merge([combined_sample_dataset, sun_position_features_xr])

# TODO include t0_index in xr dataset?

# Fill any nan values
return combined_sample_dataset.fillna(0.0)


def merge_data_arrays(self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
def merge_data_arrays(
self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]
) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.

Expand Down
87 changes: 79 additions & 8 deletions tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets
from xarray import Dataset, DataArray

from torch.utils.data import DataLoader


def test_site(site_config_filename):

Expand All @@ -18,17 +20,45 @@ def test_site(site_config_filename):
assert isinstance(sample, Dataset)

# Expected dimensions and data variables
expected_dims = {'satellite__x_geostationary', 'site__time_utc', 'nwp-ukv__target_time_utc',
'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb', 'site_solar_azimuth',
'site_solar_elevation', 'site_date_cos', 'site_time_cos', 'site_time_sin', 'site_date_sin'}
expected_dims = {
"satellite__x_geostationary",
"site__time_utc",
"nwp-ukv__target_time_utc",
"nwp-ukv__x_osgb",
"satellite__channel",
"satellite__y_geostationary",
"satellite__time_utc",
"nwp-ukv__channel",
"nwp-ukv__y_osgb",
}

expected_coords_subset = {
"site_solar_azimuth",
"site_solar_elevation",
"site_date_cos",
"site_time_cos",
"site_time_sin",
"site_date_sin",
}

expected_data_vars = {"nwp-ukv", "satellite", "site"}

import xarray as xr

sample.to_netcdf("sample.nc")
sample = xr.open_dataset("sample.nc")

# Check dimensions
assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
assert (
set(sample.dims) == expected_dims
), f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
# Check data variables
assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"
assert (
set(sample.data_vars) == expected_data_vars
), f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"

for coords in expected_coords_subset:
assert coords in sample.coords

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
Expand All @@ -38,6 +68,7 @@ def test_site(site_config_filename):
# 1.5 hours of 30 minute data (inclusive)
assert sample["site"].values.shape == (4,)


def test_site_time_filter_start(site_config_filename):

# Create dataset object
Expand Down Expand Up @@ -74,11 +105,51 @@ def test_convert_from_dataset_to_dict_datasets(site_config_filename):

assert isinstance(sample, dict)

print(sample.keys())

for key in ["nwp", "satellite", "site"]:
assert key in sample


def test_site_dataset_with_dataloader(site_config_filename):
# Create dataset object
dataset = SitesDataset(site_config_filename)

expected_coods = {
"site_solar_azimuth",
"site_solar_elevation",
"site_date_cos",
"site_time_cos",
"site_time_sin",
"site_date_sin",
}

sample = dataset[0]
for key in expected_coods:
assert key in sample

dataloader_kwargs = dict(
shuffle=False,
batch_size=None,
sampler=None,
batch_sampler=None,
num_workers=1,
collate_fn=None,
pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=1,
persistent_workers=False, # Not needed since we only enter the dataloader loop once
)

dataloader = DataLoader(dataset, collate_fn=None, batch_size=None)

for i, sample in zip(range(1), dataloader):

# check that expected_dims is in the sample
for key in expected_coods:
assert key in sample


def test_process_and_combine_site_sample_dict(site_config_filename):
# Load config
# config = load_yaml_configuration(pvnet_config_filename)
Expand Down
Loading