Skip to content

Commit

Permalink
Implement numpy based RLE function
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 8, 2023
1 parent 3db6c0e commit f403034
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 23 deletions.
30 changes: 29 additions & 1 deletion micro_sam/_vendored.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
Functions from other third party libraries.
We can remove these functions once the bug affecting our code is fixed upstream.
We can remove these functions once the bugs affecting our code is fixed upstream.
The license type of the thrid party software project must be compatible with
the software license the micro-sam project is distributed under.
"""
from typing import Any, Dict, List

import numpy as np
import torch


Expand Down Expand Up @@ -59,3 +62,28 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
out = out[0]

return out


# segment_anything/util/amg.py
# https://github.com/facebookresearch/segment-anything
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""Calculates the runlength encoding of binary input masks.
Implementation based on
https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
"""
# Put in fortran order and flatten h, w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)
tensor = tensor.detach().cpu().numpy()

n = tensor.shape[1]

# encode the rle for the individual masks
out = []
for mask in tensor:
diffs = mask[1:] != mask[:-1] # pairwise unequal (string safe)
indices = np.append(np.where(diffs), n - 1) # must include last element posi
counts = np.diff(np.append(-1, indices)) # run lengths
out.append({"size": [h, w], "counts": counts.tolist()})
return out
8 changes: 5 additions & 3 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from . import util
from .prompt_based_segmentation import segment_from_mask
from ._vendored import batched_mask_to_box
from ._vendored import batched_mask_to_box, mask_to_rle_pytorch

#
# Utility Functionality
Expand Down Expand Up @@ -208,7 +208,8 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
# mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
mask_data.filter(keep_by_nms)

Expand Down Expand Up @@ -274,7 +275,8 @@ def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):

# compress to RLE
data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
# data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]

return data
Expand Down
79 changes: 60 additions & 19 deletions test/test_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,82 @@
import numpy as np
import torch

from segment_anything.utils.amg import mask_to_rle_pytorch as mask_to_rle_pytorch_sam
from skimage.draw import random_shapes


class TestVendored(unittest.TestCase):
def setUp(self):
mask_numpy = np.zeros((10,10)).astype(bool)
def _get_mask_to_box_data():
mask_numpy = np.zeros((10, 10)).astype(bool)
mask_numpy[7:9, 3:5] = True
self.mask = mask_numpy
self.expected_result = [3, 7, 4, 8]
expected_result = [3, 7, 4, 8]
return mask_numpy, expected_result

def test_cpu_batched_mask_to_box(self):
def _test_batched_mask_to_box(self, device):
from micro_sam._vendored import batched_mask_to_box

device = "cpu"
mask, expected_result = self._get_mask_to_box_data()
mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device)
expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device)
result = batched_mask_to_box(mask)
assert all(result == expected_result)

def test_cpu_batched_mask_to_box(self):
self._test_batched_mask_to_box(device="cpu")

@unittest.skipIf(not torch.cuda.is_available(),
"CUDA Pytorch backend is not available")
def test_cuda_batched_mask_to_box(self):
from micro_sam._vendored import batched_mask_to_box

device = "cuda"
mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device)
expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device)
result = batched_mask_to_box(mask)
assert all(result == expected_result)
self._test_batched_mask_to_box(device="cuda")

@unittest.skipIf(not (torch.backends.mps.is_available() and torch.backends.mps.is_built()),
"MPS Pytorch backend is not available")
def test_mps_batched_mask_to_box(self):
from micro_sam._vendored import batched_mask_to_box
self._test_batched_mask_to_box(device="mps")

device = "mps"
mask = torch.as_tensor(self.mask, dtype=torch.bool, device=device)
expected_result = torch.as_tensor(self.expected_result, dtype=torch.int, device=device)
result = batched_mask_to_box(mask)
assert all(result == expected_result)
def _get_mask_to_rle_pytorch_data(self):
shape = (128, 256)

# randm shapes for 6 masks
n_masks = 6
masks, _ = random_shapes(shape, min_shapes=n_masks, max_shapes=n_masks)
masks = masks.astype("uint32").sum(axis=-1)

bg_id = 765 # bg val is 3 * 255 = 765
mask_ids = np.setdiff1d(np.unique(masks), bg_id)

one_hot = np.zeros((len(mask_ids),) + shape, dtype=bool)
for i, idx in enumerate(mask_ids):
one_hot[i, masks == idx] = 1
one_hot = torch.from_numpy(one_hot)

# make sure that all corner pixels are zero
one_hot[:, 0, 0] = 0
one_hot[:, -1, -1] = 0
one_hot[:, 0, -1] = 0
one_hot[:, -1, 0] = 0

expected_result = mask_to_rle_pytorch_sam(one_hot)

return one_hot, expected_result

def test_mask_to_rle_pytorch(self):
from micro_sam._vendored import mask_to_rle_pytorch

masks, expected_result = self._get_mask_to_rle_pytorch_data()
expected_size = masks.shape[1] * masks.shape[2]

# make sure that the RLE's are consistent (their sum needs to be equal to the number of pixels)
for rle in expected_result:
assert sum(rle["counts"]) == expected_size, f"{sum(rle['counts'])}, {expected_size}"

result = mask_to_rle_pytorch(masks)
for rle in result:
assert sum(rle["counts"]) == expected_size, f"{sum(rle['counts'])}, {expected_size}"

# make sure that the RLE's agree
assert result == expected_result


if __name__ == "__main__":
unittest.main()

0 comments on commit f403034

Please sign in to comment.