From 8db850c8259ab8f95e54303009d314ee85ac9b2b Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 9 Aug 2024 17:02:07 +0200 Subject: [PATCH 1/2] add clip parameter to polygon_query; tests missing --- src/spatialdata/_core/query/spatial_query.py | 21 +++++++++++++++++++- src/spatialdata/_core/spatialdata.py | 2 ++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index e312589d..dea2280a 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -13,7 +13,7 @@ from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame -from shapely.geometry import MultiPolygon, Polygon +from shapely.geometry import MultiPolygon, Point, Polygon from xarray import DataArray from spatialdata import to_polygons @@ -758,6 +758,7 @@ def polygon_query( polygon: Polygon | MultiPolygon, target_coordinate_system: str, filter_table: bool = True, + clip: bool = False, shapes: bool = True, points: bool = True, images: bool = True, @@ -777,6 +778,12 @@ def polygon_query( filter_table Specifies whether to filter the tables to only include tables that annotate elements in the retrieved SpatialData object of the query. + clip + If `True`, the shapes are clipped to the polygon. This behavior is implemented only when querying + polygons/multipolygons or circles, and it is ignored for other types of elements (images, labels, points). + Importantly, when clipping is enabled, the circles will be converted to polygons before the clipping. This may + affect downstream operations that rely on the circle radius or on performance, so it is recommended to disable + clipping when querying circles or when querying a `SpatialData` object that contains circles. shapes [Deprecated] This argument is now ignored and will be removed. Please filter the SpatialData object before calling this function. @@ -810,6 +817,7 @@ def _( polygon: Polygon | MultiPolygon, target_coordinate_system: str, filter_table: bool = True, + clip: bool = False, shapes: bool = True, points: bool = True, images: bool = True, @@ -825,6 +833,7 @@ def _( polygon_query, polygon=polygon, target_coordinate_system=target_coordinate_system, + clip=clip, ) new_elements[element_type] = queried_elements @@ -891,6 +900,7 @@ def _( element: GeoDataFrame, polygon: Polygon | MultiPolygon, target_coordinate_system: str, + clip: bool = False, **kwargs: Any, ) -> GeoDataFrame | None: from spatialdata.transformations import get_transformation, set_transformation @@ -912,9 +922,18 @@ def _( queried_shapes = element[indices] queried_shapes.index = buffered[indices][OLD_INDEX] queried_shapes.index.name = None + + if clip: + if isinstance(element.geometry.iloc[0], Point): + queried_shapes = buffered[indices] + queried_shapes.index = buffered[indices][OLD_INDEX] + queried_shapes.index.name = None + queried_shapes = queried_shapes.clip(polygon_gdf, keep_geom_type=True) + del buffered[OLD_INDEX] if OLD_INDEX in queried_shapes.columns: del queried_shapes[OLD_INDEX] + transformation = get_transformation(buffered, target_coordinate_system) queried_shapes = ShapesModel.parse(queried_shapes) set_transformation(queried_shapes, transformation, target_coordinate_system) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 87e89ba7..c1ccfab7 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2225,6 +2225,7 @@ def polygon( polygon: Polygon | MultiPolygon, target_coordinate_system: str, filter_table: bool = True, + clip: bool = False, ) -> SpatialData: """ Perform a polygon query on the SpatialData object. @@ -2239,6 +2240,7 @@ def polygon( polygon=polygon, target_coordinate_system=target_coordinate_system, filter_table=filter_table, + clip=clip, ) def __call__(self, request: BaseSpatialRequest, **kwargs) -> SpatialData: # type: ignore[no-untyped-def] From 1cb3f4d875862cf1db22986fef16cf04c2b99ae3 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 12 Aug 2024 14:30:02 +0200 Subject: [PATCH 2/2] tests and changelog --- CHANGELOG.md | 10 +++++--- tests/core/query/test_spatial_query.py | 35 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54c770f0..7cdb72db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,18 +10,22 @@ and this project adheres to [Semantic Versioning][]. ## [0.x.x] - 2024-xx-xx +### Minor + +- Added `clip: bool = False` parameter to `polygon_query()` #670 + ## [0.2.2] - 2024-08-07 -# Major +### Major - New disk format for shapes using `GeoParquet` (the change is backward compatible) #542 -# Minor +### Minor - Add `return_background` as argument to `get_centroids` and `get_element_instances` #621 - Ability to save data using older disk formats #542 -# Fixed +### Fixed - Circles validation now checks for inf or nan radii #653 - Bug with table name in torch dataset #654 @LLehner diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 74cd5128..18c7614d 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -11,6 +11,7 @@ from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon +from spatialdata._core.data_extent import get_extent from spatialdata._core.query.spatial_query import ( BaseSpatialRequest, BoundingBoxRequest, @@ -686,3 +687,37 @@ def test_spatial_query_different_axes(full_sdata, name: str): return raise RuntimeError(f"Unexpected type {type(original)}") + + +def test_query_with_clipping(sdata_blobs): + circles = sdata_blobs["blobs_circles"] + circles.index = [10, 100, 1] + polygons = sdata_blobs["blobs_polygons"] + polygons.index = [10, 100, 1] + + # define square to use as query geometry + minx = 120 + maxx = 170 + miny = 150 + maxy = 210 + x_coords = [minx, maxx, maxx, minx, minx] + y_coords = [miny, miny, maxy, maxy, miny] + polygon = Polygon(zip(x_coords, y_coords)) + + queried_circles = polygon_query(circles, polygon=polygon, target_coordinate_system="global", clip=True) + queried_polygons = polygon_query(polygons, polygon=polygon, target_coordinate_system="global", clip=True) + + assert queried_circles.index.tolist() == [100] + assert queried_polygons.index.tolist() == [100] + + extent_circles = get_extent(queried_circles) + extent_polygons = get_extent(queried_polygons) + + def query_polyon_contains_queried_data(extent: dict[str, tuple[float, float]]) -> None: + assert extent["x"][0] >= minx + assert extent["x"][1] <= maxx + assert extent["y"][0] >= miny + assert extent["y"][1] <= maxy + + query_polyon_contains_queried_data(extent_circles) + query_polyon_contains_queried_data(extent_polygons)