Skip to content

Commit

Permalink
feat: add support for --save-roi --roi-size (#18)
Browse files Browse the repository at this point in the history
Added `--save-roi` and `--roi-size `options to sdcat detect. This saves the crops in a location compatible with the clustering stage, but can also be used outside of sdcat.  Data saved to crops

     ├── det_filtered                    # The filtered detections from the model
            ├── crops                       # Crops of the detections
  • Loading branch information
danellecline authored Feb 20, 2025
1 parent bde5bf2 commit 9a801ac
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ The sdcat toolkit generates data in the following folders. Here, we assume both
│   ├── DSC01861.csv
│   └── DSC01922.csv
├── det_filtered # The filtered detections from the model
├── det_filtered_clustered # Clustered detections from the model
├── crops # Crops of the detections
├── dino_vits8...date # The clustering results - one folder per each run of the clustering algorithm
├── dino_vits8..exemplars.csv # Exemplar embeddings - examples with the highest cosine similarity within a cluster
Expand Down
32 changes: 29 additions & 3 deletions sdcat/detect/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hashlib
import multiprocessing
import os
import shutil
import uuid
from pathlib import Path

import click
Expand All @@ -11,6 +13,7 @@
from sahi.postprocess.combine import nms

from sdcat import common_args
from sdcat.cluster.utils import crop_square_image
from sdcat.config import config as cfg
from sdcat.config.config import default_config_ini
from sdcat.detect.model_util import create_model
Expand All @@ -31,6 +34,8 @@
@click.option('--show', is_flag=True, help='Show algorithm steps.')
@click.option('--image-dir', required=True, help='Directory with images to run sliced detection.')
@click.option('--save-dir', required=True, help='Save detections to this directory.')
@click.option('--save-roi', is_flag=True, help='Save each region of interest/detection.')
@click.option('--roi-size', type=int, default=224, help='Rescale the region of interest.')
@click.option('--device', default='cpu', help='Device to use, e.g. cpu or cuda:0')
@click.option('--spec-remove', is_flag=True, help='Run specularity removal algorithm on the images before processing. '
'**CAUTION**this is slow. Set --scale-percent to < 100 to speed-up')
Expand All @@ -47,7 +52,7 @@
@click.option('--overlap-height-ratio',type=float, default=0.4, help='Overlap height ratio for NMS')
@click.option('--clahe', is_flag=True, help='Run the CLAHE algorithm to contrast enhance before detection useful images with non-uniform lighting')

def run_detect(show: bool, image_dir: str, save_dir: str, model: str, model_type:str,
def run_detect(show: bool, image_dir: str, save_dir: str, save_roi:bool, roi_size: int, model: str, model_type:str,
slice_size_width: int, slice_size_height: int, scale_percent: int,
postprocess_match_metric: str, overlap_width_ratio: float, overlap_height_ratio: float,
device: str, conf: float, skip_sahi: bool, skip_saliency: bool, spec_remove: bool,
Expand Down Expand Up @@ -95,8 +100,13 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, model_type

save_path_det_raw = save_path_base / 'det_raw' / 'csv'
save_path_det_filtered = save_path_base / 'det_filtered' / 'csv'
save_path_det_roi = save_path_base / 'det_filtered' / 'crops'
save_path_viz = save_path_base / 'vizresults'

if save_roi:
save_path_det_roi.mkdir(parents=True, exist_ok=True)
for f in save_path_det_roi.rglob('*'):
os.remove(f)
save_path_det_raw.mkdir(parents=True, exist_ok=True)
save_path_det_filtered.mkdir(parents=True, exist_ok=True)
save_path_viz.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -280,12 +290,28 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, model_type
df_final['w'] = (df_final['xx'] - df_final['x'])
df_final['h'] = (df_final['xy'] - df_final['y'])

# Save DataFrame to CSV file including image_width and image_height columns
df_final.to_csv(pred_out_csv.as_posix(), index=False, header=True)
if save_roi:
# Add in a column for the unique crop name for each detection with a unique id
# create a unique uuid based on the md5 hash of the box in the row
df_final['crop_path'] = df_final.apply(lambda
row: f"{save_path_det_roi}/{uuid.uuid5(uuid.NAMESPACE_DNS, str(row['x']) + str(row['y']) + str(row['xx']) + str(row['xy']))}.png",
axis=1)

num_processes = min(multiprocessing.cpu_count(), len(df_final))
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_final)} detections in parallel using {num_processes} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, roi_size) for index, row in df_final.iterrows()]
pool.starmap(crop_square_image, args)

info(f'Found {len(pred_list)} total localizations in {f} with {len(df_combined)} after NMS')
info(f'Slice width: {slice_size_width} height: {slice_size_height}')

# Save DataFrame to CSV file including image_width and image_height columns
info(f'Detections saved to {pred_out_csv}')
df_final.to_csv(pred_out_csv.as_posix(), index=False, header=True)
if save_roi: info(f"ROI crops saved in {save_path_det_roi}")

save_stats = save_path_base / 'stats.txt'
with open(save_stats, 'w') as sf:
sf.write(f"Statistics for {f}:\n")
Expand Down
9 changes: 5 additions & 4 deletions sdcat/detect/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def create_model(model:str, conf:float, device:str, model_type=None):
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("MBARI-org/yolov5-uav-18k", "yolov5x6-uav-18k.pt")
},
'MBARI-org/yolov5x6-uavs-oneclass': {
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("MBARI-org/yolov5x6-uavs-oneclass", "best_uavs_oneclass.pt")
'MBARI-org/yolo11x-uavs-detect': {
'model_type': 'yolo11',
'model_path': lambda: hf_hub_download("MBARI-org/yolo11x-uavs-detect", "uavs-oneclass-best.pt")
},
'FathomNet/MBARI-315k-yolov5': {
'model_type': 'yolov5',
Expand All @@ -90,7 +90,8 @@ def create_model(model:str, conf:float, device:str, model_type=None):
}

if model not in model_map:
raise ValueError(f"Unknown model: {model}")
raise ValueError(f"Unknown model: {model}. Available models: {list(model_map.keys())}, "
f"or provide a local file path. You can also use the --model-type option to specify the model type.")

model_info = model_map[model]
model_type = model_info['model_type']
Expand Down
4 changes: 2 additions & 2 deletions tests/test_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_detect(data_dir: Path, scale: int) -> int:
data_dir.as_posix()], stdout=subprocess.PIPE)

# Wait for the process to finish
proc.wait()
proc.wait(5000)

# Verify that the process finished successfully
assert proc.returncode == 0
Expand Down Expand Up @@ -79,4 +79,4 @@ def test_plankton():
test_bird()
test_pinniped()
test_plankton()
print('All tests passed')
print('All tests passed')

0 comments on commit 9a801ac

Please sign in to comment.