Skip to content

Commit

Permalink
Refactor of labels logic (#336)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timtreis and pre-commit-ci[bot] authored Aug 29, 2024
1 parent c6973bd commit 7147a64
Showing 1 changed file with 27 additions and 94 deletions.
121 changes: 27 additions & 94 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,9 @@ def _render_images(
stacked = np.stack([layers[c] for c in channels], axis=-1)
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
# Apply cmaps to each channel, add up and normalize to [0, 1]
stacked = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
)
# Remove alpha channel so we can overwrite it from render_params.alpha
stacked = stacked[:, :, :3]
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
Expand All @@ -676,11 +674,7 @@ def _render_images(
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
Expand All @@ -691,24 +685,16 @@ def _render_images(
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]

# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)

elif palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]

# Apply cmaps to each channel, add up and normalize to [0, 1]
colored = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
Expand Down Expand Up @@ -794,119 +780,66 @@ def _render_labels(
table_name=table_name,
)

# default case: no contour, just fill
# if fill_alpha and outline_alpha are the same, we're technically also at a no-outline situation
if render_params.outline_alpha == 0.0 or render_params.outline_alpha == render_params.fill_alpha:
labels_infill = _map_color_seg(
def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
labels = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=None,
seg_boundaries=False,
seg_erosionpx=seg_erosionpx,
seg_boundaries=seg_boundaries,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
)

_cax = ax.imshow(
labels_infill,
labels,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
alpha=alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
return cax # noqa: RET504

# default case: no contour, just fill
# since contour_px is passed to skimage.morphology.erosion to create the contour,
# any border thickness is only within the label, not outside. Therefore, the case
# of fill_alpha == outline_alpha is equivalent to fill-only
if (render_params.fill_alpha > 0.0 and render_params.outline_alpha == 0.0) or (
render_params.fill_alpha == render_params.outline_alpha
):
cax = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha)
alpha_to_decorate_ax = render_params.fill_alpha

# outline-only case
if render_params.fill_alpha == 0.0 and render_params.outline_alpha != 0.0:
labels_contour = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=True,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0:
cax = _draw_labels(
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
)
_cax = ax.imshow(
labels_contour,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
alpha_to_decorate_ax = render_params.outline_alpha

# pretty case: both outline and infill
if (
render_params.fill_alpha > 0.0
and render_params.outline_alpha > 0.0
and render_params.fill_alpha != render_params.outline_alpha
):
elif render_params.fill_alpha > 0.0 and render_params.outline_alpha > 0.0:
# first plot the infill ...
label_infill = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=None,
seg_boundaries=False,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
)

_cax_infill = ax.imshow(
label_infill,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax_infill.set_transform(trans_data)
cax_infill = ax.add_image(_cax_infill)
cax_infill = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha)

# ... then overlay the contour
label_contour = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=True,
na_color=render_params.cmap_params.na_color,
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
cax_contour = _draw_labels(
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
)

_cax_contour = ax.imshow(
label_contour,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
zorder=render_params.zorder,
)
_cax_contour.set_transform(trans_data)
cax_contour = ax.add_image(_cax_contour)

# pass the less-transparent _cax for the legend
cax = cax_infill if render_params.fill_alpha > render_params.outline_alpha else cax_contour
alpha_to_decorate_ax = max(render_params.fill_alpha, render_params.outline_alpha)

else:
raise ValueError("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0.")

_ = _decorate_axs(
ax=ax,
cax=cax,
Expand Down

0 comments on commit 7147a64

Please sign in to comment.