diff --git a/CHANGELOG.md b/CHANGELOG.md index 139a647..473cf0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning][]. ### Added -- the datashader reduction method is now a user parameter (#309) +- The user can now specify `datashader_reduction` to control the rendering behaviour (#309) ### Changed diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 0dbe7a9..2b502b0 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -227,10 +227,15 @@ def render_shapes( Name of the table containing the color(s) columns. If one name is given than the table is used for each spatial element to be plotted if the table annotates it. If you want to use different tables for particular elements, as specified under element. + **kwargs : Any - Additional arguments to be passed to cmap and norm. And: - datashader_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"]|None - reduction method to use when using datashader and coloring by continuous values. Default: ds.sum() + Additional arguments for customization. This can include: + + datashader_reduction : Literal[ + "sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min" + ], default: "sum" + Reduction method for datashader when coloring by continuous values. Defaults to 'sum'. + Notes ----- @@ -259,13 +264,13 @@ def render_shapes( scale=scale, table_name=table_name, method=method, + ds_reduction=kwargs.get("datashader_reduction", None), ) sdata = self._copy() sdata = _verify_plotting_tree(sdata) n_steps = len(sdata.plotting_tree.keys()) outline_params = _set_outline(outline_alpha > 0, outline_width, outline_color) - for element, param_values in params_dict.items(): cmap_params = _prepare_cmap_norm( cmap=cmap, @@ -287,8 +292,8 @@ def render_shapes( transfunc=kwargs.get("transfunc", None), table_name=param_values["table_name"], zorder=n_steps, - method=method, - reduction=kwargs.get("datashader_reduction", None), + method=param_values["method"], + ds_reduction=param_values["ds_reduction"], ) n_steps += 1 @@ -358,10 +363,14 @@ def render_points( Name of the table containing the color(s) columns. If one name is given than the table is used for each spatial element to be plotted if the table annotates it. If you want to use different tables for particular elements, as specified under element. - kwargs - Additional arguments to be passed to cmap and norm. And: - datashader_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"]|None - reduction method to use when using datashader and coloring by continuous values. Default: ds.sum() + + **kwargs : Any + Additional arguments for customization. This can include: + + datashader_reduction : Literal[ + "sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min" + ], default: "sum" + Reduction method for datashader when coloring by continuous values. Defaults to 'sum'. Returns ------- @@ -412,7 +421,7 @@ def render_points( table_name=param_values["table_name"], zorder=n_steps, method=method, - reduction=kwargs.get("datashader_reduction", None), + ds_reduction=kwargs.get("datashader_reduction", None), ) n_steps += 1 diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 6b8fae1..27a467a 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -158,13 +158,18 @@ def _render_shapes( # Determine which method to use for rendering method = render_params.method + if method is None: method = "datashader" if len(shapes) > 10000 else "matplotlib" - elif method not in ["matplotlib", "datashader"]: - raise ValueError("Method must be either 'matplotlib' or 'datashader'.") + if method != "matplotlib": # we only notify the user when we switched away from matplotlib - logger.info(f"Using '{method}' as plotting backend.") + logger.info( + f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction" + " method to speed up plotting. Depending on the reduction method, the value" + " range of the plot might change. Set method to 'matplotlib' do disable" + " this behaviour." + ) if method == "datashader": trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData @@ -197,13 +202,13 @@ def _render_shapes( sdata_filt.shapes[element], geometry="geometry", agg=ds.by(col_for_color, ds.count()) ) else: - reduction_name = render_params.reduction if render_params.reduction is not None else "sum" + reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum" logger.info( f'Using the datashader reduction "{reduction_name}". "max" will give an output very close ' "to the matplotlib result." ) agg = _datashader_aggregate_with_function( - render_params.reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes" + render_params.ds_reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes" ) # save min and max values for drawing the colorbar aggregate_with_reduction = (agg.min(), agg.max()) @@ -437,13 +442,13 @@ def _render_points( if color_by_categorical: agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.by(col_for_color, ds.count())) else: - reduction_name = render_params.reduction if render_params.reduction is not None else "sum" + reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum" logger.info( f'Using the datashader reduction "{reduction_name}". "max" will give an output very close ' "to the matplotlib result." ) agg = _datashader_aggregate_with_function( - render_params.reduction, cvs, sdata_filt.points[element], col_for_color, "points" + render_params.ds_reduction, cvs, sdata_filt.points[element], col_for_color, "points" ) # save min and max values for drawing the colorbar aggregate_with_reduction = (agg.min(), agg.max()) @@ -475,7 +480,7 @@ def _render_points( how="linear", ) else: - spread_how = _datshader_get_how_kw_for_spread(render_params.reduction) + spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction) agg = ds.tf.spread(agg, px=px, how=spread_how) aggregate_with_reduction = (agg.min(), agg.max()) ds_result = ds.tf.shade( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 55c0854..b6fb1ed 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -89,7 +89,7 @@ class ShapesRenderParams: method: str | None = None zorder: int = 0 table_name: str | None = None - reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] | None = None + ds_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] | None = None @dataclass @@ -108,7 +108,7 @@ class PointsRenderParams: method: str | None = None zorder: int = 0 table_name: str | None = None - reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var"] | None = None + ds_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var"] | None = None @dataclass diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 98a1e8f..81ac62e 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1507,7 +1507,6 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st "Parameter 'element' must be a string. If you want to display more elements, pass `element` " "as `None` or chain pl.render(...).pl.render(...).pl.show()" ) - if element_type == "images": param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].images.keys()) elif element_type == "labels": @@ -1638,9 +1637,20 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if param_dict.get("table_name") and not isinstance(param_dict["table_name"], str): raise TypeError("Parameter 'table_name' must be a string .") - if param_dict.get("method") not in ["matplotlib", "datashader", None]: + # like this because the following would assign True/False to 'method' + # method := param_dict.get("method") not in ["matplotlib", "datashader", None] + method = param_dict.get("method") + if method not in ["matplotlib", "datashader", None]: raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.") + valid_ds_reduction_methods = ["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] + ds_reduction = param_dict.get("ds_reduction") + if ds_reduction and (ds_reduction not in valid_ds_reduction_methods): + raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.") + + if method == "datashader" and ds_reduction is None: + param_dict["ds_reduction"] = "sum" + return param_dict @@ -1778,6 +1788,7 @@ def _validate_shape_render_params( scale: float | int, table_name: str | None, method: str | None, + ds_reduction: str | None, ) -> dict[str, dict[str, Any]]: param_dict: dict[str, Any] = { "sdata": sdata, @@ -1795,6 +1806,7 @@ def _validate_shape_render_params( "scale": scale, "table_name": table_name, "method": method, + "ds_reduction": ds_reduction, } param_dict = _type_check_params(param_dict, "shapes") @@ -1828,6 +1840,7 @@ def _validate_shape_render_params( element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None element_params[el]["method"] = param_dict["method"] + element_params[el]["ds_reduction"] = param_dict["ds_reduction"] return element_params @@ -2041,9 +2054,10 @@ def _create_image_from_datashader_result( ds_result: ds.transfer_functions.Image, factor: float, ax: Axes ) -> tuple[MaskedArray[np.float64, Any], matplotlib.transforms.CompositeGenericTransform]: # create SpatialImage from datashader output to get it back to original size - rgba_image = np.transpose(ds_result.to_numpy().base, (2, 0, 1)) + rgba_image_data = ds_result.to_numpy().base + rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) rgba_image = Image2DModel.parse( - rgba_image, + rgba_image_data, dims=("c", "y", "x"), transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))}, ) @@ -2085,40 +2099,47 @@ def _datashader_aggregate_with_function( reduction = "sum" reduction_function_map = { - "sum": ds.sum(column=col_for_color), - "mean": ds.mean(column=col_for_color), - "any": ds.any(column=col_for_color), - "count": ds.count(column=col_for_color), - "m2": ds.reductions.m2(column=col_for_color), - "mode": ds.reductions.mode(column=col_for_color), - "std": ds.std(column=col_for_color), - "var": ds.var(column=col_for_color), - "max": ds.max(column=col_for_color), - "min": ds.min(column=col_for_color), + "sum": ds.sum, + "mean": ds.mean, + "any": ds.any, + "count": ds.count, + "m2": ds.reductions.m2, + "mode": ds.reductions.mode, + "std": ds.std, + "var": ds.var, + "max": ds.max, + "min": ds.min, } - if reduction not in reduction_function_map: - raise ValueError( - f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count, m2, mode" - ", std, var, max, min." - ) - if element_type not in ["points", "shapes"]: + try: + reduction_function = reduction_function_map[reduction](column=col_for_color) + except KeyError as e: raise ValueError( - f"utils._datashader_aggregate_with_function() should only be called with points or shapes, not with" - f" {element_type}." - ) + f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(reduction_function_map.keys())}." + ) from e + + element_function_map = { + "points": cvs.points, + "shapes": cvs.polygons, + } + + try: + element_function = element_function_map[element_type] + except KeyError as e: + raise ValueError(f"Element type '{element_type}' is not supported. Use 'points' or 'shapes'.") from e if element_type == "points": - return cvs.points(spatial_element, "x", "y", agg=reduction_function_map[reduction]) - return cvs.polygons(spatial_element, geometry="geometry", agg=reduction_function_map[reduction]) + return element_function(spatial_element, "x", "y", agg=reduction_function) + + # is shapes + return element_function(spatial_element, geometry="geometry", agg=reduction_function) def _datshader_get_how_kw_for_spread( reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] | None ) -> str: # Get the best input for the how argument of ds.tf.spread(), needed for numerical values - if reduction is None: - reduction = "sum" + reduction = reduction or "sum" reduction_to_how_map = { "sum": "add",