Skip to content

Commit

Permalink
Fix rle computation for masks that begin with a 1
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 8, 2023
1 parent 1ecde49 commit db666af
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
9 changes: 6 additions & 3 deletions micro_sam/_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
out = []
for mask in tensor:
diffs = mask[1:] != mask[:-1] # pairwise unequal (string safe)
indices = np.append(np.where(diffs), n - 1) # must include last element posi
counts = np.diff(np.append(-1, indices)) # run lengths
out.append({"size": [h, w], "counts": counts.tolist()})
indices = np.append(np.where(diffs), n - 1) # must include last element position
# count needs to start with 0 if the mask begins with 1
counts = [] if mask[0] == 0 else [0]
# compute the actual RLE
counts += np.diff(np.append(-1, indices)).tolist()
out.append({"size": [h, w], "counts": counts})
return out
7 changes: 0 additions & 7 deletions test/test_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,7 @@ def _get_mask_to_rle_pytorch_data(self):
one_hot[i, masks == idx] = 1
one_hot = torch.from_numpy(one_hot)

# make sure that all corner pixels are zero
one_hot[:, 0, 0] = 0
one_hot[:, -1, -1] = 0
one_hot[:, 0, -1] = 0
one_hot[:, -1, 0] = 0

expected_result = mask_to_rle_pytorch_sam(one_hot)

return one_hot, expected_result

def test_mask_to_rle_pytorch(self):
Expand Down

0 comments on commit db666af

Please sign in to comment.