Skip to content

Commit a9476c2

Browse files
authored
Merge pull request #7 from mbari-org/roicluster
RoiCluster
2 parents 76ec895 + bd1696c commit a9476c2

File tree

4 files changed

+79
-16
lines changed

4 files changed

+79
-16
lines changed

sdcat/__main__.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import click
88
from sdcat.logger import err, info, create_logger_file
99
from sdcat import __version__
10-
from sdcat.cluster.commands import run_cluster
10+
from sdcat.cluster.commands import run_cluster_det, run_cluster_roi
1111
from sdcat.detect.commands import run_detect
1212

1313

@@ -28,7 +28,19 @@ def cli():
2828
pass
2929

3030
cli.add_command(run_detect)
31-
cli.add_command(run_cluster)
31+
32+
33+
@cli.group(name="cluster")
34+
def cli_cluster():
35+
"""
36+
Commands related to converting data
37+
"""
38+
pass
39+
40+
41+
cli.add_command(cli_cluster)
42+
cli_cluster.add_command(run_cluster_det)
43+
cli_cluster.add_command(run_cluster_roi)
3244

3345

3446
if __name__ == '__main__':

sdcat/cluster/cluster.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sklearn.metrics.pairwise import cosine_similarity
1616
from sklearn.preprocessing import MinMaxScaler
1717
from sdcat.logger import info, warn, debug, err
18-
from sdcat.cluster.utils import cluster_grid, crop_square_image
18+
from sdcat.cluster.utils import cluster_grid, crop_square_image, square_image
1919
from sdcat.cluster.embedding import fetch_embedding, has_cached_embedding, compute_norm_embedding
2020

