diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 1569914562..1cbe848a18 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -40,11 +40,7 @@ class CircusClustering: """ _default_params = { - "hdbscan_kwargs": { - "min_cluster_size": 10, - "allow_single_cluster": True, - "min_samples" : 5 - }, + "hdbscan_kwargs": {"min_cluster_size": 10, "allow_single_cluster": True, "min_samples": 5}, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, @@ -202,7 +198,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters - min_size = 2*params["hdbscan_kwargs"].get("min_cluster_size", 10) + min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 10) peak_labels, _ = split_clusters( original_labels, @@ -216,7 +212,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): waveforms_sparse_mask=sparse_mask, min_size_split=min_size, clusterer_kwargs=d["hdbscan_kwargs"], - n_pca_features=[2, 4, 6, 8, 10] + n_pca_features=[2, 4, 6, 8, 10], ), **params["recursive_kwargs"], **job_kwargs, diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index b880660d20..570a393c23 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -220,7 +220,7 @@ def split( if not isinstance(n_pca_features, np.ndarray): n_pca_features = np.array([n_pca_features]) - + n_pca_features = n_pca_features[n_pca_features <= aligned_wfs.shape[1]] flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) @@ -230,6 +230,7 @@ def split( if flatten_features.shape[1] > n_pca: from sklearn.decomposition import PCA + # from sklearn.decomposition import TruncatedSVD # tsvd = TruncatedSVD(n_pca_features) tsvd = PCA(n_pca, whiten=True) @@ -258,7 +259,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}") - #DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -279,14 +280,15 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) - ax.set_xlabel('PCA features') + ax.set_xlabel("PCA features") axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {n_pca}, recursion_level={recursion_level}") import time - plt.savefig(f'split_{recursion_level}_{time.time()}.png') + + plt.savefig(f"split_{recursion_level}_{time.time()}.png") plt.close() - #plt.show() - + # plt.show() + if is_split: break