diff --git a/dask_image/ndmeasure/__init__.py b/dask_image/ndmeasure/__init__.py index 5f4eaf48..396d6ebf 100644 --- a/dask_image/ndmeasure/__init__.py +++ b/dask_image/ndmeasure/__init__.py @@ -4,12 +4,12 @@ import functools import operator import warnings -from dask import compute, delayed import dask.array as da import dask.bag as db import dask.dataframe as dd import numpy as np +from dask import compute, delayed from . import _utils from ._utils import _label @@ -378,7 +378,6 @@ def label(image, structure=None, wrap_axes=None): relabeled = _label.relabel_blocks(block_labeled, new_labeling) n = da.max(relabeled) - return (relabeled, n) @@ -402,7 +401,9 @@ def labeled_comprehension(image, Parameters ---------- image : ndarray - N-D image data + Intensity image with same size as ``label_image``, plus optionally + an extra dimension for multichannel data. The extra channel dimension, + if present, must be the last axis. label_image : ndarray, optional Image features noted by integers. If None (default), all values. index : int or sequence of ints, optional @@ -441,14 +442,16 @@ def labeled_comprehension(image, args = (image,) if pass_positions: positions = _utils._ravel_shape_indices( - image.shape, chunks=image.chunks + image.shape, chunks=image.chunks, skip_trailing_dim=image.ndim != label_image.ndim ) args = (image, positions) result = np.empty(index.shape, dtype=object) for i in np.ndindex(index.shape): lbl_mtch_i = (label_image == index[i]) - args_lbl_mtch_i = tuple(e[lbl_mtch_i] for e in args) + args_lbl_mtch_i = tuple( + e[lbl_mtch_i] if e.ndim == lbl_mtch_i.ndim else e.reshape(-1, e.shape[-1])[lbl_mtch_i.reshape(-1)] for e in + args) result[i] = _utils._labeled_comprehension_func( func, out_dtype, default_1d, *args_lbl_mtch_i ) diff --git a/dask_image/ndmeasure/_utils/__init__.py b/dask_image/ndmeasure/_utils/__init__.py index bcb5b70f..28b0e1dd 100644 --- a/dask_image/ndmeasure/_utils/__init__.py +++ b/dask_image/ndmeasure/_utils/__init__.py @@ -31,9 +31,10 @@ def _norm_input_labels_index(image, label_image=None, index=None): FutureWarning ) - if image.shape != label_image.shape: + image_shape = image.shape if image.ndim == label_image.ndim else image.shape[:-1] + if image_shape != label_image.shape: # allow trailing channel raise ValueError( - "The image and label_image arrays must be the same shape." + f"The image and label_image arrays must be the same shape. {image_shape} != {label_image.shape}" ) return (image, label_image, index) @@ -47,7 +48,7 @@ def _ravel_shape_indices_kernel(*args): return sum(args2) -def _ravel_shape_indices(dimensions, dtype=int, chunks=None): +def _ravel_shape_indices(dimensions, dtype=int, chunks=None, skip_trailing_dim:bool=False): """ Gets the raveled indices shaped like input. """ @@ -60,7 +61,7 @@ def _ravel_shape_indices(dimensions, dtype=int, chunks=None): dtype=dtype, chunks=c ) - for i, c in enumerate(chunks) + for i, c in enumerate(chunks[:-1] if skip_trailing_dim else chunks) ] indices = da.blockwise(