diff --git a/src/omero_metadata/rois.py b/src/omero_metadata/rois.py index ccafe90a..2e889aa4 100644 --- a/src/omero_metadata/rois.py +++ b/src/omero_metadata/rois.py @@ -33,7 +33,8 @@ def __init__(self, msg='Invalid labels found'): def mask_from_binary_image( - binim, rgba=None, z=None, c=None, t=None, text=None): + binim, rgba=None, z=None, c=None, t=None, text=None, + raise_on_no_mask=True): """ Create a mask shape from a binary image (background=0) :param numpy.array binim: Binary 2D array, must contain values [0, 1] only @@ -42,6 +43,8 @@ def mask_from_binary_image( :param c: Optional C-index for the mask :param t: Optional T-index for the mask :param text: Optional text for the mask + :param raise_on_no_mask: If True (default) throw an exception if no mask + found, otherwise return an empty Mask :return: An OMERO mask :raises NoMaskFound: If no labels were found :raises InvalidBinaryImage: If the maximum labels is greater than 1 @@ -49,16 +52,22 @@ def mask_from_binary_image( # Find bounding box to minimise size of mask xmask = binim.sum(0).nonzero()[0] ymask = binim.sum(1).nonzero()[0] - if not any(xmask) and not any(ymask): - raise NoMaskFound() - - x0 = min(xmask) - w = max(xmask) - x0 + 1 - y0 = min(ymask) - h = max(ymask) - y0 + 1 - submask = binim[y0:(y0 + h), x0:(x0 + w)] - if not np.array_equal(np.unique(submask), [0, 1]): - raise InvalidBinaryImage() + if any(xmask) and any(ymask): + x0 = min(xmask) + w = max(xmask) - x0 + 1 + y0 = min(ymask) + h = max(ymask) - y0 + 1 + submask = binim[y0:(y0 + h), x0:(x0 + w)] + if not np.array_equal(np.unique(submask), [0, 1]): + raise InvalidBinaryImage() + else: + if raise_on_no_mask: + raise NoMaskFound() + x0 = 0 + w = 0 + y0 = 0 + h = 0 + submask = [] mask = MaskI() # BUG in older versions of Numpy: @@ -87,7 +96,8 @@ def mask_from_binary_image( def masks_from_label_image( - labelim, rgba=None, z=None, c=None, t=None, text=None): + labelim, rgba=None, z=None, c=None, t=None, text=None, + raise_on_no_mask=True): """ Create mask shapes from a label image (background=0) :param numpy.array labelim: 2D label array @@ -96,13 +106,13 @@ def masks_from_label_image( :param c: Optional C-index for the mask :param t: Optional T-index for the mask :param text: Optional text for the mask + :param raise_on_no_mask: If True (default) throw an exception if no mask + found, otherwise return an empty Mask :return: A list of OMERO masks in label order ([] if no labels found) """ masks = [] for i in xrange(1, labelim.max() + 1): - try: - mask = mask_from_binary_image(labelim == i, rgba, z, c, t, text) - masks.append(mask) - except NoMaskFound: - pass + mask = mask_from_binary_image(labelim == i, rgba, z, c, t, text, + raise_on_no_mask) + masks.append(mask) return masks diff --git a/test/unit/test_masks.py b/test/unit/test_masks.py index 7590119e..a9cab0a0 100644 --- a/test/unit/test_masks.py +++ b/test/unit/test_masks.py @@ -14,6 +14,7 @@ from omero_metadata import ( mask_from_binary_image, masks_from_label_image, + NoMaskFound, ) @@ -96,3 +97,39 @@ def test_masks_from_label_image(self, label_image, args): assert unwrap(mask.getTheC()) is None assert unwrap(mask.getTheT()) is None assert unwrap(mask.getTextValue()) is None + + @pytest.mark.parametrize('args', [ + {}, + {'rgba': (255, 128, 64, 128), 'z': 1, 'c': 2, 't': 3, 'text': 'test'} + ]) + @pytest.mark.parametrize('raise_on_no_mask', [ + True, + False, + ]) + def test_empty_mask_from_binary_image(self, args, raise_on_no_mask): + empty_binary_image = np.array([[0]]) + if raise_on_no_mask: + with pytest.raises(NoMaskFound): + mask = mask_from_binary_image( + empty_binary_image, raise_on_no_mask=raise_on_no_mask, + **args) + else: + mask = mask_from_binary_image( + empty_binary_image, raise_on_no_mask=raise_on_no_mask, + **args) + assert unwrap(mask.getWidth()) == 0 + assert unwrap(mask.getHeight()) == 0 + assert unwrap(mask.getX()) == 0 + assert unwrap(mask.getY()) == 0 + assert np.array_equal(mask.getBytes(), []) + + if args: + assert unwrap(mask.getTheZ()) == 1 + assert unwrap(mask.getTheC()) == 2 + assert unwrap(mask.getTheT()) == 3 + assert unwrap(mask.getTextValue()) == 'test' + else: + assert unwrap(mask.getTheZ()) is None + assert unwrap(mask.getTheC()) is None + assert unwrap(mask.getTheT()) is None + assert unwrap(mask.getTextValue()) is None