From 3e8fd1bbe5b7457b3f599a786e0ea20495f74bb2 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:13:02 +1000 Subject: [PATCH 1/5] Vendor batched_mask_to_box function from segment_anything/util/amg.py so we can fix a bug affecting our code --- micro_sam/_vendored.py | 58 ++++++++++++++++++++++++++++++ micro_sam/instance_segmentation.py | 6 ++-- 2 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 micro_sam/_vendored.py diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py new file mode 100644 index 00000000..87635ff1 --- /dev/null +++ b/micro_sam/_vendored.py @@ -0,0 +1,58 @@ +""" +Functions from other third party libraries. + +We can remove these functions once the bug affecting our code is fixed upstream. + +The license type of the thrid party software project must be compatible with +the software license the micro-sam project is distributed under. +""" +import torch + + +# segment_anything/util/amg.py +# https://github.com/facebookresearch/segment-anything +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 50ff0751..cab6147d 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -32,7 +32,7 @@ from . import util from .prompt_based_segmentation import segment_from_mask - +from ._vendored import batched_mask_to_box # # Utility Functionality @@ -196,7 +196,7 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): # recalculate boxes and remove any new duplicates masks = torch.cat(new_masks, dim=0) - boxes = amg_utils.batched_mask_to_box(masks) + boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores, dtype=torch.float), @@ -270,7 +270,7 @@ def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): # threshold masks and calculate boxes data["masks"] = data["masks"] > self._predictor.model.mask_threshold data["masks"] = data["masks"].type(torch.int) - data["boxes"] = amg_utils.batched_mask_to_box(data["masks"]) + data["boxes"] = batched_mask_to_box(data["masks"]) # compress to RLE data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w) From bbcc64958338d65a91ab31b1ccb7d812d6e1cbac Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:14:14 +1000 Subject: [PATCH 2/5] Endure input to batched_mask_to_box is boolean array, otherwise result output is incorrect --- micro_sam/_vendored.py | 1 + 1 file changed, 1 insertion(+) diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index 87635ff1..8ee3ed59 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -16,6 +16,7 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. """ + assert masks.dtype == torch.bool # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: From 72a30dcaac0af31fe46a365f75daa1463a78ec07 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:26:24 +1000 Subject: [PATCH 3/5] Make batched_mask_to_box compatible with MPS Pytorch backend (Apple Silicon) --- micro_sam/_vendored.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index 8ee3ed59..ee3efa66 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -32,16 +32,18 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: # Get top and bottom edges in_height, _ = torch.max(masks, dim=-1) - in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + in_height_coords = in_height * torch.arange(h, dtype=torch.int, device=in_height.device)[None, :] bottom_edges, _ = torch.max(in_height_coords, dim=-1) in_height_coords = in_height_coords + h * (~in_height) + in_height_coords = in_height_coords.type(torch.int) top_edges, _ = torch.min(in_height_coords, dim=-1) # Get left and right edges in_width, _ = torch.max(masks, dim=-2) - in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + in_width_coords = in_width * torch.arange(w, dtype=torch.int, device=in_width.device)[None, :] right_edges, _ = torch.max(in_width_coords, dim=-1) in_width_coords = in_width_coords + w * (~in_width) + in_width_coords = in_width_coords.type(torch.int) left_edges, _ = torch.min(in_width_coords, dim=-1) # If the mask is empty the right edge will be to the left of the left edge. From 2c41d23a678332093854daa0b2f4d9ec42ae1a2a Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 7 Sep 2023 17:12:56 +1000 Subject: [PATCH 4/5] Only pass boolean masks to batched_mask_to_box function (bug means int masks produce incorrect results) --- micro_sam/instance_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index cab6147d..775ae829 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -269,7 +269,7 @@ def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): # threshold masks and calculate boxes data["masks"] = data["masks"] > self._predictor.model.mask_threshold - data["masks"] = data["masks"].type(torch.int) + data["masks"] = data["masks"].type(torch.bool) data["boxes"] = batched_mask_to_box(data["masks"]) # compress to RLE From 6981148bcf2d0e26f386a4b62e0d5affebc36a40 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 7 Sep 2023 17:13:41 +1000 Subject: [PATCH 5/5] Add tests for vendored batched_mask_to_box function from segment_anything/util/amg.py --- test/test_vendored.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 test/test_vendored.py diff --git a/test/test_vendored.py b/test/test_vendored.py new file mode 100644 index 00000000..ddce5476 --- /dev/null +++ b/test/test_vendored.py @@ -0,0 +1,43 @@ +import unittest + +import numpy as np +import torch + + +class TestVendored(unittest.TestCase): + def setUp(self): + mask_numpy = np.zeros((10,10)).astype(bool) + mask_numpy[7:9, 3:5] = True + self.mask = mask_numpy + self.expected_result = [3, 7, 4, 8] + + def test_cpu_batched_mask_to_box(self): + from micro_sam._vendored import batched_mask_to_box + + device = "cpu" + mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device) + expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device) + result = batched_mask_to_box(mask) + assert all(result == expected_result) + + @unittest.skipIf(not torch.cuda.is_available(), + "CUDA Pytorch backend is not available") + def test_cuda_batched_mask_to_box(self): + from micro_sam._vendored import batched_mask_to_box + + device = "cuda" + mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device) + expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device) + result = batched_mask_to_box(mask) + assert all(result == expected_result) + + @unittest.skipIf(not (torch.backends.mps.is_available() and torch.backends.mps.is_built()), + "MPS Pytorch backend is not available") + def test_mps_batched_mask_to_box(self): + from micro_sam._vendored import batched_mask_to_box + + device = "mps" + mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device) + expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device) + result = batched_mask_to_box(mask) + assert all(result == expected_result)