Skip to content

Commit

Permalink
Fix save masks bug
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 19, 2024
1 parent a67328a commit 4cea1ab
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
26 changes: 16 additions & 10 deletions samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,7 +1931,7 @@ def sam_map_gui(sam, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwar
description="Mask opacity:",
min=0,
max=1,
value=0.5,
value=0.7,
readout=True,
continuous_update=True,
layout=widgets.Layout(width=widget_width, padding=padding),
Expand Down Expand Up @@ -2182,12 +2182,20 @@ def segment_button_click(change):

filename = f"masks_{random_string()}.tif"
filename = os.path.join(out_dir, filename)
sam.predict(
point_coords=point_coords,
point_labels=point_labels,
point_crs="EPSG:4326",
output=filename,
)
if sam.model_version == "sam":
sam.predict(
point_coords=point_coords,
point_labels=point_labels,
point_crs="EPSG:4326",
output=filename,
)
elif sam.model_version == "sam2":
sam.predict_by_points(
point_coords_batch=point_coords,
point_labels_batch=point_labels,
point_crs="EPSG:4326",
output=filename,
)
if m.find_layer("Masks") is not None:
m.remove_layer(m.find_layer("Masks"))
if m.find_layer("Regularized") is not None:
Expand All @@ -2200,18 +2208,16 @@ def segment_button_click(change):
os.remove(sam.prediction_fp)
except:
pass

# Skip the image layer if localtileserver is not available
try:
m.add_raster(
filename,
nodata=0,
cmap="Blues",
cmap="tab20",
opacity=opacity_slider.value,
layer_name="Masks",
zoom_to_layer=False,
)

if rectangular.value:
vector = filename.replace(".tif", ".gpkg")
vector_rec = filename.replace(".tif", "_rect.gpkg")
Expand Down
1 change: 1 addition & 0 deletions samgeo/samgeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(

self.checkpoint = checkpoint
self.model_type = model_type
self.model_version = "sam"
self.device = device
self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
self.source = None # Store the input image path
Expand Down
34 changes: 29 additions & 5 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
hydra_overrides_extra = []

self.model_id = model_id
self.model_version = "sam2"
self.device = device

if video:
Expand Down Expand Up @@ -701,8 +702,9 @@ def predict_by_points(
point_crs: Optional[str] = None,
output: Optional[str] = None,
index: Optional[int] = None,
unique: bool = True,
mask_multiplier: int = 255,
dtype: str = "float32",
dtype: str = "int32",
return_results: bool = False,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -733,7 +735,7 @@ def predict_by_points(
which will save the mask with the highest score.
mask_multiplier (int, optional): The mask multiplier for the output mask,
which is usually a binary mask [0, 1].
dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
dtype (np.dtype, optional): The data type of the output image. Defaults to np.int32.
return_results (bool, optional): Whether to return the predicted masks,
scores, and logits. Defaults to False.
Expand Down Expand Up @@ -771,12 +773,28 @@ def predict_by_points(
labels = point_labels_batch

elif isinstance(point_coords_batch, list):
points = point_coords_batch
num_points = points.shape[0]
if point_crs is not None:
point_coords_batch = common.coords_to_xy(
self.source, point_coords_batch, point_crs
)
points = []
points.append([[point] for point in point_coords_batch])

num_points = len(points)
if point_labels_batch is None:
labels = np.array([[1] for i in range(num_points)])
elif isinstance(point_labels_batch, list):
labels = []
labels.append([[label] for label in point_labels_batch])
else:
labels = point_labels_batch

points = np.array(points[0])
labels = np.array(labels[0])

elif isinstance(point_coords_batch, np.ndarray):
points = point_coords_batch
labels = point_labels_batch
else:
raise ValueError("point_coords must be a list, a GeoDataFrame, or a path.")

Expand Down Expand Up @@ -813,7 +831,13 @@ def predict_by_points(
self.logits = logits

if output is not None:
self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
self.save_masks(
output,
foreground=True,
unique=unique,
mask_multiplier=mask_multiplier,
**kwargs,
)

if return_results:
return output_masks, scores, logits
Expand Down

0 comments on commit 4cea1ab

Please sign in to comment.