From 031dae1115d5bb72e2d99951dc9e052d92661e25 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:24:04 +0100 Subject: [PATCH] Minor fix to casting masks in AMG post-processing (#780) * Minor fix to casting masks in nms * Cast to torch.bool --- micro_sam/_vendored.py | 4 ++-- micro_sam/instance_segmentation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index 976f8b4a7..864a9a684 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -29,7 +29,7 @@ def njit(func): def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: - """Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask. + """Calculates boxes in XYXY format around masks. Return [0, 0, 0, 0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. @@ -38,7 +38,7 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: It further ensures that inputs are boolean tensors, otherwise the function yields wrong results. See https://github.com/facebookresearch/segment-anything/issues/552 for details. """ - assert masks.dtype == torch.bool + assert masks.dtype == torch.bool, masks.dtype # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index c03c7a4fa..44968bf29 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -215,7 +215,7 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): # recalculate boxes and remove any new duplicates masks = torch.cat(new_masks, dim=0) - boxes = batched_mask_to_box(masks) + boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores, dtype=torch.float),