Skip to content

Commit c9bae23

Browse files
Fix NaN handling and transparency in shape rendering (#503)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 491dae9 commit c9bae23

File tree

6 files changed

+393
-308
lines changed

6 files changed

+393
-308
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import geopandas as gpd
1010
import matplotlib
1111
import matplotlib.pyplot as plt
12+
import matplotlib.ticker
1213
import numpy as np
1314
import pandas as pd
1415
import scanpy as sc
@@ -141,6 +142,15 @@ def _render_shapes(
141142
color_source_vector = color_source_vector[mask]
142143
color_vector = color_vector[mask]
143144

145+
# continuous case: leave NaNs as NaNs; utils maps them to na_color during draw
146+
if color_source_vector is None and not values_are_categorical:
147+
color_vector = np.asarray(color_vector, dtype=float)
148+
if np.isnan(color_vector).any():
149+
nan_count = int(np.isnan(color_vector).sum())
150+
logger.warning(
151+
f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
152+
)
153+
144154
# Using dict.fromkeys here since set returns in arbitrary order
145155
# remove the color of NaN values, else it might be assigned to a category
146156
# order of color in the palette should agree to order of occurence
@@ -195,7 +205,10 @@ def _render_shapes(
195205

196206
# Handle circles encoded as points with radius
197207
if is_point.any():
198-
scale = shapes[is_point]["radius"] * render_params.scale
208+
radius_values = shapes[is_point]["radius"]
209+
# Convert to numeric, replacing non-numeric values with NaN
210+
radius_numeric = pd.to_numeric(radius_values, errors="coerce")
211+
scale = radius_numeric * render_params.scale
199212
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
200213

201214
# apply transformations to the individual points
@@ -218,6 +231,20 @@ def _render_shapes(
218231

219232
# in case we are coloring by a column in table
220233
if col_for_color is not None and col_for_color not in transformed_element.columns:
234+
# Ensure color vector length matches the number of shapes
235+
if len(color_vector) != len(transformed_element):
236+
if len(color_vector) == 1:
237+
# If single color, broadcast to all shapes
238+
color_vector = [color_vector[0]] * len(transformed_element)
239+
else:
240+
# If lengths don't match, pad or truncate to match
241+
if len(color_vector) > len(transformed_element):
242+
color_vector = color_vector[: len(transformed_element)]
243+
else:
244+
# Pad with the last color or na_color
245+
na_color = render_params.cmap_params.na_color.get_hex_with_alpha()
246+
color_vector = list(color_vector) + [na_color] * (len(transformed_element) - len(color_vector))
247+
221248
transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
222249
# Render shapes with datashader
223250
color_by_categorical = col_for_color is not None and color_source_vector is not None
@@ -447,12 +474,13 @@ def _render_shapes(
447474
path.vertices = trans.transform(path.vertices)
448475

449476
if not values_are_categorical:
450-
# If the user passed a Normalize object with vmin/vmax we'll use those,
451-
# if not we'll use the min/max of the color_vector
452-
_cax.set_clim(
453-
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
454-
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
455-
)
477+
vmin = render_params.cmap_params.norm.vmin
478+
vmax = render_params.cmap_params.norm.vmax
479+
if vmin is None:
480+
vmin = float(np.nanmin(color_vector))
481+
if vmax is None:
482+
vmax = float(np.nanmax(color_vector))
483+
_cax.set_clim(vmin=vmin, vmax=vmax)
456484

457485
if (
458486
len(set(color_vector)) != 1

0 commit comments

Comments
 (0)