diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 6cf9a486..a051e61d 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -106,7 +106,7 @@ def automatic_instance_segmentation( ndim = image_data.ndim if ndim is None else ndim if ndim == 2: - if image_data.ndim != 2 or image_data.shape[-1] != 3: + if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") # Precompute the image embeddings. @@ -135,7 +135,7 @@ def automatic_instance_segmentation( else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) else: - if image_data.ndim != 3 or image_data.shape[-1] != 3: + if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") instances = automatic_3d_segmentation(