Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
timtreis committed Sep 14, 2024
1 parent a2b66e1 commit 8b0b24d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 49 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning][].

### Added

- the datashader reduction method is now a user parameter (#309)
- The user can now specify `datashader_reduction` to control the rendering behaviour (#309)

### Changed

Expand Down
31 changes: 20 additions & 11 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,15 @@ def render_shapes(
Name of the table containing the color(s) columns. If one name is given than the table is used for each
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
elements, as specified under element.
**kwargs : Any
Additional arguments to be passed to cmap and norm. And:
datashader_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"]|None
reduction method to use when using datashader and coloring by continuous values. Default: ds.sum()
Additional arguments for customization. This can include:
datashader_reduction : Literal[
"sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"
], default: "sum"
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
Notes
-----
Expand Down Expand Up @@ -259,13 +264,13 @@ def render_shapes(
scale=scale,
table_name=table_name,
method=method,
ds_reduction=kwargs.get("datashader_reduction", None),
)

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())
outline_params = _set_outline(outline_alpha > 0, outline_width, outline_color)

for element, param_values in params_dict.items():
cmap_params = _prepare_cmap_norm(
cmap=cmap,
Expand All @@ -287,8 +292,8 @@ def render_shapes(
transfunc=kwargs.get("transfunc", None),
table_name=param_values["table_name"],
zorder=n_steps,
method=method,
reduction=kwargs.get("datashader_reduction", None),
method=param_values["method"],
ds_reduction=param_values["ds_reduction"],
)
n_steps += 1

Expand Down Expand Up @@ -358,10 +363,14 @@ def render_points(
Name of the table containing the color(s) columns. If one name is given than the table is used for each
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
elements, as specified under element.
kwargs
Additional arguments to be passed to cmap and norm. And:
datashader_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"]|None
reduction method to use when using datashader and coloring by continuous values. Default: ds.sum()
**kwargs : Any
Additional arguments for customization. This can include:
datashader_reduction : Literal[
"sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"
], default: "sum"
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
Returns
-------
Expand Down Expand Up @@ -412,7 +421,7 @@ def render_points(
table_name=param_values["table_name"],
zorder=n_steps,
method=method,
reduction=kwargs.get("datashader_reduction", None),
ds_reduction=kwargs.get("datashader_reduction", None),
)
n_steps += 1

Expand Down
21 changes: 13 additions & 8 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,18 @@ def _render_shapes(

# Determine which method to use for rendering
method = render_params.method

if method is None:
method = "datashader" if len(shapes) > 10000 else "matplotlib"
elif method not in ["matplotlib", "datashader"]:
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")

if method != "matplotlib":
# we only notify the user when we switched away from matplotlib
logger.info(f"Using '{method}' as plotting backend.")
logger.info(
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
" method to speed up plotting. Depending on the reduction method, the value"
" range of the plot might change. Set method to 'matplotlib' do disable"
" this behaviour."
)

if method == "datashader":
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
Expand Down Expand Up @@ -197,13 +202,13 @@ def _render_shapes(
sdata_filt.shapes[element], geometry="geometry", agg=ds.by(col_for_color, ds.count())
)
else:
reduction_name = render_params.reduction if render_params.reduction is not None else "sum"
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
logger.info(
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
"to the matplotlib result."
)
agg = _datashader_aggregate_with_function(
render_params.reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes"
render_params.ds_reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes"
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
Expand Down Expand Up @@ -437,13 +442,13 @@ def _render_points(
if color_by_categorical:
agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.by(col_for_color, ds.count()))
else:
reduction_name = render_params.reduction if render_params.reduction is not None else "sum"
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
logger.info(
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
"to the matplotlib result."
)
agg = _datashader_aggregate_with_function(
render_params.reduction, cvs, sdata_filt.points[element], col_for_color, "points"
render_params.ds_reduction, cvs, sdata_filt.points[element], col_for_color, "points"
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
Expand Down Expand Up @@ -475,7 +480,7 @@ def _render_points(
how="linear",
)
else:
spread_how = _datshader_get_how_kw_for_spread(render_params.reduction)
spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction)
agg = ds.tf.spread(agg, px=px, how=spread_how)
aggregate_with_reduction = (agg.min(), agg.max())
ds_result = ds.tf.shade(
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ShapesRenderParams:
method: str | None = None
zorder: int = 0
table_name: str | None = None
reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] | None = None


@dataclass
Expand All @@ -108,7 +108,7 @@ class PointsRenderParams:
method: str | None = None
zorder: int = 0
table_name: str | None = None
reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var"] | None = None


@dataclass
Expand Down
75 changes: 48 additions & 27 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,6 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
"Parameter 'element' must be a string. If you want to display more elements, pass `element` "
"as `None` or chain pl.render(...).pl.render(...).pl.show()"
)

if element_type == "images":
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].images.keys())
elif element_type == "labels":
Expand Down Expand Up @@ -1638,9 +1637,20 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
if param_dict.get("table_name") and not isinstance(param_dict["table_name"], str):
raise TypeError("Parameter 'table_name' must be a string .")

if param_dict.get("method") not in ["matplotlib", "datashader", None]:
# like this because the following would assign True/False to 'method'
# method := param_dict.get("method") not in ["matplotlib", "datashader", None]
method = param_dict.get("method")
if method not in ["matplotlib", "datashader", None]:
raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.")

valid_ds_reduction_methods = ["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"]
ds_reduction = param_dict.get("ds_reduction")
if ds_reduction and (ds_reduction not in valid_ds_reduction_methods):
raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.")

if method == "datashader" and ds_reduction is None:
param_dict["ds_reduction"] = "sum"

return param_dict


Expand Down Expand Up @@ -1778,6 +1788,7 @@ def _validate_shape_render_params(
scale: float | int,
table_name: str | None,
method: str | None,
ds_reduction: str | None,
) -> dict[str, dict[str, Any]]:
param_dict: dict[str, Any] = {
"sdata": sdata,
Expand All @@ -1795,6 +1806,7 @@ def _validate_shape_render_params(
"scale": scale,
"table_name": table_name,
"method": method,
"ds_reduction": ds_reduction,
}
param_dict = _type_check_params(param_dict, "shapes")

Expand Down Expand Up @@ -1828,6 +1840,7 @@ def _validate_shape_render_params(
element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None
element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None
element_params[el]["method"] = param_dict["method"]
element_params[el]["ds_reduction"] = param_dict["ds_reduction"]

return element_params

Expand Down Expand Up @@ -2041,9 +2054,10 @@ def _create_image_from_datashader_result(
ds_result: ds.transfer_functions.Image, factor: float, ax: Axes
) -> tuple[MaskedArray[np.float64, Any], matplotlib.transforms.CompositeGenericTransform]:
# create SpatialImage from datashader output to get it back to original size
rgba_image = np.transpose(ds_result.to_numpy().base, (2, 0, 1))
rgba_image_data = ds_result.to_numpy().base
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
rgba_image = Image2DModel.parse(
rgba_image,
rgba_image_data,
dims=("c", "y", "x"),
transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))},
)
Expand Down Expand Up @@ -2085,40 +2099,47 @@ def _datashader_aggregate_with_function(
reduction = "sum"

reduction_function_map = {
"sum": ds.sum(column=col_for_color),
"mean": ds.mean(column=col_for_color),
"any": ds.any(column=col_for_color),
"count": ds.count(column=col_for_color),
"m2": ds.reductions.m2(column=col_for_color),
"mode": ds.reductions.mode(column=col_for_color),
"std": ds.std(column=col_for_color),
"var": ds.var(column=col_for_color),
"max": ds.max(column=col_for_color),
"min": ds.min(column=col_for_color),
"sum": ds.sum,
"mean": ds.mean,
"any": ds.any,
"count": ds.count,
"m2": ds.reductions.m2,
"mode": ds.reductions.mode,
"std": ds.std,
"var": ds.var,
"max": ds.max,
"min": ds.min,
}

if reduction not in reduction_function_map:
raise ValueError(
f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count, m2, mode"
", std, var, max, min."
)
if element_type not in ["points", "shapes"]:
try:
reduction_function = reduction_function_map[reduction](column=col_for_color)
except KeyError as e:
raise ValueError(
f"utils._datashader_aggregate_with_function() should only be called with points or shapes, not with"
f" {element_type}."
)
f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(reduction_function_map.keys())}."
) from e

element_function_map = {
"points": cvs.points,
"shapes": cvs.polygons,
}

try:
element_function = element_function_map[element_type]
except KeyError as e:
raise ValueError(f"Element type '{element_type}' is not supported. Use 'points' or 'shapes'.") from e

if element_type == "points":
return cvs.points(spatial_element, "x", "y", agg=reduction_function_map[reduction])
return cvs.polygons(spatial_element, geometry="geometry", agg=reduction_function_map[reduction])
return element_function(spatial_element, "x", "y", agg=reduction_function)

# is shapes
return element_function(spatial_element, geometry="geometry", agg=reduction_function)


def _datshader_get_how_kw_for_spread(
reduction: Literal["sum", "mean", "any", "count", "m2", "mode", "std", "var", "max", "min"] | None
) -> str:
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
if reduction is None:
reduction = "sum"
reduction = reduction or "sum"

reduction_to_how_map = {
"sum": "add",
Expand Down

0 comments on commit 8b0b24d

Please sign in to comment.