99import geopandas as gpd
1010import matplotlib
1111import matplotlib .pyplot as plt
12+ import matplotlib .ticker
1213import numpy as np
1314import pandas as pd
1415import scanpy as sc
@@ -141,6 +142,15 @@ def _render_shapes(
141142 color_source_vector = color_source_vector [mask ]
142143 color_vector = color_vector [mask ]
143144
145+ # continuous case: leave NaNs as NaNs; utils maps them to na_color during draw
146+ if color_source_vector is None and not values_are_categorical :
147+ color_vector = np .asarray (color_vector , dtype = float )
148+ if np .isnan (color_vector ).any ():
149+ nan_count = int (np .isnan (color_vector ).sum ())
150+ logger .warning (
151+ f"Found { nan_count } NaN values in color data. These observations will be colored with the 'na_color'."
152+ )
153+
144154 # Using dict.fromkeys here since set returns in arbitrary order
145155 # remove the color of NaN values, else it might be assigned to a category
146156 # order of color in the palette should agree to order of occurence
@@ -195,7 +205,10 @@ def _render_shapes(
195205
196206 # Handle circles encoded as points with radius
197207 if is_point .any ():
198- scale = shapes [is_point ]["radius" ] * render_params .scale
208+ radius_values = shapes [is_point ]["radius" ]
209+ # Convert to numeric, replacing non-numeric values with NaN
210+ radius_numeric = pd .to_numeric (radius_values , errors = "coerce" )
211+ scale = radius_numeric * render_params .scale
199212 shapes .loc [is_point , "geometry" ] = _geometry [is_point ].buffer (scale .to_numpy ())
200213
201214 # apply transformations to the individual points
@@ -218,6 +231,20 @@ def _render_shapes(
218231
219232 # in case we are coloring by a column in table
220233 if col_for_color is not None and col_for_color not in transformed_element .columns :
234+ # Ensure color vector length matches the number of shapes
235+ if len (color_vector ) != len (transformed_element ):
236+ if len (color_vector ) == 1 :
237+ # If single color, broadcast to all shapes
238+ color_vector = [color_vector [0 ]] * len (transformed_element )
239+ else :
240+ # If lengths don't match, pad or truncate to match
241+ if len (color_vector ) > len (transformed_element ):
242+ color_vector = color_vector [: len (transformed_element )]
243+ else :
244+ # Pad with the last color or na_color
245+ na_color = render_params .cmap_params .na_color .get_hex_with_alpha ()
246+ color_vector = list (color_vector ) + [na_color ] * (len (transformed_element ) - len (color_vector ))
247+
221248 transformed_element [col_for_color ] = color_vector if color_source_vector is None else color_source_vector
222249 # Render shapes with datashader
223250 color_by_categorical = col_for_color is not None and color_source_vector is not None
@@ -447,12 +474,13 @@ def _render_shapes(
447474 path .vertices = trans .transform (path .vertices )
448475
449476 if not values_are_categorical :
450- # If the user passed a Normalize object with vmin/vmax we'll use those,
451- # if not we'll use the min/max of the color_vector
452- _cax .set_clim (
453- vmin = render_params .cmap_params .norm .vmin or min (color_vector ),
454- vmax = render_params .cmap_params .norm .vmax or max (color_vector ),
455- )
477+ vmin = render_params .cmap_params .norm .vmin
478+ vmax = render_params .cmap_params .norm .vmax
479+ if vmin is None :
480+ vmin = float (np .nanmin (color_vector ))
481+ if vmax is None :
482+ vmax = float (np .nanmax (color_vector ))
483+ _cax .set_clim (vmin = vmin , vmax = vmax )
456484
457485 if (
458486 len (set (color_vector )) != 1
0 commit comments