Skip to content

Commit

Permalink
Minor fix to casting masks in AMG post-processing (#780)
Browse files Browse the repository at this point in the history
* Minor fix to casting masks in nms

* Cast to torch.bool
  • Loading branch information
anwai98 authored Nov 14, 2024
1 parent 9b055c3 commit 031dae1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions micro_sam/_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def njit(func):


def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask.
"""Calculates boxes in XYXY format around masks. Return [0, 0, 0, 0] for an empty mask.
For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
Expand All @@ -38,7 +38,7 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
It further ensures that inputs are boolean tensors, otherwise the function yields wrong results.
See https://github.com/facebookresearch/segment-anything/issues/552 for details.
"""
assert masks.dtype == torch.bool
assert masks.dtype == torch.bool, masks.dtype

# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):

# recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels.
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores, dtype=torch.float),
Expand Down

0 comments on commit 031dae1

Please sign in to comment.