Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clip parameter to polygon_query; tests missing #670

Merged
merged 2 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@ and this project adheres to [Semantic Versioning][].

## [0.x.x] - 2024-xx-xx

### Minor

- Added `clip: bool = False` parameter to `polygon_query()` #670

## [0.2.2] - 2024-08-07

# Major
### Major

- New disk format for shapes using `GeoParquet` (the change is backward compatible) #542

# Minor
### Minor

- Add `return_background` as argument to `get_centroids` and `get_element_instances` #621
- Ability to save data using older disk formats #542

# Fixed
### Fixed

- Circles validation now checks for inf or nan radii #653
- Bug with table name in torch dataset #654 @LLehner
Expand Down
21 changes: 20 additions & 1 deletion src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from shapely.geometry import MultiPolygon, Polygon
from shapely.geometry import MultiPolygon, Point, Polygon
from xarray import DataArray

from spatialdata import to_polygons
Expand Down Expand Up @@ -758,6 +758,7 @@ def polygon_query(
polygon: Polygon | MultiPolygon,
target_coordinate_system: str,
filter_table: bool = True,
clip: bool = False,
shapes: bool = True,
points: bool = True,
images: bool = True,
Expand All @@ -777,6 +778,12 @@ def polygon_query(
filter_table
Specifies whether to filter the tables to only include tables that annotate elements in the retrieved
SpatialData object of the query.
clip
If `True`, the shapes are clipped to the polygon. This behavior is implemented only when querying
polygons/multipolygons or circles, and it is ignored for other types of elements (images, labels, points).
Importantly, when clipping is enabled, the circles will be converted to polygons before the clipping. This may
affect downstream operations that rely on the circle radius or on performance, so it is recommended to disable
clipping when querying circles or when querying a `SpatialData` object that contains circles.
shapes [Deprecated]
This argument is now ignored and will be removed. Please filter the SpatialData object before calling this
function.
Expand Down Expand Up @@ -810,6 +817,7 @@ def _(
polygon: Polygon | MultiPolygon,
target_coordinate_system: str,
filter_table: bool = True,
clip: bool = False,
shapes: bool = True,
points: bool = True,
images: bool = True,
Expand All @@ -825,6 +833,7 @@ def _(
polygon_query,
polygon=polygon,
target_coordinate_system=target_coordinate_system,
clip=clip,
)
new_elements[element_type] = queried_elements

Expand Down Expand Up @@ -891,6 +900,7 @@ def _(
element: GeoDataFrame,
polygon: Polygon | MultiPolygon,
target_coordinate_system: str,
clip: bool = False,
**kwargs: Any,
) -> GeoDataFrame | None:
from spatialdata.transformations import get_transformation, set_transformation
Expand All @@ -912,9 +922,18 @@ def _(
queried_shapes = element[indices]
queried_shapes.index = buffered[indices][OLD_INDEX]
queried_shapes.index.name = None

if clip:
if isinstance(element.geometry.iloc[0], Point):
queried_shapes = buffered[indices]
queried_shapes.index = buffered[indices][OLD_INDEX]
queried_shapes.index.name = None
queried_shapes = queried_shapes.clip(polygon_gdf, keep_geom_type=True)

del buffered[OLD_INDEX]
if OLD_INDEX in queried_shapes.columns:
del queried_shapes[OLD_INDEX]

transformation = get_transformation(buffered, target_coordinate_system)
queried_shapes = ShapesModel.parse(queried_shapes)
set_transformation(queried_shapes, transformation, target_coordinate_system)
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,7 @@ def polygon(
polygon: Polygon | MultiPolygon,
target_coordinate_system: str,
filter_table: bool = True,
clip: bool = False,
) -> SpatialData:
"""
Perform a polygon query on the SpatialData object.
Expand All @@ -2239,6 +2240,7 @@ def polygon(
polygon=polygon,
target_coordinate_system=target_coordinate_system,
filter_table=filter_table,
clip=clip,
)

def __call__(self, request: BaseSpatialRequest, **kwargs) -> SpatialData: # type: ignore[no-untyped-def]
Expand Down
35 changes: 35 additions & 0 deletions tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datatree import DataTree
from geopandas import GeoDataFrame
from shapely import MultiPolygon, Point, Polygon
from spatialdata._core.data_extent import get_extent
from spatialdata._core.query.spatial_query import (
BaseSpatialRequest,
BoundingBoxRequest,
Expand Down Expand Up @@ -686,3 +687,37 @@ def test_spatial_query_different_axes(full_sdata, name: str):
return

raise RuntimeError(f"Unexpected type {type(original)}")


def test_query_with_clipping(sdata_blobs):
circles = sdata_blobs["blobs_circles"]
circles.index = [10, 100, 1]
polygons = sdata_blobs["blobs_polygons"]
polygons.index = [10, 100, 1]

# define square to use as query geometry
minx = 120
maxx = 170
miny = 150
maxy = 210
x_coords = [minx, maxx, maxx, minx, minx]
y_coords = [miny, miny, maxy, maxy, miny]
polygon = Polygon(zip(x_coords, y_coords))

queried_circles = polygon_query(circles, polygon=polygon, target_coordinate_system="global", clip=True)
queried_polygons = polygon_query(polygons, polygon=polygon, target_coordinate_system="global", clip=True)

assert queried_circles.index.tolist() == [100]
assert queried_polygons.index.tolist() == [100]

extent_circles = get_extent(queried_circles)
extent_polygons = get_extent(queried_polygons)

def query_polyon_contains_queried_data(extent: dict[str, tuple[float, float]]) -> None:
assert extent["x"][0] >= minx
assert extent["x"][1] <= maxx
assert extent["y"][0] >= miny
assert extent["y"][1] <= maxy

query_polyon_contains_queried_data(extent_circles)
query_polyon_contains_queried_data(extent_polygons)
Loading