diff --git a/samgeo/common.py b/samgeo/common.py index 6b3e042b..047dbf34 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -1931,7 +1931,7 @@ def sam_map_gui(sam, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwar description="Mask opacity:", min=0, max=1, - value=0.5, + value=0.7, readout=True, continuous_update=True, layout=widgets.Layout(width=widget_width, padding=padding), @@ -2182,12 +2182,20 @@ def segment_button_click(change): filename = f"masks_{random_string()}.tif" filename = os.path.join(out_dir, filename) - sam.predict( - point_coords=point_coords, - point_labels=point_labels, - point_crs="EPSG:4326", - output=filename, - ) + if sam.model_version == "sam": + sam.predict( + point_coords=point_coords, + point_labels=point_labels, + point_crs="EPSG:4326", + output=filename, + ) + elif sam.model_version == "sam2": + sam.predict_by_points( + point_coords_batch=point_coords, + point_labels_batch=point_labels, + point_crs="EPSG:4326", + output=filename, + ) if m.find_layer("Masks") is not None: m.remove_layer(m.find_layer("Masks")) if m.find_layer("Regularized") is not None: @@ -2200,18 +2208,16 @@ def segment_button_click(change): os.remove(sam.prediction_fp) except: pass - # Skip the image layer if localtileserver is not available try: m.add_raster( filename, nodata=0, - cmap="Blues", + cmap="tab20", opacity=opacity_slider.value, layer_name="Masks", zoom_to_layer=False, ) - if rectangular.value: vector = filename.replace(".tif", ".gpkg") vector_rec = filename.replace(".tif", "_rect.gpkg") diff --git a/samgeo/samgeo.py b/samgeo/samgeo.py index a30d3573..9d0d7e01 100644 --- a/samgeo/samgeo.py +++ b/samgeo/samgeo.py @@ -73,6 +73,7 @@ def __init__( self.checkpoint = checkpoint self.model_type = model_type + self.model_version = "sam" self.device = device self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model self.source = None # Store the input image path diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index 15c7b759..b1b29e74 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -133,6 +133,7 @@ def __init__( hydra_overrides_extra = [] self.model_id = model_id + self.model_version = "sam2" self.device = device if video: @@ -701,8 +702,9 @@ def predict_by_points( point_crs: Optional[str] = None, output: Optional[str] = None, index: Optional[int] = None, + unique: bool = True, mask_multiplier: int = 255, - dtype: str = "float32", + dtype: str = "int32", return_results: bool = False, **kwargs: Any, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -733,7 +735,7 @@ def predict_by_points( which will save the mask with the highest score. mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1]. - dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32. + dtype (np.dtype, optional): The data type of the output image. Defaults to np.int32. return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False. @@ -771,12 +773,28 @@ def predict_by_points( labels = point_labels_batch elif isinstance(point_coords_batch, list): - points = point_coords_batch - num_points = points.shape[0] + if point_crs is not None: + point_coords_batch = common.coords_to_xy( + self.source, point_coords_batch, point_crs + ) + points = [] + points.append([[point] for point in point_coords_batch]) + + num_points = len(points) if point_labels_batch is None: labels = np.array([[1] for i in range(num_points)]) + elif isinstance(point_labels_batch, list): + labels = [] + labels.append([[label] for label in point_labels_batch]) else: labels = point_labels_batch + + points = np.array(points[0]) + labels = np.array(labels[0]) + + elif isinstance(point_coords_batch, np.ndarray): + points = point_coords_batch + labels = point_labels_batch else: raise ValueError("point_coords must be a list, a GeoDataFrame, or a path.") @@ -813,7 +831,13 @@ def predict_by_points( self.logits = logits if output is not None: - self.save_prediction(output, index, mask_multiplier, dtype, **kwargs) + self.save_masks( + output, + foreground=True, + unique=unique, + mask_multiplier=mask_multiplier, + **kwargs, + ) if return_results: return output_masks, scores, logits