From 7147a6441ae41556f1419ddd0cc38ef60f6d0f74 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 29 Aug 2024 12:50:37 -0400 Subject: [PATCH] Refactor of labels logic (#336) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/spatialdata_plot/pl/render.py | 121 +++++++----------------------- 1 file changed, 27 insertions(+), 94 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 6be76107..d96464c8 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -650,11 +650,9 @@ def _render_images( stacked = np.stack([layers[c] for c in channels], axis=-1) else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels - # Apply cmaps to each channel, add up and normalize to [0, 1] stacked = ( np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels ) - # Remove alpha channel so we can overwrite it from render_params.alpha stacked = stacked[:, :, :3] logger.warning( "One cmap was given for multiple channels and is now used for each channel. " @@ -676,11 +674,7 @@ def _render_images( seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - - # Apply cmaps to each channel and add up colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) - - # Remove alpha channel so we can overwrite it from render_params.alpha colored = colored[:, :, :3] _ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder) @@ -691,24 +685,16 @@ def _render_images( raise ValueError("If 'palette' is provided, its length must match the number of channels.") channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)] - - # Apply cmaps to each channel and add up colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) - - # Remove alpha channel so we can overwrite it from render_params.alpha colored = colored[:, :, :3] _ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder) elif palette is None and got_multiple_cmaps: channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr] - - # Apply cmaps to each channel, add up and normalize to [0, 1] colored = ( np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels ) - - # Remove alpha channel so we can overwrite it from render_params.alpha colored = colored[:, :, :3] _ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder) @@ -794,119 +780,66 @@ def _render_labels( table_name=table_name, ) - # default case: no contour, just fill - # if fill_alpha and outline_alpha are the same, we're technically also at a no-outline situation - if render_params.outline_alpha == 0.0 or render_params.outline_alpha == render_params.fill_alpha: - labels_infill = _map_color_seg( + def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage: + labels = _map_color_seg( seg=label.values, cell_id=instance_id, color_vector=color_vector, color_source_vector=color_source_vector, cmap_params=render_params.cmap_params, - seg_erosionpx=None, - seg_boundaries=False, + seg_erosionpx=seg_erosionpx, + seg_boundaries=seg_boundaries, na_color=render_params.cmap_params.na_color, na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user, ) + _cax = ax.imshow( - labels_infill, + labels, rasterized=True, cmap=None if categorical else render_params.cmap_params.cmap, norm=None if categorical else render_params.cmap_params.norm, - alpha=render_params.fill_alpha, + alpha=alpha, origin="lower", zorder=render_params.zorder, ) _cax.set_transform(trans_data) cax = ax.add_image(_cax) + return cax # noqa: RET504 + + # default case: no contour, just fill + # since contour_px is passed to skimage.morphology.erosion to create the contour, + # any border thickness is only within the label, not outside. Therefore, the case + # of fill_alpha == outline_alpha is equivalent to fill-only + if (render_params.fill_alpha > 0.0 and render_params.outline_alpha == 0.0) or ( + render_params.fill_alpha == render_params.outline_alpha + ): + cax = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha) alpha_to_decorate_ax = render_params.fill_alpha # outline-only case - if render_params.fill_alpha == 0.0 and render_params.outline_alpha != 0.0: - labels_contour = _map_color_seg( - seg=label.values, - cell_id=instance_id, - color_vector=color_vector, - color_source_vector=color_source_vector, - cmap_params=render_params.cmap_params, - seg_erosionpx=render_params.contour_px, - seg_boundaries=True, - na_color=render_params.cmap_params.na_color, - na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user, + elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0: + cax = _draw_labels( + seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha ) - _cax = ax.imshow( - labels_contour, - rasterized=True, - cmap=None if categorical else render_params.cmap_params.cmap, - norm=None if categorical else render_params.cmap_params.norm, - alpha=render_params.outline_alpha, - origin="lower", - zorder=render_params.zorder, - ) - _cax.set_transform(trans_data) - cax = ax.add_image(_cax) alpha_to_decorate_ax = render_params.outline_alpha # pretty case: both outline and infill - if ( - render_params.fill_alpha > 0.0 - and render_params.outline_alpha > 0.0 - and render_params.fill_alpha != render_params.outline_alpha - ): + elif render_params.fill_alpha > 0.0 and render_params.outline_alpha > 0.0: # first plot the infill ... - label_infill = _map_color_seg( - seg=label.values, - cell_id=instance_id, - color_vector=color_vector, - color_source_vector=color_source_vector, - cmap_params=render_params.cmap_params, - seg_erosionpx=None, - seg_boundaries=False, - na_color=render_params.cmap_params.na_color, - na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user, - ) - - _cax_infill = ax.imshow( - label_infill, - rasterized=True, - cmap=None if categorical else render_params.cmap_params.cmap, - norm=None if categorical else render_params.cmap_params.norm, - alpha=render_params.fill_alpha, - origin="lower", - zorder=render_params.zorder, - ) - _cax_infill.set_transform(trans_data) - cax_infill = ax.add_image(_cax_infill) + cax_infill = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha) # ... then overlay the contour - label_contour = _map_color_seg( - seg=label.values, - cell_id=instance_id, - color_vector=color_vector, - color_source_vector=color_source_vector, - cmap_params=render_params.cmap_params, - seg_erosionpx=render_params.contour_px, - seg_boundaries=True, - na_color=render_params.cmap_params.na_color, - na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user, + cax_contour = _draw_labels( + seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha ) - _cax_contour = ax.imshow( - label_contour, - rasterized=True, - cmap=None if categorical else render_params.cmap_params.cmap, - norm=None if categorical else render_params.cmap_params.norm, - alpha=render_params.outline_alpha, - origin="lower", - zorder=render_params.zorder, - ) - _cax_contour.set_transform(trans_data) - cax_contour = ax.add_image(_cax_contour) - # pass the less-transparent _cax for the legend cax = cax_infill if render_params.fill_alpha > render_params.outline_alpha else cax_contour alpha_to_decorate_ax = max(render_params.fill_alpha, render_params.outline_alpha) + else: + raise ValueError("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0.") + _ = _decorate_axs( ax=ax, cax=cax,