Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Oct 18, 2024
1 parent 6be35c0 commit 16184e9
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
5 changes: 2 additions & 3 deletions tests/numpy_batch/test_gsp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from ocf_datapipes.batch import BatchKey
from ocf_data_sampler.load.gsp import open_gsp

from ocf_data_sampler.numpy_batch import convert_gsp_to_numpy_batch

from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey

def test_convert_gsp_to_numpy_batch(uk_gsp_zarr_path):

Expand All @@ -19,5 +18,5 @@ def test_convert_gsp_to_numpy_batch(uk_gsp_zarr_path):
assert isinstance(numpy_batch, dict)

# Assert the shape of the numpy batch
assert (numpy_batch[BatchKey.gsp] == da.values).all()
assert (numpy_batch[GSPBatchKey.gsp] == da.values).all()

2 changes: 1 addition & 1 deletion tests/numpy_batch/test_nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ocf_data_sampler.numpy_batch import convert_nwp_to_numpy_batch

from ocf_datapipes.batch import NWPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey

@pytest.fixture(scope="module")
def da_nwp_like():
Expand Down
14 changes: 7 additions & 7 deletions tests/numpy_batch/test_sun_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
calculate_azimuth_and_elevation, make_sun_position_numpy_batch
)

from ocf_datapipes.batch import NumpyBatch, BatchKey
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey


@pytest.mark.parametrize("lat", [0, 5, 10, 23.5])
Expand Down Expand Up @@ -71,11 +71,11 @@ def test_make_sun_position_numpy_batch():

batch = make_sun_position_numpy_batch(datetimes, lon, lat, key_preffix="gsp")

assert BatchKey.gsp_solar_elevation in batch
assert BatchKey.gsp_solar_azimuth in batch
assert GSPBatchKey.gsp_solar_elevation in batch
assert GSPBatchKey.gsp_solar_azimuth in batch

# The solar coords are normalised in the function
assert (batch[BatchKey.gsp_solar_elevation]>=0).all()
assert (batch[BatchKey.gsp_solar_elevation]<=1).all()
assert (batch[BatchKey.gsp_solar_azimuth]>=0).all()
assert (batch[BatchKey.gsp_solar_azimuth]<=1).all()
assert (batch[GSPBatchKey.gsp_solar_elevation]>=0).all()
assert (batch[GSPBatchKey.gsp_solar_elevation]<=1).all()
assert (batch[GSPBatchKey.gsp_solar_azimuth]>=0).all()
assert (batch[GSPBatchKey.gsp_solar_azimuth]<=1).all()
2 changes: 1 addition & 1 deletion tests/select/test_select_spatial_slice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import xarray as xr
from ocf_datapipes.utils import Location
from ocf_data_sampler.select.location import Location
import pytest

from ocf_data_sampler.select.select_spatial_slice import (
Expand Down

0 comments on commit 16184e9

Please sign in to comment.