diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index ee3efa66..9cbe239d 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -1,11 +1,14 @@ """ Functions from other third party libraries. -We can remove these functions once the bug affecting our code is fixed upstream. +We can remove these functions once the bugs affecting our code is fixed upstream. The license type of the thrid party software project must be compatible with the software license the micro-sam project is distributed under. """ +from typing import Any, Dict, List + +import numpy as np import torch @@ -59,3 +62,28 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: out = out[0] return out + + +# segment_anything/util/amg.py +# https://github.com/facebookresearch/segment-anything +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """Calculates the runlength encoding of binary input masks. + + Implementation based on + https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi + """ + # Put in fortran order and flatten h, w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + tensor = tensor.detach().cpu().numpy() + + n = tensor.shape[1] + + # encode the rle for the individual masks + out = [] + for mask in tensor: + diffs = mask[1:] != mask[:-1] # pairwise unequal (string safe) + indices = np.append(np.where(diffs), n - 1) # must include last element posi + counts = np.diff(np.append(-1, indices)) # run lengths + out.append({"size": [h, w], "counts": counts.tolist()}) + return out diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 775ae829..de6f4bba 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -32,7 +32,7 @@ from . import util from .prompt_based_segmentation import segment_from_mask -from ._vendored import batched_mask_to_box +from ._vendored import batched_mask_to_box, mask_to_rle_pytorch # # Utility Functionality @@ -208,7 +208,8 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): for i_mask in keep_by_nms: if scores[i_mask] == 0.0: mask_torch = masks[i_mask].unsqueeze(0) - mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] + # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0] + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly mask_data.filter(keep_by_nms) @@ -274,7 +275,8 @@ def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): # compress to RLE data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) - data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) + # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"]) + data["rles"] = mask_to_rle_pytorch(data["masks"]) del data["masks"] return data diff --git a/test/test_vendored.py b/test/test_vendored.py index ddce5476..b74fdff2 100644 --- a/test/test_vendored.py +++ b/test/test_vendored.py @@ -3,41 +3,82 @@ import numpy as np import torch +from segment_anything.utils.amg import mask_to_rle_pytorch as mask_to_rle_pytorch_sam +from skimage.draw import random_shapes + class TestVendored(unittest.TestCase): - def setUp(self): - mask_numpy = np.zeros((10,10)).astype(bool) + def _get_mask_to_box_data(): + mask_numpy = np.zeros((10, 10)).astype(bool) mask_numpy[7:9, 3:5] = True - self.mask = mask_numpy - self.expected_result = [3, 7, 4, 8] + expected_result = [3, 7, 4, 8] + return mask_numpy, expected_result - def test_cpu_batched_mask_to_box(self): + def _test_batched_mask_to_box(self, device): from micro_sam._vendored import batched_mask_to_box - device = "cpu" + mask, expected_result = self._get_mask_to_box_data() mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device) expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device) result = batched_mask_to_box(mask) assert all(result == expected_result) + def test_cpu_batched_mask_to_box(self): + self._test_batched_mask_to_box(device="cpu") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA Pytorch backend is not available") def test_cuda_batched_mask_to_box(self): - from micro_sam._vendored import batched_mask_to_box - - device = "cuda" - mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device) - expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device) - result = batched_mask_to_box(mask) - assert all(result == expected_result) + self._test_batched_mask_to_box(device="cuda") @unittest.skipIf(not (torch.backends.mps.is_available() and torch.backends.mps.is_built()), "MPS Pytorch backend is not available") def test_mps_batched_mask_to_box(self): - from micro_sam._vendored import batched_mask_to_box + self._test_batched_mask_to_box(device="mps") - device = "mps" - mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device) - expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device) - result = batched_mask_to_box(mask) - assert all(result == expected_result) + def _get_mask_to_rle_pytorch_data(self): + shape = (128, 256) + + # randm shapes for 6 masks + n_masks = 6 + masks, _ = random_shapes(shape, min_shapes=n_masks, max_shapes=n_masks) + masks = masks.astype("uint32").sum(axis=-1) + + bg_id = 765 # bg val is 3 * 255 = 765 + mask_ids = np.setdiff1d(np.unique(masks), bg_id) + + one_hot = np.zeros((len(mask_ids),) + shape, dtype=bool) + for i, idx in enumerate(mask_ids): + one_hot[i, masks == idx] = 1 + one_hot = torch.from_numpy(one_hot) + + # make sure that all corner pixels are zero + one_hot[:, 0, 0] = 0 + one_hot[:, -1, -1] = 0 + one_hot[:, 0, -1] = 0 + one_hot[:, -1, 0] = 0 + + expected_result = mask_to_rle_pytorch_sam(one_hot) + + return one_hot, expected_result + + def test_mask_to_rle_pytorch(self): + from micro_sam._vendored import mask_to_rle_pytorch + + masks, expected_result = self._get_mask_to_rle_pytorch_data() + expected_size = masks.shape[1] * masks.shape[2] + + # make sure that the RLE's are consistent (their sum needs to be equal to the number of pixels) + for rle in expected_result: + assert sum(rle["counts"]) == expected_size, f"{sum(rle['counts'])}, {expected_size}" + + result = mask_to_rle_pytorch(masks) + for rle in result: + assert sum(rle["counts"]) == expected_size, f"{sum(rle['counts'])}, {expected_size}" + + # make sure that the RLE's agree + assert result == expected_result + + +if __name__ == "__main__": + unittest.main()