Skip to content

Commit

Permalink
Fixing edge case drop_dimension (#188)
Browse files Browse the repository at this point in the history
* Fixing edge case drop_dimension

* fix pre-commit

* fix pre-commit
  • Loading branch information
clausmichele authored Nov 2, 2023
1 parent 416ccdb commit cfe15ce
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def drop_dimension(data: RasterCube, name: str) -> RasterCube:
raise DimensionLabelCountMismatch(
f"The number of dimension labels exceeds one, which requires a reducer. Dimension ({name}) has {len(data[name])} labels."
)
return data.drop_vars(name).squeeze()
return data.drop_vars(name).squeeze(name)


def create_raster_cube() -> RasterCube:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_add_dimension(temporal_interval, bounding_box, random_raster_data):
assert output_cube_2.openeo.temporal_dims[1] == "weird"


@pytest.mark.parametrize("size", [(30, 30, 20, 2)])
@pytest.mark.parametrize("size", [(30, 30, 1, 2)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_drop_dimension(temporal_interval, bounding_box, random_raster_data):
input_cube = create_fake_rastercube(
Expand All @@ -63,4 +63,6 @@ def test_drop_dimension(temporal_interval, bounding_box, random_raster_data):
suitable_cube = input_cube.where(input_cube.bands == "B02", drop=True)

output_cube = drop_dimension(suitable_cube, DIM_TO_DROP)
DIMS_TO_KEEP = tuple(filter(lambda y: y != DIM_TO_DROP, input_cube.dims))
assert DIM_TO_DROP not in output_cube.dims
assert DIMS_TO_KEEP == output_cube.dims

0 comments on commit cfe15ce

Please sign in to comment.