Skip to content

Commit

Permalink
add ds shapes outlines, fix shapes fill_alpha behavior, fix std/var/a…
Browse files Browse the repository at this point in the history
…ny reductions
  • Loading branch information
Sonja Stockhaus committed Oct 1, 2024
1 parent ceb4fd2 commit 59a19da
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def render_shapes(
Additional arguments for customization. This can include:
datashader_reduction : Literal[
"sum", "mean", "any", "count", "mode", "std", "var", "max", "min"
"sum", "mean", "any", "count", "std", "var", "max", "min"
], default: "sum"
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
Expand Down Expand Up @@ -368,7 +368,7 @@ def render_points(
Additional arguments for customization. This can include:
datashader_reduction : Literal[
"sum", "mean", "any", "count", "mode", "std", "var", "max", "min"
"sum", "mean", "any", "count", "std", "var", "max", "min"
], default: "sum"
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
Expand Down
70 changes: 52 additions & 18 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ def _render_shapes(
aggregate_with_reduction = (agg.min(), agg.max())
else:
agg = cvs.polygons(sdata_filt.shapes[element], geometry="geometry", agg=ds.count())
# render outlines if needed
render_outlines = render_params.outline_alpha > 0
if render_outlines:
agg_outlines = cvs.line(
sdata_filt.shapes[element],
geometry="geometry",
line_width=render_params.outline_params.linewidth,
)

color_key = (
[x[:-2] for x in color_vector.categories.values]
Expand All @@ -222,36 +230,56 @@ def _render_shapes(
else None
)

ds_cmap = None
if color_vector is not None:
ds_cmap = color_vector[0]
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
ds_cmap = ds_cmap[:-2]
# if color_vector is not None and (
# isinstance(color_vector[0], str) and len(color_vector[0]) == 9 and color_vector[0][0] == "#"
# ):
# color_vector = [x[:-2] for x in color_vector]

ds_result = (
ds.tf.shade(
if color_by_categorical or col_for_color is None:
ds_cmap = None
if color_vector is not None:
ds_cmap = color_vector[0]
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
ds_cmap = ds_cmap[:-2]

ds_result = ds.tf.shade(
agg,
cmap=ds_cmap,
color_key=color_key,
min_alpha=np.min([254, render_params.fill_alpha * 255]),
how="linear",
)
if color_by_categorical or col_for_color is None
else ds.tf.shade(
elif aggregate_with_reduction is not None: # to shut up mypy
ds_cmap = render_params.cmap_params.cmap
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)

ds_result = ds.tf.shade(
agg,
cmap=render_params.cmap_params.cmap,
cmap=ds_cmap,
how="linear",
min_alpha=np.min([254, render_params.fill_alpha * 255]),
)

# shade outlines if needed
outline_color = render_params.outline_params.outline_color
if isinstance(outline_color, str) and outline_color.startswith("#") and len(outline_color) == 9:
outline_color = outline_color[:-2]

if render_outlines:
ds_outlines = ds.tf.shade(
agg_outlines,
cmap=outline_color,
min_alpha=np.min([254, render_params.outline_alpha * 255]),
how="linear",
)
)

rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_cax = _ax_show_and_transform(
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.fill_alpha
)
# render outline image if needed
if render_outlines:
rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax)
_ax_show_and_transform(
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.outline_alpha
)

cax = None
if aggregate_with_reduction is not None:
Expand Down Expand Up @@ -491,16 +519,22 @@ def _render_points(
ds.tf.spread(agg, px=px),
cmap=color_vector[0],
color_key=color_key,
min_alpha=np.min([255, render_params.alpha * 255]), # value 150 is arbitrarily chosen
min_alpha=np.min([254, render_params.alpha * 255]), # value 150 is arbitrarily chosen
how="linear",
)
else:
spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction)
agg = ds.tf.spread(agg, px=px, how=spread_how)
aggregate_with_reduction = (agg.min(), agg.max())

ds_cmap = render_params.cmap_params.cmap
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)

ds_result = ds.tf.shade(
agg,
cmap=render_params.cmap_params.cmap,
cmap=ds_cmap,
how="linear",
)

Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ShapesRenderParams:
method: str | None = None
zorder: int = 0
table_name: str | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "mode", "std", "var", "max", "min"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


@dataclass
Expand All @@ -108,7 +108,7 @@ class PointsRenderParams:
method: str | None = None
zorder: int = 0
table_name: str | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "mode", "std", "var", "max", "min"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


@dataclass
Expand Down
17 changes: 10 additions & 7 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
"any",
"count",
# "m2", -> not intended to be used alone (see https://datashader.org/api.html#datashader.reductions.m2)
"mode",
# "mode", -> not supported for points (see https://datashader.org/api.html#datashader.reductions.mode)
"std",
"var",
"max",
Expand Down Expand Up @@ -2086,7 +2086,7 @@ def _create_image_from_datashader_result(


def _datashader_aggregate_with_function(
reduction: Literal["sum", "mean", "any", "count", "mode", "std", "var", "max", "min"] | None,
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
cvs: Canvas,
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
col_for_color: str | None,
Expand Down Expand Up @@ -2114,7 +2114,6 @@ def _datashader_aggregate_with_function(
"mean": ds.mean,
"any": ds.any,
"count": ds.count,
"mode": ds.reductions.mode,
"std": ds.std,
"var": ds.var,
"max": ds.max,
Expand All @@ -2139,14 +2138,19 @@ def _datashader_aggregate_with_function(
raise ValueError(f"Element type '{element_type}' is not supported. Use 'points' or 'shapes'.") from e

if element_type == "points":
return element_function(spatial_element, "x", "y", agg=reduction_function)
points_aggregate = element_function(spatial_element, "x", "y", agg=reduction_function)
if reduction == "any":
# replace False/True by nan/1
points_aggregate = points_aggregate.astype(int)
points_aggregate = points_aggregate.where(points_aggregate > 0)
return points_aggregate

# is shapes
return element_function(spatial_element, geometry="geometry", agg=reduction_function)


def _datshader_get_how_kw_for_spread(
reduction: Literal["sum", "mean", "any", "count", "mode", "std", "var", "max", "min"] | None
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None
) -> str:
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
reduction = reduction or "sum"
Expand All @@ -2156,7 +2160,6 @@ def _datshader_get_how_kw_for_spread(
"mean": "source",
"any": "source",
"count": "add",
"mode": "source",
"std": "source",
"var": "source",
"max": "max",
Expand All @@ -2165,7 +2168,7 @@ def _datshader_get_how_kw_for_spread(

if reduction not in reduction_to_how_map:
raise ValueError(
f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count, mode"
f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count"
", std, var, max, min."
)

Expand Down
5 changes: 0 additions & 5 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ def test_plot_datashader_can_use_count_as_reduction(self, sdata_blobs: SpatialDa
element="blobs_points", size=40, color="instance_id", method="datashader", datashader_reduction="count"
).pl.show()

def test_plot_datashader_can_use_mode_as_reduction(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(
element="blobs_points", size=40, color="instance_id", method="datashader", datashader_reduction="mode"
).pl.show()

def test_plot_datashader_can_use_std_as_reduction(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(
element="blobs_points", size=40, color="instance_id", method="datashader", datashader_reduction="std"
Expand Down
36 changes: 36 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def test_plot_shapes_categorical_color(self, sdata_blobs: SpatialData):
def test_plot_datashader_can_render_shapes(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader").pl.show()

def test_plot_datashader_can_render_colored_shapes(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", color="red").pl.show()

def test_plot_datashader_can_render_with_different_alpha(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", fill_alpha=0.7).pl.show()

def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData):
RNG = np.random.default_rng(seed=42)
n_obs = len(sdata_blobs["blobs_polygons"])
Expand All @@ -317,8 +323,38 @@ def test_plot_datashader_can_color_by_value(self, sdata_blobs: SpatialData):
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()

def test_plot_datashader_can_color_by_identical_value(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 1, 1, 1, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()

def test_plot_datashader_shades_with_linear_cmap(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 2, 1, 20, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()

def test_plot_datashader_can_render_with_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_alpha=1).pl.show()

def test_plot_datashader_can_render_with_diff_alpha_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_alpha=0.5).pl.show()

def test_plot_datashader_can_render_with_diff_width_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_width=5.0).pl.show()

def test_plot_datashader_can_render_with_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color="red"
).pl.show()

def test_plot_datashader_can_render_with_rgb_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0)
).pl.show()

def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0)
).pl.show()

0 comments on commit 59a19da

Please sign in to comment.