diff --git a/ocf_data_sampler/torch_datasets/datasets/site.py b/ocf_data_sampler/torch_datasets/datasets/site.py index 7e3b610..bc439ad 100644 --- a/ocf_data_sampler/torch_datasets/datasets/site.py +++ b/ocf_data_sampler/torch_datasets/datasets/site.py @@ -241,29 +241,30 @@ def process_and_combine_site_sample_dict( # add datetime features datetimes = pd.DatetimeIndex(combined_sample_dataset.site__time_utc.values) - datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site") - datetime_features_xr = xr.Dataset(datetime_features, coords={"site__time_utc": datetimes}) - combined_sample_dataset = xr.merge([combined_sample_dataset, datetime_features_xr]) + datetime_features = make_datetime_numpy_dict(datetimes=datetimes, key_prefix="site_") + combined_sample_dataset = combined_sample_dataset.assign_coords( + {k: ("site__time_utc", v) for k, v in datetime_features.items()} + ) # add sun features sun_position_features = make_sun_position_numpy_sample( datetimes=datetimes, lon=combined_sample_dataset.site__longitude.values, lat=combined_sample_dataset.site__latitude.values, - key_prefix="site", + key_prefix="site_", ) - sun_position_features_xr = xr.Dataset( - sun_position_features, coords={"site__time_utc": datetimes} + combined_sample_dataset = combined_sample_dataset.assign_coords( + {k: ("site__time_utc", v) for k, v in sun_position_features.items()} ) - combined_sample_dataset = xr.merge([combined_sample_dataset, sun_position_features_xr]) # TODO include t0_index in xr dataset? # Fill any nan values return combined_sample_dataset.fillna(0.0) - - def merge_data_arrays(self, normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset: + def merge_data_arrays( + self, normalised_data_arrays: list[Tuple[str, xr.DataArray]] + ) -> xr.Dataset: """ Combine a list of DataArrays into a single Dataset with unique naming conventions. diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py index ecee74e..703ac27 100644 --- a/tests/torch_datasets/test_site.py +++ b/tests/torch_datasets/test_site.py @@ -3,6 +3,8 @@ from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_from_dataset_to_dict_datasets from xarray import Dataset, DataArray +from torch.utils.data import DataLoader + def test_site(site_config_filename): @@ -18,17 +20,45 @@ def test_site(site_config_filename): assert isinstance(sample, Dataset) # Expected dimensions and data variables - expected_dims = {'satellite__x_geostationary', 'site__time_utc', 'nwp-ukv__target_time_utc', - 'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary', - 'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb', 'site_solar_azimuth', - 'site_solar_elevation', 'site_date_cos', 'site_time_cos', 'site_time_sin', 'site_date_sin'} + expected_dims = { + "satellite__x_geostationary", + "site__time_utc", + "nwp-ukv__target_time_utc", + "nwp-ukv__x_osgb", + "satellite__channel", + "satellite__y_geostationary", + "satellite__time_utc", + "nwp-ukv__channel", + "nwp-ukv__y_osgb", + } + + expected_coords_subset = { + "site__solar_azimuth", + "site__solar_elevation", + "site__date_cos", + "site__time_cos", + "site__time_sin", + "site__date_sin", + } expected_data_vars = {"nwp-ukv", "satellite", "site"} + import xarray as xr + + sample.to_netcdf("sample.nc") + sample = xr.open_dataset("sample.nc") + # Check dimensions - assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}" + assert ( + set(sample.dims) == expected_dims + ), f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}" # Check data variables - assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}" + assert ( + set(sample.data_vars) == expected_data_vars + ), f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}" + + for coords in expected_coords_subset: + assert coords in sample.coords # check the shape of the data is correct # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels @@ -38,6 +68,7 @@ def test_site(site_config_filename): # 1.5 hours of 30 minute data (inclusive) assert sample["site"].values.shape == (4,) + def test_site_time_filter_start(site_config_filename): # Create dataset object @@ -74,11 +105,51 @@ def test_convert_from_dataset_to_dict_datasets(site_config_filename): assert isinstance(sample, dict) - print(sample.keys()) - for key in ["nwp", "satellite", "site"]: assert key in sample + +def test_site_dataset_with_dataloader(site_config_filename): + # Create dataset object + dataset = SitesDataset(site_config_filename) + + expected_coods = { + "site__solar_azimuth", + "site__solar_elevation", + "site__date_cos", + "site__time_cos", + "site__time_sin", + "site__date_sin", + } + + sample = dataset[0] + for key in expected_coods: + assert key in sample + + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, + sampler=None, + batch_sampler=None, + num_workers=1, + collate_fn=None, + pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=1, + persistent_workers=False, # Not needed since we only enter the dataloader loop once + ) + + dataloader = DataLoader(dataset, collate_fn=None, batch_size=None) + + for i, sample in zip(range(1), dataloader): + + # check that expected_dims is in the sample + for key in expected_coods: + assert key in sample + + def test_process_and_combine_site_sample_dict(site_config_filename): # Load config # config = load_yaml_configuration(pvnet_config_filename)