From f375e923a52b9044a01b4cb39de816ec2b8aae7e Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sun, 6 Oct 2024 23:47:17 -0500 Subject: [PATCH] Add region_groups function (#334) --- requirements.txt | 3 + samgeo/common.py | 167 +++++++++++++++++++++++++++++++++++++++++++++- samgeo/samgeo2.py | 54 +++++++++++++++ 3 files changed, 223 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a68d8398..50c7ca1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,10 @@ patool pycocotools pyproj rasterio +rioxarray sam2 +scikit-image +scikit-learn segment-anything-hq segment-anything-py timm diff --git a/samgeo/common.py b/samgeo/common.py index 860e1a28..b145c069 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -7,7 +7,7 @@ import cv2 import numpy as np from tqdm import tqdm -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple, Any import shapely import pyproj import rasterio @@ -3520,3 +3520,168 @@ def geotiff_to_jpg_batch(input_folder: str, output_folder: str = None) -> str: geotiff_to_jpg(geotiff_path, output_path) return output_folder + + +def region_groups( + image: Union[str, "xr.DataArray", np.ndarray], + connectivity: int = 1, + min_size: int = 10, + max_size: Optional[int] = None, + threshold: Optional[int] = None, + properties: Optional[List[str]] = None, + out_csv: Optional[str] = None, + out_vector: Optional[str] = None, + out_image: Optional[str] = None, + **kwargs: Any, +) -> Union[Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]]: + """ + Segment regions in an image and filter them based on size. + + Args: + image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file + path, xarray DataArray, or numpy array. + connectivity (int, optional): Connectivity for labeling. Defaults to 1 + for 4-connectivity. Use 2 for 8-connectivity. + min_size (int, optional): Minimum size of regions to keep. Defaults to 10. + max_size (Optional[int], optional): Maximum size of regions to keep. + Defaults to None. + threshold (Optional[int], optional): Threshold for filling holes. + Defaults to None, which is equal to min_size. + properties (Optional[List[str]], optional): List of properties to measure. + See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops + Defaults to None. + out_csv (Optional[str], optional): Path to save the properties as a CSV file. + Defaults to None. + out_vector (Optional[str], optional): Path to save the vector file. + Defaults to None. + out_image (Optional[str], optional): Path to save the output image. + Defaults to None. + + Returns: + Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame. + """ + import rioxarray as rxr + import xarray as xr + from skimage import measure + import pandas as pd + import scipy.ndimage as ndi + + if isinstance(image, str): + ds = rxr.open_rasterio(image) + da = ds.sel(band=1) + array = da.values.squeeze() + elif isinstance(image, xr.DataArray): + da = image + array = image.values.squeeze() + elif isinstance(image, np.ndarray): + array = image + else: + raise ValueError( + "The input image must be a file path, xarray DataArray, or numpy array." + ) + + if threshold is None: + threshold = min_size + + if properties is None: + properties = [ + "label", + "area", + "area_bbox", + "area_convex", + "area_filled", + "axis_major_length", + "axis_minor_length", + "eccentricity", + "equivalent_diameter_area", + "extent", + "orientation", + "perimeter", + "solidity", + ] + + label_image = measure.label(array, connectivity=connectivity) + props = measure.regionprops_table(label_image, properties=properties) + + df = pd.DataFrame(props) + + # Get the labels of regions with area smaller than the threshold + small_regions = df[df["area"] < min_size]["label"].values + # Set the corresponding labels in the label_image to zero + for region_label in small_regions: + label_image[label_image == region_label] = 0 + + if max_size is not None: + large_regions = df[df["area"] > max_size]["label"].values + for region_label in large_regions: + label_image[label_image == region_label] = 0 + + # Find the background (holes) which are zeros + holes = label_image == 0 + + # Label the holes (connected components in the background) + labeled_holes, _ = ndi.label(holes) + + # Measure properties of the labeled holes, including area and bounding box + hole_props = measure.regionprops(labeled_holes) + + # Loop through each hole and fill it if it is smaller than the threshold + for prop in hole_props: + if prop.area < threshold: + # Get the coordinates of the small hole + coords = prop.coords + + # Find the surrounding region's ID (non-zero value near the hole) + surrounding_region_values = [] + for coord in coords: + x, y = coord + # Get a 3x3 neighborhood around the hole pixel + neighbors = label_image[max(0, x - 1) : x + 2, max(0, y - 1) : y + 2] + # Exclude the hole pixels (zeros) and get region values + region_values = neighbors[neighbors != 0] + if region_values.size > 0: + surrounding_region_values.append( + region_values[0] + ) # Take the first non-zero value + + if surrounding_region_values: + # Fill the hole with the mode (most frequent) of the surrounding region values + fill_value = max( + set(surrounding_region_values), key=surrounding_region_values.count + ) + label_image[coords[:, 0], coords[:, 1]] = fill_value + + label_image, num_labels = measure.label( + label_image, connectivity=connectivity, return_num=True + ) + props = measure.regionprops_table(label_image, properties=properties) + + df = pd.DataFrame(props) + df["elongation"] = df["axis_major_length"] / df["axis_minor_length"] + + dtype = "uint8" + if num_labels > 255 and num_labels <= 65535: + dtype = "uint16" + elif num_labels > 65535: + dtype = "uint32" + + if out_csv is not None: + df.to_csv(out_csv, index=False) + + if isinstance(image, np.ndarray): + return label_image, df + else: + da.values = label_image + if out_image is not None: + da.rio.to_raster(out_image, dtype=dtype) + if out_vector is not None: + tmp_vector = temp_file_path(".gpkg") + raster_to_vector(out_image, tmp_vector) + gdf = gpd.read_file(tmp_vector) + gdf["label"] = gdf["value"].astype(int) + gdf.drop(columns=["value"], inplace=True) + gdf2 = pd.merge(gdf, df, on="label", how="left") + gdf2.to_file(out_vector) + gdf2.sort_values("label", inplace=True) + df = gdf2 + return da, df diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index a0e1f3e5..f0e7555a 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -1514,3 +1514,57 @@ def raster_to_vector(self, raster, vector, simplify_tolerance=None, **kwargs): common.raster_to_vector( raster, vector, simplify_tolerance=simplify_tolerance, **kwargs ) + + def region_groups( + self, + image: Union[str, "xr.DataArray", np.ndarray], + connectivity: int = 1, + min_size: int = 10, + max_size: Optional[int] = None, + threshold: Optional[int] = None, + properties: Optional[List[str]] = None, + out_csv: Optional[str] = None, + out_vector: Optional[str] = None, + out_image: Optional[str] = None, + **kwargs: Any, + ) -> Union[ + Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"] + ]: + """ + Segment regions in an image and filter them based on size. + + Args: + image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file + path, xarray DataArray, or numpy array. + connectivity (int, optional): Connectivity for labeling. Defaults to 1 + for 4-connectivity. Use 2 for 8-connectivity. + min_size (int, optional): Minimum size of regions to keep. Defaults to 10. + max_size (Optional[int], optional): Maximum size of regions to keep. + Defaults to None. + threshold (Optional[int], optional): Threshold for filling holes. + Defaults to None, which is equal to min_size. + properties (Optional[List[str]], optional): List of properties to measure. + See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops + Defaults to None. + out_csv (Optional[str], optional): Path to save the properties as a CSV file. + Defaults to None. + out_vector (Optional[str], optional): Path to save the vector file. + Defaults to None. + out_image (Optional[str], optional): Path to save the output image. + Defaults to None. + + Returns: + Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame. + """ + return common.region_groups( + image, + connectivity=connectivity, + min_size=min_size, + max_size=max_size, + threshold=threshold, + properties=properties, + out_csv=out_csv, + out_vector=out_vector, + out_image=out_image, + **kwargs, + )