Skip to content

Commit

Permalink
Add region_groups function (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs authored Oct 7, 2024
1 parent 2b4c015 commit f375e92
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 1 deletion.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ patool
pycocotools
pyproj
rasterio
rioxarray
sam2
scikit-image
scikit-learn
segment-anything-hq
segment-anything-py
timm
Expand Down
167 changes: 166 additions & 1 deletion samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import cv2
import numpy as np
from tqdm import tqdm
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple, Any
import shapely
import pyproj
import rasterio
Expand Down Expand Up @@ -3520,3 +3520,168 @@ def geotiff_to_jpg_batch(input_folder: str, output_folder: str = None) -> str:
geotiff_to_jpg(geotiff_path, output_path)

return output_folder


def region_groups(
image: Union[str, "xr.DataArray", np.ndarray],
connectivity: int = 1,
min_size: int = 10,
max_size: Optional[int] = None,
threshold: Optional[int] = None,
properties: Optional[List[str]] = None,
out_csv: Optional[str] = None,
out_vector: Optional[str] = None,
out_image: Optional[str] = None,
**kwargs: Any,
) -> Union[Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]]:
"""
Segment regions in an image and filter them based on size.
Args:
image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file
path, xarray DataArray, or numpy array.
connectivity (int, optional): Connectivity for labeling. Defaults to 1
for 4-connectivity. Use 2 for 8-connectivity.
min_size (int, optional): Minimum size of regions to keep. Defaults to 10.
max_size (Optional[int], optional): Maximum size of regions to keep.
Defaults to None.
threshold (Optional[int], optional): Threshold for filling holes.
Defaults to None, which is equal to min_size.
properties (Optional[List[str]], optional): List of properties to measure.
See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
Defaults to None.
out_csv (Optional[str], optional): Path to save the properties as a CSV file.
Defaults to None.
out_vector (Optional[str], optional): Path to save the vector file.
Defaults to None.
out_image (Optional[str], optional): Path to save the output image.
Defaults to None.
Returns:
Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame.
"""
import rioxarray as rxr
import xarray as xr
from skimage import measure
import pandas as pd
import scipy.ndimage as ndi

if isinstance(image, str):
ds = rxr.open_rasterio(image)
da = ds.sel(band=1)
array = da.values.squeeze()
elif isinstance(image, xr.DataArray):
da = image
array = image.values.squeeze()
elif isinstance(image, np.ndarray):
array = image
else:
raise ValueError(
"The input image must be a file path, xarray DataArray, or numpy array."
)

if threshold is None:
threshold = min_size

if properties is None:
properties = [
"label",
"area",
"area_bbox",
"area_convex",
"area_filled",
"axis_major_length",
"axis_minor_length",
"eccentricity",
"equivalent_diameter_area",
"extent",
"orientation",
"perimeter",
"solidity",
]

label_image = measure.label(array, connectivity=connectivity)
props = measure.regionprops_table(label_image, properties=properties)

df = pd.DataFrame(props)

# Get the labels of regions with area smaller than the threshold
small_regions = df[df["area"] < min_size]["label"].values
# Set the corresponding labels in the label_image to zero
for region_label in small_regions:
label_image[label_image == region_label] = 0

if max_size is not None:
large_regions = df[df["area"] > max_size]["label"].values
for region_label in large_regions:
label_image[label_image == region_label] = 0

# Find the background (holes) which are zeros
holes = label_image == 0

# Label the holes (connected components in the background)
labeled_holes, _ = ndi.label(holes)

# Measure properties of the labeled holes, including area and bounding box
hole_props = measure.regionprops(labeled_holes)

# Loop through each hole and fill it if it is smaller than the threshold
for prop in hole_props:
if prop.area < threshold:
# Get the coordinates of the small hole
coords = prop.coords

# Find the surrounding region's ID (non-zero value near the hole)
surrounding_region_values = []
for coord in coords:
x, y = coord
# Get a 3x3 neighborhood around the hole pixel
neighbors = label_image[max(0, x - 1) : x + 2, max(0, y - 1) : y + 2]
# Exclude the hole pixels (zeros) and get region values
region_values = neighbors[neighbors != 0]
if region_values.size > 0:
surrounding_region_values.append(
region_values[0]
) # Take the first non-zero value

if surrounding_region_values:
# Fill the hole with the mode (most frequent) of the surrounding region values
fill_value = max(
set(surrounding_region_values), key=surrounding_region_values.count
)
label_image[coords[:, 0], coords[:, 1]] = fill_value

label_image, num_labels = measure.label(
label_image, connectivity=connectivity, return_num=True
)
props = measure.regionprops_table(label_image, properties=properties)

df = pd.DataFrame(props)
df["elongation"] = df["axis_major_length"] / df["axis_minor_length"]

dtype = "uint8"
if num_labels > 255 and num_labels <= 65535:
dtype = "uint16"
elif num_labels > 65535:
dtype = "uint32"

if out_csv is not None:
df.to_csv(out_csv, index=False)

if isinstance(image, np.ndarray):
return label_image, df
else:
da.values = label_image
if out_image is not None:
da.rio.to_raster(out_image, dtype=dtype)
if out_vector is not None:
tmp_vector = temp_file_path(".gpkg")
raster_to_vector(out_image, tmp_vector)
gdf = gpd.read_file(tmp_vector)
gdf["label"] = gdf["value"].astype(int)
gdf.drop(columns=["value"], inplace=True)
gdf2 = pd.merge(gdf, df, on="label", how="left")
gdf2.to_file(out_vector)
gdf2.sort_values("label", inplace=True)
df = gdf2
return da, df
54 changes: 54 additions & 0 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,3 +1514,57 @@ def raster_to_vector(self, raster, vector, simplify_tolerance=None, **kwargs):
common.raster_to_vector(
raster, vector, simplify_tolerance=simplify_tolerance, **kwargs
)

def region_groups(
self,
image: Union[str, "xr.DataArray", np.ndarray],
connectivity: int = 1,
min_size: int = 10,
max_size: Optional[int] = None,
threshold: Optional[int] = None,
properties: Optional[List[str]] = None,
out_csv: Optional[str] = None,
out_vector: Optional[str] = None,
out_image: Optional[str] = None,
**kwargs: Any,
) -> Union[
Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]
]:
"""
Segment regions in an image and filter them based on size.
Args:
image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file
path, xarray DataArray, or numpy array.
connectivity (int, optional): Connectivity for labeling. Defaults to 1
for 4-connectivity. Use 2 for 8-connectivity.
min_size (int, optional): Minimum size of regions to keep. Defaults to 10.
max_size (Optional[int], optional): Maximum size of regions to keep.
Defaults to None.
threshold (Optional[int], optional): Threshold for filling holes.
Defaults to None, which is equal to min_size.
properties (Optional[List[str]], optional): List of properties to measure.
See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
Defaults to None.
out_csv (Optional[str], optional): Path to save the properties as a CSV file.
Defaults to None.
out_vector (Optional[str], optional): Path to save the vector file.
Defaults to None.
out_image (Optional[str], optional): Path to save the output image.
Defaults to None.
Returns:
Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame.
"""
return common.region_groups(
image,
connectivity=connectivity,
min_size=min_size,
max_size=max_size,
threshold=threshold,
properties=properties,
out_csv=out_csv,
out_vector=out_vector,
out_image=out_image,
**kwargs,
)

0 comments on commit f375e92

Please sign in to comment.