diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index 9cbe239d..f574012a 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -83,7 +83,10 @@ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: out = [] for mask in tensor: diffs = mask[1:] != mask[:-1] # pairwise unequal (string safe) - indices = np.append(np.where(diffs), n - 1) # must include last element posi - counts = np.diff(np.append(-1, indices)) # run lengths - out.append({"size": [h, w], "counts": counts.tolist()}) + indices = np.append(np.where(diffs), n - 1) # must include last element position + # count needs to start with 0 if the mask begins with 1 + counts = [] if mask[0] == 0 else [0] + # compute the actual RLE + counts += np.diff(np.append(-1, indices)).tolist() + out.append({"size": [h, w], "counts": counts}) return out diff --git a/test/test_vendored.py b/test/test_vendored.py index 07429b4b..a0d7f703 100644 --- a/test/test_vendored.py +++ b/test/test_vendored.py @@ -52,14 +52,7 @@ def _get_mask_to_rle_pytorch_data(self): one_hot[i, masks == idx] = 1 one_hot = torch.from_numpy(one_hot) - # make sure that all corner pixels are zero - one_hot[:, 0, 0] = 0 - one_hot[:, -1, -1] = 0 - one_hot[:, 0, -1] = 0 - one_hot[:, -1, 0] = 0 - expected_result = mask_to_rle_pytorch_sam(one_hot) - return one_hot, expected_result def test_mask_to_rle_pytorch(self):