Skip to content

Commit 491dae9

Browse files
authored
Added option visium_hex with which Visium spots are rendered as continous hexes (#501)
1 parent a9cf363 commit 491dae9

File tree

6 files changed

+129
-20
lines changed

6 files changed

+129
-20
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def render_shapes(
170170
method: str | None = None,
171171
table_name: str | None = None,
172172
table_layer: str | None = None,
173-
shape: Literal["circle", "hex", "square"] | None = None,
173+
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
174174
**kwargs: Any,
175175
) -> sd.SpatialData:
176176
"""
@@ -243,9 +243,11 @@ def render_shapes(
243243
table_layer: str | None
244244
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
245245
:attr:`sdata.table.X` is used for coloring.
246-
shape: Literal["circle", "hex", "square"] | None
246+
shape: Literal["circle", "hex", "visium_hex", "square"] | None
247247
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
248-
specified, the shapes are converted to a circle/hexagon/square before rendering.
248+
specified, the shapes are converted to a circle/hexagon/square before rendering. If "visium_hex" is
249+
specified, the shapes are assumed to be Visium spots and the size of the hexagons is adjusted to be adjacent
250+
to each other.
249251
250252
**kwargs : Any
251253
Additional arguments for customization. This can include:

src/spatialdata_plot/pl/render_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ class ShapesRenderParams:
211211
zorder: int = 0
212212
table_name: str | None = None
213213
table_layer: str | None = None
214-
shape: Literal["circle", "hex", "square"] | None = None
214+
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
215215
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
216216

217217

src/spatialdata_plot/pl/utils.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,7 +1802,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
18021802
if (norm := param_dict.get("norm")) is not None:
18031803
if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
18041804
raise TypeError("Parameter 'norm' must be of type Normalize.")
1805-
if element_type in ["shapes", "points"] and not isinstance(norm, bool | Normalize):
1805+
if element_type in {"shapes", "points"} and not isinstance(
1806+
norm, bool | Normalize
1807+
):
18061808
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")
18071809

18081810
if (scale := param_dict.get("scale")) is not None:
@@ -1821,11 +1823,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
18211823
raise ValueError("Parameter 'size' must be a positive number.")
18221824

18231825
if element_type == "shapes" and (shape := param_dict.get("shape")) is not None:
1826+
valid_shapes = {"circle", "hex", "visium_hex", "square"}
18241827
if not isinstance(shape, str):
1825-
raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.")
1826-
if shape not in ["circle", "hex", "square"]:
1828+
raise TypeError(
1829+
f"Parameter 'shape' must be a String from {valid_shapes} if not None."
1830+
)
1831+
if shape not in valid_shapes:
18271832
raise ValueError(
1828-
f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']."
1833+
f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}."
18291834
)
18301835

18311836
table_name = param_dict.get("table_name")
@@ -2040,7 +2045,7 @@ def _validate_shape_render_params(
20402045
scale: float | int,
20412046
table_name: str | None,
20422047
table_layer: str | None,
2043-
shape: Literal["circle", "hex", "square"] | None,
2048+
shape: Literal["circle", "hex", "visium_hex", "square"] | None,
20442049
method: str | None,
20452050
ds_reduction: str | None,
20462051
) -> dict[str, dict[str, Any]]:
@@ -2647,9 +2652,10 @@ def _convert_shapes(
26472652

26482653
# define individual conversion methods
26492654
def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
2655+
# Create hexagon with point at top (30° offset from standard orientation)
26502656
vertices = [
26512657
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
2652-
for angle in range(0, 360, 60)
2658+
for angle in range(30, 390, 60) # Start at 30° and go every 60°
26532659
]
26542660
return shapely.Polygon(vertices), None
26552661

@@ -2718,6 +2724,62 @@ def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely
27182724
"Polygon": _polygon_to_hexagon,
27192725
"Multipolygon": _multipolygon_to_hexagon,
27202726
}
2727+
elif target_shape == "visium_hex":
2728+
# For visium_hex, we only support Points and warn for other geometry types
2729+
point_centers = []
2730+
non_point_count = 0
2731+
2732+
for i in range(shapes.shape[0]):
2733+
if shapes["geometry"][i].type == "Point":
2734+
point_centers.append((shapes["geometry"][i].x, shapes["geometry"][i].y))
2735+
else:
2736+
non_point_count += 1
2737+
2738+
if non_point_count > 0:
2739+
warnings.warn(
2740+
f"visium_hex conversion only supports Point geometries. Found {non_point_count} non-Point geometries "
2741+
f"that will be converted using regular hex conversion. Consider using shape='hex' for mixed geometry types.",
2742+
UserWarning,
2743+
stacklevel=2,
2744+
)
2745+
2746+
if len(point_centers) < 2:
2747+
# If we have fewer than 2 points, fall back to regular hex conversion
2748+
conversion_methods = {
2749+
"Point": _circle_to_hexagon,
2750+
"Polygon": _polygon_to_hexagon,
2751+
"Multipolygon": _multipolygon_to_hexagon,
2752+
}
2753+
else:
2754+
# Calculate typical spacing between point centers
2755+
centers_array = np.array(point_centers)
2756+
distances = []
2757+
for i in range(len(point_centers)):
2758+
for j in range(i + 1, len(point_centers)):
2759+
dist = np.linalg.norm(centers_array[i] - centers_array[j])
2760+
distances.append(dist)
2761+
2762+
# Use min dist of closest neighbors as the side length for radius calc
2763+
side_length = np.min(distances)
2764+
hex_radius = (side_length * 2.0 / math.sqrt(3)) / 2.0
2765+
2766+
# Create conversion methods
2767+
def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
2768+
return _circle_to_hexagon(center, hex_radius)
2769+
2770+
def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
2771+
# Fall back to regular hex conversion for non-points
2772+
return _polygon_to_hexagon(polygon)
2773+
2774+
def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
2775+
# Fall back to regular hex conversion for non-points
2776+
return _multipolygon_to_hexagon(multipolygon)
2777+
2778+
conversion_methods = {
2779+
"Point": _circle_to_visium_hex,
2780+
"Polygon": _polygon_to_visium_hex,
2781+
"Multipolygon": _multipolygon_to_visium_hex,
2782+
}
27212783
else:
27222784
conversion_methods = {
27232785
"Point": _circle_to_square,
12 KB
Loading

tests/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from abc import ABC, ABCMeta
23
from collections.abc import Callable
34
from functools import wraps
@@ -525,3 +526,29 @@ def _get_sdata_with_multiple_images(share_coordinate_system: str = "all"):
525526
return sdata
526527

527528
return _get_sdata_with_multiple_images
529+
530+
531+
@pytest.fixture
532+
def sdata_hexagonal_grid_spots():
533+
"""Create a hexagonal grid of points for testing visium_hex functionality."""
534+
from shapely.geometry import Point
535+
from spatialdata.models import ShapesModel
536+
537+
spacing = 10.0
538+
n_rows, n_cols = 4, 4
539+
540+
points = []
541+
for i, j in itertools.product(range(n_rows), range(n_cols)):
542+
# Offset every second row by half the spacing for proper hexagonal packing
543+
x = j * spacing + (i % 2) * spacing / 2
544+
y = i * spacing * 0.866 # sqrt(3)/2 for proper hexagonal spacing
545+
points.append(Point(x, y))
546+
547+
# Create GeoDataFrame with radius column
548+
gdf = GeoDataFrame(geometry=points)
549+
gdf["radius"] = 2.0 # Small radius for original circles
550+
551+
# Use ShapesModel.parse() to create a properly validated GeoDataFrame
552+
shapes_gdf = ShapesModel.parse(gdf)
553+
554+
return SpatialData(shapes={"spots": shapes_gdf})

tests/pl/test_render_shapes.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_plot_can_render_circles_with_outline(self, sdata_blobs: SpatialData):
4040
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show()
4141

4242
def test_plot_can_render_circles_with_colored_outline(self, sdata_blobs: SpatialData):
43-
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_color="red").pl.show()
43+
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_color="red").pl.show()
4444

4545
def test_plot_can_render_polygons(self, sdata_blobs: SpatialData):
4646
sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show()
@@ -49,13 +49,17 @@ def test_plot_can_render_polygons_with_outline(self, sdata_blobs: SpatialData):
4949
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1).pl.show()
5050

5151
def test_plot_can_render_polygons_with_str_colored_outline(self, sdata_blobs: SpatialData):
52-
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color="red").pl.show()
52+
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1, outline_color="red").pl.show()
5353

5454
def test_plot_can_render_polygons_with_rgb_colored_outline(self, sdata_blobs: SpatialData):
55-
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 0.0, 1.0, 1.0)).pl.show()
55+
sdata_blobs.pl.render_shapes(
56+
element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0, 1.0)
57+
).pl.show()
5658

5759
def test_plot_can_render_polygons_with_rgba_colored_outline(self, sdata_blobs: SpatialData):
58-
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0)).pl.show()
60+
sdata_blobs.pl.render_shapes(
61+
element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0)
62+
).pl.show()
5963

6064
def test_plot_can_render_empty_geometry(self, sdata_blobs: SpatialData):
6165
sdata_blobs.shapes["blobs_circles"].at[0, "geometry"] = gpd.points_from_xy([None], [None])[0]
@@ -65,7 +69,7 @@ def test_plot_can_render_circles_with_default_outline_width(self, sdata_blobs: S
6569
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show()
6670

6771
def test_plot_can_render_circles_with_specified_outline_width(self, sdata_blobs: SpatialData):
68-
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_width=3.0).pl.show()
72+
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_width=3.0).pl.show()
6973

7074
def test_plot_can_render_multipolygons(self):
7175
def _make_multi():
@@ -402,19 +406,23 @@ def test_plot_datashader_can_render_with_diff_alpha_outline(self, sdata_blobs: S
402406
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_alpha=0.5).pl.show()
403407

404408
def test_plot_datashader_can_render_with_diff_width_outline(self, sdata_blobs: SpatialData):
405-
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_width=5.0).pl.show()
409+
sdata_blobs.pl.render_shapes(
410+
method="datashader", element="blobs_polygons", outline_alpha=1.0, outline_width=5.0
411+
).pl.show()
406412

407413
def test_plot_datashader_can_render_with_colored_outline(self, sdata_blobs: SpatialData):
408-
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_color="red").pl.show()
414+
sdata_blobs.pl.render_shapes(
415+
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color="red"
416+
).pl.show()
409417

410418
def test_plot_datashader_can_render_with_rgb_colored_outline(self, sdata_blobs: SpatialData):
411419
sdata_blobs.pl.render_shapes(
412-
method="datashader", element="blobs_polygons", outline_color=(0.0, 0.0, 1.0)
420+
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0)
413421
).pl.show()
414422

415423
def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs: SpatialData):
416424
sdata_blobs.pl.render_shapes(
417-
method="datashader", element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0)
425+
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0)
418426
).pl.show()
419427

420428
def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData):
@@ -593,6 +601,12 @@ def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData)
593601
def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
594602
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show()
595603

604+
def test_plot_visium_hex_hexagonal_grid(self, sdata_hexagonal_grid_spots: SpatialData):
605+
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
606+
607+
sdata_hexagonal_grid_spots.pl.render_shapes(element="spots", shape="circle").pl.show(ax=axs[0])
608+
sdata_hexagonal_grid_spots.pl.render_shapes(element="spots", shape="visium_hex").pl.show(ax=axs[1])
609+
596610
def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData):
597611
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show()
598612

@@ -616,6 +630,7 @@ def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: S
616630

617631
def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
618632
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show()
633+
619634
def test_plot_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData):
620635
sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0)).pl.show()
621636

@@ -631,7 +646,10 @@ def test_plot_can_render_double_outline_with_diff_alpha(self, sdata_blobs: Spati
631646

632647
def test_plot_outline_alpha_takes_precedence(self, sdata_blobs: SpatialData):
633648
sdata_blobs.pl.render_shapes(
634-
element="blobs_circles", outline_color=("#ff660033", "#33aa0066"), outline_width=(20, 10), outline_alpha=1.0
649+
element="blobs_circles",
650+
outline_color=("#ff660033", "#33aa0066"),
651+
outline_width=(20, 10),
652+
outline_alpha=(1.0, 1.0),
635653
).pl.show()
636654

637655
def test_plot_datashader_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData):

0 commit comments

Comments
 (0)