diff --git a/sdcat/cluster/cluster.py b/sdcat/cluster/cluster.py index a7ef7f8..4590595 100644 --- a/sdcat/cluster/cluster.py +++ b/sdcat/cluster/cluster.py @@ -44,6 +44,7 @@ def _run_hdbscan_assign( min_similarity: float, min_cluster_size: int, min_samples: int, + use_tsne: bool, ancillary_df: pd.DataFrame, out_path: Path) -> tuple: """ @@ -55,6 +56,7 @@ def _run_hdbscan_assign( :param min_similarity: The minimum similarity score to use for clustering reassignment :param min_cluster_size: The minimum number of samples in a cluster :param min_samples: The number of samples in a neighborhood for a point + :param use_tsne: Whether to use t-SNE for dimensionality reduction :param ancillary_df: (optional) Ancillary data to include in the clustering :param out_path: The output path to save the clustering artifacts to :return: The average similarity score for each cluster, exemplar_df, cluster ids, cluster means, and coverage @@ -84,7 +86,7 @@ def _run_hdbscan_assign( perplexity = min(30, num_samples - 1) # TSN-E does not work well when we have a few samples - if num_samples > 100: + if num_samples > 100 and use_tsne: tsne = TSNE(n_components=2, perplexity=perplexity, metric="cosine", n_jobs=8, random_state=42, verbose=True) embedding = tsne.fit_transform(df.values) else: @@ -242,6 +244,7 @@ def cluster_vits( min_similarity: float, min_cluster_size: int, min_samples: int, + use_tsne: bool = False, roi: bool = False) -> pd.DataFrame: """ Cluster the crops using the VITS embeddings. :param prefix: A unique prefix to save artifacts from clustering @@ -319,6 +322,7 @@ def cluster_vits( min_similarity, min_cluster_size, min_samples, + use_tsne, ancillary_df, output_path / prefix) diff --git a/sdcat/cluster/commands.py b/sdcat/cluster/commands.py index b80cca1..c8151aa 100644 --- a/sdcat/cluster/commands.py +++ b/sdcat/cluster/commands.py @@ -26,13 +26,14 @@ @common_args.config_ini @common_args.start_image @common_args.end_image -@click.option('--det-dir', help='Input folder(s) with raw detection results', multiple=True) -@click.option('--save-dir', help='Output directory to save clustered detection results') +@common_args.use_tsne +@common_args.alpha +@common_args.cluster_selection_epsilon +@common_args.min_cluster_size +@click.option('--det-dir', help='Input folder(s) with raw detection results', multiple=True, required=True) +@click.option('--save-dir', help='Output directory to save clustered detection results', required=True) @click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str) -@click.option('--alpha', help='Alpha is a parameter that controls the linkage. See https://hdbscan.readthedocs.io/en/latest/parameter_selection.html. Default is 0.92. Increase for less conservative clustering, e.g. 1.0', type=float) -@click.option('--cluster-selection-epsilon', help='Epsilon is a parameter that controls the linkage. Default is 0. Increase for less conservative clustering', type=float) -@click.option('--min-cluster-size', help='The minimum number of samples in a group for that group to be considered a cluster. Default is 2. Increase for less conservative clustering, e.g. 5, 15', type=int) -def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, start_image, end_image): +def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, start_image, end_image, use_tsne): config = cfg.Config(config_ini) max_area = int(config('cluster', 'max_area')) min_area = int(config('cluster', 'min_area')) @@ -250,7 +251,7 @@ def is_day(utc_dt): # Cluster the detections df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, min_similarity, - min_cluster_size, min_samples) + min_cluster_size, min_samples, use_tsne) # Merge the results with the original DataFrame df.update(df_cluster) @@ -263,13 +264,14 @@ def is_day(utc_dt): @click.command('roi', help='Cluster roi. See cluster --config-ini to override cluster defaults.') @common_args.config_ini -@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True) -@click.option('--save-dir', help='Output directory to save clustered detection results') +@common_args.use_tsne +@common_args.alpha +@common_args.cluster_selection_epsilon +@common_args.min_cluster_size +@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True, required=True) +@click.option('--save-dir', help='Output directory to save clustered detection results', required=True) @click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str) -@click.option('--alpha', help='Alpha is a parameter that controls the linkage. See https://hdbscan.readthedocs.io/en/latest/parameter_selection.html. Default is 0.92. Increase for less conservative clustering, e.g. 1.0', type=float) -@click.option('--cluster-selection-epsilon', help='Epsilon is a parameter that controls the linkage. Default is 0. Increase for less conservative clustering', type=float) -@click.option('--min-cluster-size', help='The minimum number of samples in a group for that group to be considered a cluster. Default is 2. Increase for less conservative clustering, e.g. 5, 15', type=int) -def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size): +def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, use_tsne): config = cfg.Config(config_ini) min_samples = int(config('cluster', 'min_samples')) alpha = alpha if alpha else float(config('cluster', 'alpha')) @@ -346,7 +348,7 @@ def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_select # Cluster the detections df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, min_similarity, - min_cluster_size, min_samples, roi=True) + min_cluster_size, min_samples, use_tsne, roi=True) # Merge the results with the original DataFrame df.update(df_cluster) diff --git a/sdcat/common_args.py b/sdcat/common_args.py index 05b3f05..6090c33 100644 --- a/sdcat/common_args.py +++ b/sdcat/common_args.py @@ -1,4 +1,4 @@ -# sightwire, Apache-2.0 license +# sdcat, Apache-2.0 license # Filename: common_args.py # Description: Common arguments for processing commands @@ -18,4 +18,23 @@ end_image = click.option('--end-image', type=str, - help='End image name') \ No newline at end of file + help='End image name') + +alpha = click.option('--alpha', + type=float, + help='Alpha is a parameter that controls the linkage. See https://hdbscan.readthedocs.io/en/latest/parameter_selection.html. ' + 'Default is 0.92. Increase for less conservative clustering, e.g. 1.0') + +cluster_selection_epsilon = click.option('--cluster-selection-epsilon', + type=float, + help='Epsilon is a parameter that controls the linkage. ' + 'Default is 0. Increase for less conservative clustering') + +min_cluster_size = click.option('--min-cluster-size', + type=int, + help='The minimum number of samples in a group for that group to be considered a cluster. ' + 'Default is 2. Increase for less conservative clustering, e.g. 5, 15') + +use_tsne = click.option('--use-tsne', + is_flag=True, + help='Use t-SNE for dimensionality reduction. Default is False') \ No newline at end of file