Skip to content

Commit

Permalink
Fix save masks bug (#332)
Browse files Browse the repository at this point in the history
* Fix save masks bug

* Fix save masks bug for SAM 1
  • Loading branch information
giswqs authored Oct 5, 2024
1 parent dab8b84 commit 837f5a5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 4 additions & 3 deletions samgeo/samgeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def save_masks(

# Generate a mask of objects with unique values
if unique:
# Sort the masks by area in ascending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)
# Sort the masks by area in descending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)

# Create an output image with the same size as the input image
objects = np.zeros(
Expand All @@ -276,9 +276,10 @@ def save_masks(
)
)
# Assign a unique value to each object
count = len(sorted_masks)
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
objects[m] = index + 1
objects[m] = count - index

# Generate a binary mask
else:
Expand Down
7 changes: 4 additions & 3 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ def save_masks(

# Generate a mask of objects with unique values
if unique:
# Sort the masks by area in ascending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)
# Sort the masks by area in descending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)

# Create an output image with the same size as the input image
objects = np.zeros(
Expand All @@ -304,9 +304,10 @@ def save_masks(
)
)
# Assign a unique value to each object
count = len(sorted_masks)
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
objects[m] = index + 1
objects[m] = count - index

# Generate a binary mask
else:
Expand Down

0 comments on commit 837f5a5

Please sign in to comment.