2121
if find_spec("multicore_tsne"):
@@ -227,12 +227,14 @@ def cluster_vits(
227227
cluster_selection_epsilon: float,
228228
min_similarity: float,
229229
min_cluster_size: int,
230-
min_samples: int):
230+
min_samples: int,
231+
roi: bool = False) -> pd.DataFrame:
231232
""" Cluster the crops using the VITS embeddings.
232233
:param prefix: A unique prefix to save artifacts from clustering
233234
:param model: The model to use for clustering
234235
:param df_dets: The dataframe with the detections
235236
:param output_path: The output path to save the clustering artifacts to
237+
:param roi: Whether the detections are already cropped to the ROI
236238
:param cluster_selection_epsilon: The epsilon parameter for HDBSCAN
237239
:param alpha: The alpha parameter for HDBSCAN
238240
:param min_similarity: The minimum similarity score to use for -1 cluster reassignment
@@ -250,12 +252,18 @@ def cluster_vits(
250252

251253
# Skip cropping if all the crops are already done
252254
if num_crop != len(df_dets):
253-
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
254-
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
255255
num_processes = min(multiprocessing.cpu_count(), len(df_dets))
256-
with multiprocessing.Pool(num_processes) as pool:
257-
args = [(row, 224) for index, row in df_dets.iterrows()]
258-
pool.starmap(crop_square_image, args)
256+
if roi == True:
257+
info('ROI crops already exist. Creating square crops in parallel using {multiprocessing.cpu_count()} processes...')
258+
with multiprocessing.Pool(num_processes) as pool:
259+
args = [(row, 224) for index, row in df_dets.iterrows()]
260+
pool.starmap(square_image, args)
261+
else:
262+
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
263+
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
264+
with multiprocessing.Pool(num_processes) as pool:
265+
args = [(row, 224) for index, row in df_dets.iterrows()]
266+
pool.starmap(crop_square_image, args)
259267

260268
# Drop any rows with crop_path that have files that don't exist - sometimes the crops fail
261269
df_dets = df_dets[df_dets['crop_path'].apply(lambda x: os.path.exists(x))]
@@ -279,9 +287,15 @@ def cluster_vits(
279287
(output_path / prefix).mkdir(parents=True)
280288

281289
# Remove everything except ancillary data to include in clustering
282-
ancillary_df = df_dets.drop(
283-
columns=['x', 'y', 'xx', 'xy', 'w', 'h', 'image_width', 'image_height', 'cluster_id', 'cluster', 'score',
284-
'class', 'image_path', 'crop_path'])
290+
columns = ['x', 'y', 'xx', 'xy', 'w', 'h', 'image_width', 'image_height', 'cluster_id', 'cluster', 'score',
291+
'class', 'image_path', 'crop_path']
292+
# Check if the columns exist in the dataframe
293+
if all(col in df_dets.columns for col in columns):
294+
ancillary_df = df_dets.drop(
295+
columns=['x', 'y', 'xx', 'xy', 'w', 'h', 'image_width', 'image_height', 'cluster_id', 'cluster', 'score',
296+
'class', 'image_path', 'crop_path'])
297+
else:
298+
ancillary_df = df_dets
285299

286300
# Cluster the images
287301
cluster_sim, unique_clusters, cluster_means, coverage = _run_hdbscan_assign(prefix,

sdcat/cluster/commands.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
import pandas as pd
1515
import pytz
1616
import torch
17+
from PIL import Image
1718

1819
from sdcat import common_args
1920
from sdcat.config import config as cfg
2021
from sdcat.logger import info, err, warn
2122
from sdcat.cluster.cluster import cluster_vits
2223

2324

24-
@click.command('cluster', help='Cluster detections. See cluster --config-ini to override cluster defaults.')
25+
@click.command('detections', help='Cluster detections. See cluster --config-ini to override cluster defaults.')
2526
@common_args.config_ini
2627
@common_args.start_image
2728
@common_args.end_image
@@ -31,7 +32,7 @@
3132
@click.option('--alpha', help='Alpha is a parameter that controls the linkage. See https://hdbscan.readthedocs.io/en/latest/parameter_selection.html. Default is 0.92. Increase for less conservative clustering, e.g. 1.0', type=float)
3233
@click.option('--cluster-selection-epsilon', help='Epsilon is a parameter that controls the linkage. Default is 0. Increase for less conservative clustering', type=float)
3334
@click.option('--min-cluster-size', help='The minimum number of samples in a group for that group to be considered a cluster. Default is 2. Increase for less conservative clustering, e.g. 5, 15', type=int)
34-
def run_cluster(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, start_image, end_image):
35+
def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, start_image, end_image):
3536
config = cfg.Config(config_ini)
3637
max_area = int(config('cluster', 'max_area'))
3738
min_area = int(config('cluster', 'min_area'))
@@ -259,7 +260,7 @@ def is_day(utc_dt):
259260
shutil.copy(Path(config_ini), save_dir / f'{prefix}_config.ini')
260261
else:
261262
warn(f'No detections found to cluster')
262-
263+
263264
@click.command('roi', help='Cluster roi. See cluster --config-ini to override cluster defaults.')
264265
@common_args.config_ini
265266
@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True)

sdcat/cluster/utils.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,45 @@ def gen_grid(with_attention: bool):
9292
# gen_grid(with_attention=True)
9393

9494

95+
def square_image(row, square_dim: int):
96+
"""
97+
Squares an image to the model dimension, filling it with black bars if necessary
98+
:param row:
99+
:param square_dim: dimension of the square image
100+
:return:
101+
"""
102+
try:
103+
if not Path(row.image_path).exists():
104+
warn(f'Skipping {row.crop_path} because the image {row.image_path} does not exist')
105+
return
106+
107+
if Path(row.crop_path).exists(): # If the crop already exists, skip it
108+
return
109+
110+
# Determine the size of the new square
111+
max_side = max(row.image_width, row.image_height)
112+
113+
# Create a new square image with a black background
114+
new_image = Image.new('RGB', (max_side, max_side), (0, 0, 0))
115+
116+
img = Image.open(row.image_path)
117+
118+
# Paste the original image onto the center of the new image
119+
new_image.paste(img, ((max_side - row.image_width) // 2, (max_side - row.image_height) // 2))
120+
121+
# Resize the image to square_dim x square_dim
122+
img = img.resize((square_dim, square_dim), Image.LANCZOS)
123+
124+
# Save the image
125+
img.save(row.crop_path)
126+
img.close()
127+
except Exception as e:
128+
exception(f'Error cropping {row.image_path} {e}')
129+
raise e
130+
95131
def crop_square_image(row, square_dim: int):
96132
"""
97-
Crop the image to a square padding the shorted dimension, then resize it to square_dim x square_dim
133+
Crop the image to a square padding the shortest dimension, then resize it to square_dim x square_dim
98134
This also adjusts the crop to make sure the crop is fully in the frame, otherwise the crop that
99135
exceeds the frame is filled with black bars - these produce clusters of "edge" objects instead
100136
of the detection

0 commit comments

Comments
 (0)