diff --git a/ocf_data_sampler/numpy_batch/satellite.py b/ocf_data_sampler/numpy_batch/satellite.py index c4c8d5b..6cdb270 100644 --- a/ocf_data_sampler/numpy_batch/satellite.py +++ b/ocf_data_sampler/numpy_batch/satellite.py @@ -1,23 +1,30 @@ """Convert Satellite to NumpyBatch""" import xarray as xr -from ocf_datapipes.batch import BatchKey, NumpyBatch +class SatelliteBatchKey: -def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NumpyBatch: + satellite_actual = 'satellite_actual' + satellite_time_utc = 'satellite_time_utc' + satellite_x_geostationary = 'satellite_x_geostationary' + satellite_y_geostationary = 'satellite_y_geostationary' + satellite_t0_idx = 'satellite_t0_idx' + + +def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" - example: NumpyBatch = { - BatchKey.satellite_actual: da.values, - BatchKey.satellite_time_utc: da.time_utc.values.astype(float), + example = { + SatelliteBatchKey.satellite_actual: da.values, + SatelliteBatchKey.satellite_time_utc: da.time_utc.values.astype(float), } for batch_key, dataset_key in ( - (BatchKey.satellite_x_geostationary, "x_geostationary"), - (BatchKey.satellite_y_geostationary, "y_geostationary"), + (SatelliteBatchKey.satellite_x_geostationary, "x_geostationary"), + (SatelliteBatchKey.satellite_y_geostationary, "y_geostationary"), ): example[batch_key] = da[dataset_key].values if t0_idx is not None: - example[BatchKey.satellite_t0_idx] = t0_idx + example[SatelliteBatchKey.satellite_t0_idx] = t0_idx return example \ No newline at end of file diff --git a/ocf_data_sampler/numpy_batch/sun_position.py b/ocf_data_sampler/numpy_batch/sun_position.py index 0866801..423222e 100644 --- a/ocf_data_sampler/numpy_batch/sun_position.py +++ b/ocf_data_sampler/numpy_batch/sun_position.py @@ -2,7 +2,6 @@ import pvlib import numpy as np import pandas as pd -from ocf_datapipes.batch import BatchKey, NumpyBatch def calculate_azimuth_and_elevation( @@ -38,7 +37,7 @@ def make_sun_position_numpy_batch( lon: float, lat: float, key_preffix: str = "gsp" -) -> NumpyBatch: +) -> dict: """Creates NumpyBatch with standardized solar coordinates Args: @@ -58,9 +57,9 @@ def make_sun_position_numpy_batch( elevation = elevation / 180 + 0.5 # Make NumpyBatch - sun_numpy_batch: NumpyBatch = { - BatchKey[key_preffix + "_solar_azimuth"]: azimuth, - BatchKey[key_preffix + "_solar_elevation"]: elevation, + sun_numpy_batch = { + key_preffix + "_solar_azimuth": azimuth, + key_preffix + "_solar_elevation": elevation, } return sun_numpy_batch \ No newline at end of file