Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def render_labels(
norm: Normalize | None = None,
na_color: ColorLike | None = "default",
outline_alpha: float | int = 0.0,
fill_alpha: float | int = 0.4,
fill_alpha: float | int | None = None,
scale: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
Expand Down Expand Up @@ -643,8 +643,9 @@ def render_labels(
won't be shown.
outline_alpha : float | int, default 0.0
Alpha value for the outline of the labels. Invisible by default.
fill_alpha : float | int, default 0.4
Alpha value for the fill of the labels.
fill_alpha : float | int, optional.
Alpha value for the fill of the labels. When no alpha is implied by the passed color, a default value of 0.4
is used.
scale : str | None
Influences the resolution of the rendering. Possibilities for setting this parameter:
1) None (default). The image is rasterized to fit the canvas size. For multiscale images, the best scale
Expand Down Expand Up @@ -702,6 +703,7 @@ def render_labels(
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
groups=param_values["groups"],
contour_px=param_values["contour_px"],
cmap_params=cmap_params,
Expand Down Expand Up @@ -984,13 +986,13 @@ def show(

if wanted_labels_on_this_cs:
if (table := params_copy.table_name) is not None:
assert isinstance(params_copy.color, str)
colors = sc.get.obs_df(sdata[table], [params_copy.color])
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
assert isinstance(params_copy.col_for_color, str)
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=params_copy.color,
key=params_copy.col_for_color,
palette=params_copy.palette,
)

Expand Down
19 changes: 12 additions & 7 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,8 @@ def _render_labels(
table_name = render_params.table_name
table_layer = render_params.table_layer
palette = render_params.palette
color = render_params.color
color = render_params.color.get_hex() if render_params.color else None
col_for_color = render_params.col_for_color
groups = render_params.groups
scale = render_params.scale

Expand Down Expand Up @@ -1137,18 +1138,21 @@ def _render_labels(
sdata=sdata_filt,
element=label,
element_name=element,
value_to_plot=color,
# value_to_plot=color, # TODO
value_to_plot=col_for_color,
groups=groups,
palette=palette,
na_color=render_params.cmap_params.na_color,
# na_color=render_params.cmap_params.na_color, # TODO
na_color=render_params.color if render_params.color is not None else render_params.cmap_params.na_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
render_type="labels",
)

# rasterize could have removed labels from label
# only problematic if color is specified
if rasterize and color is not None:
if rasterize and (color is not None or col_for_color is not None):
labels_in_rasterized_image = np.unique(label.values)
mask = np.isin(instance_id, labels_in_rasterized_image)
instance_id = instance_id[mask]
Expand All @@ -1157,8 +1161,8 @@ def _render_labels(
color_vector = color_vector.remove_unused_categories()
assert color_source_vector is not None
color_source_vector = color_source_vector[mask]
else:
assert color_source_vector is None
# else:
# assert color_source_vector is None # TODO: delete?

def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
labels = _map_color_seg(
Expand Down Expand Up @@ -1228,7 +1232,8 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=color,
# value_to_plot=color, # TODO
value_to_plot=col_for_color,
color_source_vector=color_source_vector,
color_vector=color_vector,
palette=palette,
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ class LabelsRenderParams:

cmap_params: CmapParams
element: str
color: str | None = None
color: Color | None = None
col_for_color: str | None = None
groups: str | list[str] | None = None
contour_px: int | None = None
outline: bool = False
Expand Down
38 changes: 26 additions & 12 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,10 +777,11 @@ def _set_color_source_vec(
alpha: float = 1.0,
table_name: str | None = None,
table_layer: str | None = None,
render_type: Literal["points"] | None = None,
render_type: Literal["points", "labels"] | None = None,
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
if value_to_plot is None and element is not None:
color = np.full(len(element), na_color.get_hex_with_alpha())
n_elements = len(element) if render_type != "labels" else len(dask.array.unique(element.data).compute())
color = np.full(n_elements, na_color.get_hex_with_alpha())
return color, color, False

# Figure out where to get the color from
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def _get_categorical_color_mapping(
alpha: float = 1,
groups: list[str] | str | None = None,
palette: list[str] | str | None = None,
render_type: Literal["points"] | None = None,
render_type: Literal["points", "labels"] | None = None,
) -> Mapping[str, str]:
if not isinstance(color_source_vector, Categorical):
raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")
Expand Down Expand Up @@ -1648,15 +1649,15 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
}:
if not isinstance(color, str | tuple | list):
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
if element_type in {"shapes", "points"}:
if element_type in {"shapes", "points", "labels"}:
if _is_color_like(color):
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
param_dict["col_for_color"] = None
param_dict["color"] = Color(color)
if param_dict["color"].alpha_is_user_defined():
if element_type == "points" and param_dict.get("alpha") is None:
param_dict["alpha"] = param_dict["color"].get_alpha_as_float()
elif element_type == "shapes" and param_dict.get("fill_alpha") is None:
elif element_type in ["shapes", "labels"] and param_dict.get("fill_alpha") is None:
param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float()
else:
logger.info(
Expand All @@ -1668,7 +1669,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
param_dict["color"] = None
else:
raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.")
elif "color" in param_dict and element_type != "labels":
elif "color" in param_dict and element_type != "images":
param_dict["col_for_color"] = None

if outline_width := param_dict.get("outline_width"):
Expand Down Expand Up @@ -1754,6 +1755,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
elif element_type == "shapes":
# set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color)
param_dict["fill_alpha"] = 1.0
elif element_type == "labels":
param_dict["fill_alpha"] = 0.4

if (cmap := param_dict.get("cmap")) is not None and (palette := param_dict.get("palette")) is not None:
raise ValueError("Both `palette` and `cmap` are specified. Please specify only one of them.")
Expand Down Expand Up @@ -1894,7 +1897,7 @@ def _validate_label_render_params(
element: str | None,
cmap: list[Colormap | str] | Colormap | str | None,
color: str | None,
fill_alpha: float | int,
fill_alpha: float | int | None,
contour_px: int | None,
groups: list[str] | str | None,
palette: list[str] | str | None,
Expand Down Expand Up @@ -1939,12 +1942,23 @@ def _validate_label_render_params(
element_params[el]["table_layer"] = param_dict["table_layer"]

element_params[el]["table_name"] = None
element_params[el]["color"] = None
color = param_dict["color"]
if color is not None:
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)

# element_params[el]["color"] = None # TODO: delete
# color = param_dict["color"]
# if color is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still in debugging or ready @Sonja-Stockhaus ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ready for tests and I think I wanted to remove the comments afterwards

# color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"],
# labels=True)
# element_params[el]["table_name"] = table_name
# element_params[el]["color"] = color
element_params[el]["color"] = param_dict["color"]

element_params[el]["col_for_color"] = None
if (col_for_color := param_dict["col_for_color"]) is not None:
col_for_color, table_name = _validate_col_for_column_table(
sdata, el, col_for_color, param_dict["table_name"], labels=True
)
element_params[el]["table_name"] = table_name
element_params[el]["color"] = color
element_params[el]["col_for_color"] = col_for_color

element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None
element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None
Expand Down