Skip to content

Commit

Permalink
feat: added square black pad resize for roi
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed May 16, 2024
1 parent 6bb998a commit 473aa34
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 14 deletions.
22 changes: 15 additions & 7 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hdbscan import HDBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from sdcat.logger import info, warn, debug, err
from sdcat.cluster.utils import cluster_grid, crop_square_image
from sdcat.cluster.utils import cluster_grid, crop_square_image, square_image
from sdcat.cluster.embedding import fetch_embedding, has_cached_embedding, compute_norm_embedding

if find_spec("cuml"):
Expand Down Expand Up @@ -222,12 +222,14 @@ def cluster_vits(
cluster_selection_epsilon: float,
min_similarity: float,
min_cluster_size: int,
min_samples: int):
min_samples: int,
roi: bool = False) -> pd.DataFrame:
""" Cluster the crops using the VITS embeddings.
:param prefix: A unique prefix to save artifacts from clustering
:param model: The model to use for clustering
:param df_dets: The dataframe with the detections
:param output_path: The output path to save the clustering artifacts to
:param roi: Whether the detections are already cropped to the ROI
:param cluster_selection_epsilon: The epsilon parameter for HDBSCAN
:param alpha: The alpha parameter for HDBSCAN
:param min_similarity: The minimum similarity score to use for -1 cluster reassignment
Expand All @@ -245,12 +247,18 @@ def cluster_vits(

# Skip cropping if all the crops are already done
if num_crop != len(df_dets):
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
num_processes = min(multiprocessing.cpu_count(), len(df_dets))
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)
if roi == True:
info('ROI crops already exist. Creating square crops in parallel using {multiprocessing.cpu_count()} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)
else:
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)

# Drop any rows with crop_path that have files that don't exist - sometimes the crops fail
df_dets = df_dets[df_dets['crop_path'].apply(lambda x: os.path.exists(x))]
Expand Down
11 changes: 5 additions & 6 deletions sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,10 @@ def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_select
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)


# Grab all images from the input directories
supported_extensions = ['.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG']
images = []

detections = []
roi_path = Path(roi_dir)
for ext in supported_extensions:
images.extend(list(roi_path.rglob(f'*{ext}')))
Expand All @@ -316,9 +314,10 @@ def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_select
# Sort the dataframe by image_path to make sure the images are in order for start_image and end_image filtering
df = df.sort_values(by='image_path')

# Add in a column for the unique crop name for each detection with a unique id
# create a unique uuid based on the md5 hash of the box in the row
df['crop_path'] = df['image_path']
# Create a unique crop name for each detection with a unique id
crop_path = save_dir / 'crops'
crop_path.mkdir(parents=True, exist_ok=True)
df['crop_path'] = df.apply(lambda row: f"{crop_path}/{uuid.uuid5(uuid.NAMESPACE_DNS, row['image_path'])}.png", axis=1)

# Add in a column for the unique crop name for each detection with a unique id
df['cluster_id'] = -1 # -1 is the default value and means that the image is not in a cluster
Expand All @@ -342,7 +341,7 @@ def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_select

# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, min_similarity,
min_cluster_size, min_samples)
min_cluster_size, min_samples, roi=True)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand Down
38 changes: 37 additions & 1 deletion sdcat/cluster/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,45 @@ def gen_grid(with_attention: bool):
# gen_grid(with_attention=True)


def square_image(row, square_dim: int):
"""
Squares an image to the model dimension, filling it with black bars if necessary
:param row:
:param square_dim: dimension of the square image
:return:
"""
try:
if not Path(row.image_path).exists():
warn(f'Skipping {row.crop_path} because the image {row.image_path} does not exist')
return

if Path(row.crop_path).exists(): # If the crop already exists, skip it
return

# Determine the size of the new square
max_side = max(row.image_width, row.image_height)

# Create a new square image with a black background
new_image = Image.new('RGB', (max_side, max_side), (0, 0, 0))

img = Image.open(row.image_path)

# Paste the original image onto the center of the new image
new_image.paste(img, ((max_side - row.image_width) // 2, (max_side - row.image_height) // 2))

# Resize the image to square_dim x square_dim
img = img.resize((square_dim, square_dim), Image.LANCZOS)

# Save the image
img.save(row.crop_path)
img.close()
except Exception as e:
exception(f'Error cropping {row.image_path} {e}')
raise e

def crop_square_image(row, square_dim: int):
"""
Crop the image to a square padding the shorted dimension, then resize it to square_dim x square_dim
Crop the image to a square padding the shortest dimension, then resize it to square_dim x square_dim
This also adjusts the crop to make sure the crop is fully in the frame, otherwise the crop that
exceeds the frame is filled with black bars - these produce clusters of "edge" objects instead
of the detection
Expand Down

0 comments on commit 473aa34

Please sign in to comment.