From 5dc9727309d9954ba7c734b67e89e97ed4d258cb Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sun, 13 Oct 2024 21:35:40 -0400 Subject: [PATCH] Add support for box prompts for SamGeo2 --- docs/examples/box_prompts.ipynb | 6 +- docs/examples/sam2_box_prompts.ipynb | 429 +++++++++++++++++++++++++++ mkdocs.yml | 1 + samgeo/common.py | 16 +- samgeo/samgeo2.py | 56 +--- 5 files changed, 457 insertions(+), 51 deletions(-) create mode 100644 docs/examples/sam2_box_prompts.ipynb diff --git a/docs/examples/box_prompts.ipynb b/docs/examples/box_prompts.ipynb index 344c9e5a..65a8a187 100644 --- a/docs/examples/box_prompts.ipynb +++ b/docs/examples/box_prompts.ipynb @@ -9,7 +9,7 @@ "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/box_prompts.ipynb)\n", "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/box_prompts.ipynb)\n", "\n", - "This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM). \n", + "This notebook shows how to generate object masks from box prompts with the Segment Anything Model (SAM). \n", "\n", "Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. " ] @@ -131,10 +131,6 @@ "source": [ "## Initialize SAM class\n", "\n", - "The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.\n", - "\n", - "Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.\n", - "\n", "Set `automatic=False` to disable the `SamAutomaticMaskGenerator` and enable the `SamPredictor`." ] }, diff --git a/docs/examples/sam2_box_prompts.ipynb b/docs/examples/sam2_box_prompts.ipynb new file mode 100644 index 00000000..a5eb0316 --- /dev/null +++ b/docs/examples/sam2_box_prompts.ipynb @@ -0,0 +1,429 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Segmenting remote sensing imagery with box prompts\n", + "\n", + "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/sam2_box_prompts.ipynb)\n", + "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/sam2_box_prompts.ipynb)\n", + "\n", + "This notebook shows how to generate object masks from box prompts with the Segment Anything Model 2 (SAM 2). \n", + "\n", + "Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install dependencies\n", + "\n", + "Uncomment and run the following cell to install the required dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install segment-geospatial" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import leafmap\n", + "from samgeo import SamGeo2, raster_to_vector, regularize" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an interactive map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m = leafmap.Map(center=[47.653287, -117.588070], zoom=16, height=\"800px\")\n", + "m.add_basemap(\"Satellite\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download a sample image\n", + "\n", + "Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if m.user_roi is not None:\n", + " bbox = m.user_roi_bounds()\n", + "else:\n", + " bbox = [-117.6029, 47.65, -117.5936, 47.6563]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = \"satellite.tif\"\n", + "leafmap.map_tiles_to_geotiff(\n", + " output=image, bbox=bbox, zoom=18, source=\"Satellite\", overwrite=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also use your own image. Uncomment and run the following cell to use your own image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# image = '/path/to/your/own/image.tif'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the downloaded image on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.layers[-1].visible = False\n", + "m.add_raster(image, layer_name=\"Image\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SAM class\n", + "\n", + "Set `automatic=False` to enable the `SAM2ImagePredictor`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam = SamGeo2(\n", + " model_id=\"sam2-hiera-large\",\n", + " automatic=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the image to segment. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.set_image(image)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the map. Use the drawing tools to draw some rectangles around the features you want to extract, such as trees, buildings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create bounding boxes\n", + "\n", + "If no rectangles are drawn, the default bounding boxes will be used as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if m.user_rois is not None:\n", + " boxes = m.user_rois\n", + "else:\n", + " boxes = [\n", + " [-117.5995, 47.6518, -117.5988, 47.652],\n", + " [-117.5987, 47.6518, -117.5979, 47.652],\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Segment the image\n", + "\n", + "Use the `predict()` method to segment the image with specified bounding boxes. The `boxes` parameter accepts a list of bounding box coordinates in the format of [[left, bottom, right, top], [left, bottom, right, top], ...], a GeoJSON dictionary, or a file path to a GeoJSON file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.predict(boxes=boxes, point_crs=\"EPSG:4326\", output=\"mask.tif\", dtype=\"uint8\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display the result\n", + "\n", + "Add the segmented image to the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.add_raster(\"mask.tif\", cmap=\"viridis\", nodata=0, layer_name=\"Mask\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use an existing vector file as box prompts\n", + "\n", + "Alternatively, you can specify a file path to a vector file. Let's download a sample vector file from GitHub." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://github.com/opengeos/datasets/releases/download/samgeo/building_bboxes.geojson\"\n", + "geojson = \"building_bboxes.geojson\"\n", + "leafmap.download_file(url, geojson)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the vector data on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m = leafmap.Map()\n", + "m.add_raster(image, layer_name=\"Image\")\n", + "style = {\n", + " \"color\": \"#ffff00\",\n", + " \"weight\": 2,\n", + " \"fillColor\": \"#7c4185\",\n", + " \"fillOpacity\": 0,\n", + "}\n", + "m.add_vector(geojson, style=style, zoom_to_layer=True, layer_name=\"Bboxes\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/95e8d2a5-9354-4694-b928-195a85bbb2e6)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Segment image with box prompts\n", + "\n", + "Segment the image using the specified file path to the vector mask." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_masks = \"building_masks.tif\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.predict(\n", + " boxes=geojson,\n", + " point_crs=\"EPSG:4326\",\n", + " output=output_masks,\n", + " dtype=\"uint8\",\n", + " multimask_output=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the segmented masks on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.add_raster(\n", + " output_masks, cmap=\"jet\", nodata=0, opacity=0.5, layer_name=\"Building masks\"\n", + ")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/6f2d4f1f-dfc1-4dfa-8acb-642e1afb9c4a)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert raster to vector" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_vector = \"building_vector.geojson\"\n", + "raster_to_vector(output_masks, output_vector)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Regularize building footprints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_regularized = \"building_regularized.geojson\"\n", + "regularize(output_vector, output_regularized)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.add_vector(\n", + " output_regularized, style=style, layer_name=\"Building regularized\", info_mode=None\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/c4b77056-9fd1-4ce8-9740-1b9d4f993040)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index d235dc2b..34178871 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - examples/sam2_automatic.ipynb - examples/sam2_predictor.ipynb - examples/sam2_video.ipynb + - examples/sam2_box_prompts.ipynb - Workshops: - workshops/purdue.ipynb - workshops/cn_workshop.ipynb diff --git a/samgeo/common.py b/samgeo/common.py index d316a360..724b4f65 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -406,13 +406,12 @@ def tms_to_geotiff( """ - import os + import re import io import math import itertools import concurrent.futures - import numpy from PIL import Image try: @@ -490,7 +489,18 @@ def resolution_to_zoom_level(resolution): gdal.UseExceptions() web_mercator = osr.SpatialReference() - web_mercator.ImportFromEPSG(3857) + try: + web_mercator.ImportFromEPSG(3857) + except RuntimeError as e: + # https://github.com/PDAL/PDAL/issues/2544#issuecomment-637995923 + if "PROJ" in str(e): + pattern = r"/[\w/]+" + match = re.search(pattern, str(e)) + if match: + file_path = match.group(0) + os.environ["PROJ_LIB"] = file_path + os.environ["GDAL_DATA"] = file_path.replace("proj", "gdal") + web_mercator.ImportFromEPSG(3857) WKT_3857 = web_mercator.ExportToWkt() diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index 0860fb17..a7820e4d 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -657,42 +657,21 @@ def predict( if isinstance(boxes, list) and (point_crs is not None): coords = common.bbox_to_xy(self.source, boxes, point_crs) input_boxes = np.array(coords) - if isinstance(coords[0], int): - input_boxes = input_boxes[None, :] - else: - input_boxes = torch.tensor(input_boxes, device=self.device) - input_boxes = predictor.transform.apply_boxes_torch( - input_boxes, self.image.shape[:2] - ) + elif isinstance(boxes, list) and (point_crs is None): input_boxes = np.array(boxes) - if isinstance(boxes[0], int): - input_boxes = input_boxes[None, :] self.boxes = input_boxes - if ( - boxes is None - or (len(boxes) == 1) - or (len(boxes) == 4 and isinstance(boxes[0], float)) - ): - if isinstance(boxes, list) and isinstance(boxes[0], list): - boxes = boxes[0] - masks, scores, logits = predictor.predict( - point_coords, - point_labels, - input_boxes, - mask_input, - multimask_output, - return_logits, - ) - else: - masks, scores, logits = predictor.predict_torch( - point_coords=point_coords, - point_labels=point_coords, - boxes=input_boxes, - multimask_output=True, - ) + masks, scores, logits = predictor.predict( + point_coords=point_coords, + point_labels=point_labels, + box=input_boxes, + mask_input=mask_input, + multimask_output=multimask_output, + return_logits=return_logits, + normalize_coords=normalize_coords, + ) self.masks = masks self.scores = scores @@ -709,16 +688,6 @@ def predict( if return_results: return masks, scores, logits - return self.predictor.predict( - point_coords=point_coords, - point_labels=point_labels, - box=boxes, - mask_input=mask_input, - multimask_output=multimask_output, - return_logits=return_logits, - normalize_coords=normalize_coords, - ) - def predict_batch( self, point_coords_batch: List[np.ndarray] = None, @@ -933,10 +902,11 @@ def tensor_to_numpy( image_np = np.array(image_pil) if index is None: - index = 1 + index = 0 masks = masks[:, index, :, :] - masks = masks.squeeze(1) + if len(masks.shape) == 4 and masks.shape[1] == 1: + masks = masks.squeeze(1) if boxes is None or (len(boxes) == 0): # No "object" instances found print("No objects found in the image.")