Skip to content

Commit a9cf363

Browse files
Sonja-StockhausSonja Stockhaustimtreis
authored
add shapes parameter to render shapes as hex/circle/square (#474)
Co-authored-by: Sonja Stockhaus <[email protected]> Co-authored-by: Tim Treis <[email protected]>
1 parent ce5a103 commit a9cf363

22 files changed

+233
-5
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
from copy import deepcopy
77
from pathlib import Path
8-
from typing import Any
8+
from typing import Any, Literal
99

1010
import matplotlib.pyplot as plt
1111
import numpy as np
@@ -170,6 +170,7 @@ def render_shapes(
170170
method: str | None = None,
171171
table_name: str | None = None,
172172
table_layer: str | None = None,
173+
shape: Literal["circle", "hex", "square"] | None = None,
173174
**kwargs: Any,
174175
) -> sd.SpatialData:
175176
"""
@@ -242,6 +243,9 @@ def render_shapes(
242243
table_layer: str | None
243244
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
244245
:attr:`sdata.table.X` is used for coloring.
246+
shape: Literal["circle", "hex", "square"] | None
247+
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
248+
specified, the shapes are converted to a circle/hexagon/square before rendering.
245249
246250
**kwargs : Any
247251
Additional arguments for customization. This can include:
@@ -286,6 +290,7 @@ def render_shapes(
286290
scale=scale,
287291
table_name=table_name,
288292
table_layer=table_layer,
293+
shape=shape,
289294
method=method,
290295
ds_reduction=kwargs.get("datashader_reduction"),
291296
)
@@ -318,6 +323,7 @@ def render_shapes(
318323
transfunc=kwargs.get("transfunc"),
319324
table_name=param_values["table_name"],
320325
table_layer=param_values["table_layer"],
326+
shape=param_values["shape"],
321327
zorder=n_steps,
322328
method=param_values["method"],
323329
ds_reduction=param_values["ds_reduction"],

src/spatialdata_plot/pl/render.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from spatialdata_plot.pl.utils import (
3838
_ax_show_and_transform,
3939
_convert_alpha_to_datashader_range,
40+
_convert_shapes,
4041
_create_image_from_datashader_result,
4142
_datashader_aggregate_with_function,
4243
_datashader_map_aggregate_to_color,
@@ -163,6 +164,15 @@ def _render_shapes(
163164
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
164165

165166
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
167+
# convert shapes if necessary
168+
if render_params.shape is not None:
169+
current_type = shapes["geometry"].type
170+
if not (render_params.shape == "circle" and (current_type == "Point").all()):
171+
logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.")
172+
max_extent = np.max(
173+
[shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]]
174+
)
175+
shapes = _convert_shapes(shapes, render_params.shape, max_extent)
166176

167177
# Determine which method to use for rendering
168178
method = render_params.method
@@ -186,17 +196,17 @@ def _render_shapes(
186196
# Handle circles encoded as points with radius
187197
if is_point.any():
188198
scale = shapes[is_point]["radius"] * render_params.scale
189-
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
199+
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
190200

191201
# apply transformations to the individual points
192202
tm = trans.get_matrix()
193-
transformed_element = sdata_filt.shapes[element].transform(
203+
transformed_geometry = shapes["geometry"].transform(
194204
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2]
195205
)
196206
transformed_element = ShapesModel.parse(
197207
gpd.GeoDataFrame(
198-
data=sdata_filt.shapes[element].drop("geometry", axis=1),
199-
geometry=transformed_element,
208+
data=shapes.drop("geometry", axis=1),
209+
geometry=transformed_geometry,
200210
)
201211
)
202212

src/spatialdata_plot/pl/render_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ class ShapesRenderParams:
211211
zorder: int = 0
212212
table_name: str | None = None
213213
table_layer: str | None = None
214+
shape: Literal["circle", "hex", "square"] | None = None
214215
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
215216

216217

src/spatialdata_plot/pl/utils.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
import os
45
import warnings
56
from collections import OrderedDict
@@ -51,6 +52,7 @@
5152
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
5253
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
5354
from scanpy.plotting.palettes import default_20, default_28, default_102
55+
from scipy.spatial import ConvexHull
5456
from skimage.color import label2rgb
5557
from skimage.morphology import erosion, square
5658
from skimage.segmentation import find_boundaries
@@ -1818,6 +1820,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
18181820
if size < 0:
18191821
raise ValueError("Parameter 'size' must be a positive number.")
18201822

1823+
if element_type == "shapes" and (shape := param_dict.get("shape")) is not None:
1824+
if not isinstance(shape, str):
1825+
raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.")
1826+
if shape not in ["circle", "hex", "square"]:
1827+
raise ValueError(
1828+
f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']."
1829+
)
1830+
18211831
table_name = param_dict.get("table_name")
18221832
table_layer = param_dict.get("table_layer")
18231833
if table_name and not isinstance(param_dict["table_name"], str):
@@ -2030,6 +2040,7 @@ def _validate_shape_render_params(
20302040
scale: float | int,
20312041
table_name: str | None,
20322042
table_layer: str | None,
2043+
shape: Literal["circle", "hex", "square"] | None,
20332044
method: str | None,
20342045
ds_reduction: str | None,
20352046
) -> dict[str, dict[str, Any]]:
@@ -2049,6 +2060,7 @@ def _validate_shape_render_params(
20492060
"scale": scale,
20502061
"table_name": table_name,
20512062
"table_layer": table_layer,
2063+
"shape": shape,
20522064
"method": method,
20532065
"ds_reduction": ds_reduction,
20542066
}
@@ -2069,6 +2081,7 @@ def _validate_shape_render_params(
20692081
element_params[el]["norm"] = param_dict["norm"]
20702082
element_params[el]["scale"] = param_dict["scale"]
20712083
element_params[el]["table_layer"] = param_dict["table_layer"]
2084+
element_params[el]["shape"] = param_dict["shape"]
20722085

20732086
element_params[el]["color"] = param_dict["color"]
20742087

@@ -2487,6 +2500,39 @@ def _prepare_transformation(
24872500
return trans, trans_data
24882501

24892502

2503+
def _get_datashader_trans_matrix_of_single_element(
2504+
trans: Identity | Scale | Affine | MapAxis | Translation,
2505+
) -> ArrayLike:
2506+
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
2507+
tm: ArrayLike = trans.to_affine_matrix(("x", "y"), ("x", "y"))
2508+
2509+
if isinstance(trans, Identity):
2510+
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2511+
if isinstance(trans, (Scale | Affine)):
2512+
# idea: "flip the y-axis", apply transformation, flip back
2513+
flip_and_transform: ArrayLike = flip_matrix @ tm @ flip_matrix
2514+
return flip_and_transform
2515+
if isinstance(trans, MapAxis):
2516+
# no flipping needed
2517+
return tm
2518+
# for a Translation, we need the transposed transformation matrix
2519+
tm_T = tm.T
2520+
assert isinstance(tm_T, np.ndarray)
2521+
return tm_T
2522+
2523+
2524+
def _get_transformation_matrix_for_datashader(
2525+
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
2526+
) -> ArrayLike:
2527+
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
2528+
if isinstance(trans, SDSequence):
2529+
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2530+
for x in trans.transformations:
2531+
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
2532+
return tm
2533+
return _get_datashader_trans_matrix_of_single_element(trans)
2534+
2535+
24902536
def _datashader_map_aggregate_to_color(
24912537
agg: DataArray,
24922538
cmap: str | list[str] | ListedColormap,
@@ -2588,6 +2634,124 @@ def _hex_no_alpha(hex: str) -> str:
25882634
raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")
25892635

25902636

2637+
def _convert_shapes(
2638+
shapes: GeoDataFrame, target_shape: str, max_extent: float, warn_above_extent_fraction: float = 0.5
2639+
) -> GeoDataFrame:
2640+
"""Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape."""
2641+
# NOTE: possible follow-up: when converting equally sized shapes to hex, automatically scale resulting hexagons
2642+
# so that they are perfectly adjacent to each other
2643+
2644+
if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0:
2645+
warn_above_extent_fraction = 0.5 # set to default if the value is outside [0, 1]
2646+
warn_shape_size = False
2647+
2648+
# define individual conversion methods
2649+
def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
2650+
vertices = [
2651+
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
2652+
for angle in range(0, 360, 60)
2653+
]
2654+
return shapely.Polygon(vertices), None
2655+
2656+
def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
2657+
vertices = [
2658+
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
2659+
for angle in range(45, 360, 90)
2660+
]
2661+
return shapely.Polygon(vertices), None
2662+
2663+
def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
2664+
return center, radius
2665+
2666+
def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
2667+
center, radius = _polygon_to_circle(polygon)
2668+
return _circle_to_hexagon(center, radius)
2669+
2670+
def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
2671+
center, radius = _polygon_to_circle(polygon)
2672+
return _circle_to_square(center, radius)
2673+
2674+
def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]:
2675+
coords = np.array(polygon.exterior.coords)
2676+
circle_points = coords[ConvexHull(coords).vertices]
2677+
center = np.mean(circle_points, axis=0)
2678+
radius = max(float(np.linalg.norm(p - center)) for p in circle_points)
2679+
assert isinstance(radius, float) # shut up mypy
2680+
if 2 * radius > max_extent * warn_above_extent_fraction:
2681+
nonlocal warn_shape_size
2682+
warn_shape_size = True
2683+
return shapely.Point(center), radius
2684+
2685+
def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
2686+
center, radius = _multipolygon_to_circle(multipolygon)
2687+
return _circle_to_hexagon(center, radius)
2688+
2689+
def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
2690+
center, radius = _multipolygon_to_circle(multipolygon)
2691+
return _circle_to_square(center, radius)
2692+
2693+
def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]:
2694+
coords = []
2695+
for polygon in multipolygon.geoms:
2696+
coords.extend(polygon.exterior.coords)
2697+
points = np.array(coords)
2698+
circle_points = points[ConvexHull(points).vertices]
2699+
center = np.mean(circle_points, axis=0)
2700+
radius = max(float(np.linalg.norm(p - center)) for p in circle_points)
2701+
assert isinstance(radius, float) # shut up mypy
2702+
if 2 * radius > max_extent * warn_above_extent_fraction:
2703+
nonlocal warn_shape_size
2704+
warn_shape_size = True
2705+
return shapely.Point(center), radius
2706+
2707+
# define dict with all conversion methods
2708+
if target_shape == "circle":
2709+
conversion_methods = {
2710+
"Point": _circle_to_circle,
2711+
"Polygon": _polygon_to_circle,
2712+
"Multipolygon": _multipolygon_to_circle,
2713+
}
2714+
pass
2715+
elif target_shape == "hex":
2716+
conversion_methods = {
2717+
"Point": _circle_to_hexagon,
2718+
"Polygon": _polygon_to_hexagon,
2719+
"Multipolygon": _multipolygon_to_hexagon,
2720+
}
2721+
else:
2722+
conversion_methods = {
2723+
"Point": _circle_to_square,
2724+
"Polygon": _polygon_to_square,
2725+
"Multipolygon": _multipolygon_to_square,
2726+
}
2727+
2728+
# convert every shape
2729+
for i in range(shapes.shape[0]):
2730+
if shapes["geometry"][i].type == "Point":
2731+
converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore
2732+
elif shapes["geometry"][i].type == "Polygon":
2733+
converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore
2734+
elif shapes["geometry"][i].type == "MultiPolygon":
2735+
converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore
2736+
else:
2737+
error_type = shapes["geometry"][i].type
2738+
raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.")
2739+
shapes["geometry"][i] = converted
2740+
if radius is not None:
2741+
if "radius" not in shapes.columns:
2742+
shapes["radius"] = np.nan
2743+
shapes["radius"][i] = radius
2744+
2745+
if warn_shape_size:
2746+
logger.info(
2747+
f"When converting the shapes, the size of at least one target shape extends "
2748+
f"{warn_above_extent_fraction * 100}% of the original total bound of the shapes. The conversion"
2749+
" might not give satisfying results in this scenario."
2750+
)
2751+
2752+
return shapes
2753+
2754+
25912755
def _convert_alpha_to_datashader_range(alpha: float) -> float:
25922756
"""Convert alpha from the range [0, 1] to the range [0, 255] used in datashader."""
25932757
# prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes
13.7 KB
Loading
9.98 KB
Loading
10.5 KB
Loading
10.5 KB
Loading
9.98 KB
Loading
13.9 KB
Loading

0 commit comments

Comments
 (0)