Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 9, 2025
1 parent 7cdb239 commit efbb9a7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
10 changes: 3 additions & 7 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ class CircusClustering:
"""

_default_params = {
"hdbscan_kwargs": {
"min_cluster_size": 10,
"allow_single_cluster": True,
"min_samples" : 5
},
"hdbscan_kwargs": {"min_cluster_size": 10, "allow_single_cluster": True, "min_samples": 5},
"cleaning_kwargs": {},
"waveforms": {"ms_before": 2, "ms_after": 2},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
Expand Down Expand Up @@ -202,7 +198,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters

min_size = 2*params["hdbscan_kwargs"].get("min_cluster_size", 10)
min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 10)

peak_labels, _ = split_clusters(
original_labels,
Expand All @@ -216,7 +212,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
waveforms_sparse_mask=sparse_mask,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=[2, 4, 6, 8, 10]
n_pca_features=[2, 4, 6, 8, 10],
),
**params["recursive_kwargs"],
**job_kwargs,
Expand Down
14 changes: 8 additions & 6 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def split(

if not isinstance(n_pca_features, np.ndarray):
n_pca_features = np.array([n_pca_features])

n_pca_features = n_pca_features[n_pca_features <= aligned_wfs.shape[1]]
flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1)

Expand All @@ -230,6 +230,7 @@ def split(

if flatten_features.shape[1] > n_pca:
from sklearn.decomposition import PCA

# from sklearn.decomposition import TruncatedSVD
# tsvd = TruncatedSVD(n_pca_features)
tsvd = PCA(n_pca, whiten=True)
Expand Down Expand Up @@ -258,7 +259,7 @@ def split(
else:
raise ValueError(f"wrong clusterer {clusterer}")

#DEBUG = True
# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt
Expand All @@ -279,14 +280,15 @@ def split(

ax = axs[1]
ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5)
ax.set_xlabel('PCA features')
ax.set_xlabel("PCA features")

axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {n_pca}, recursion_level={recursion_level}")
import time
plt.savefig(f'split_{recursion_level}_{time.time()}.png')

plt.savefig(f"split_{recursion_level}_{time.time()}.png")
plt.close()
#plt.show()
# plt.show()

if is_split:
break

Expand Down

0 comments on commit efbb9a7

Please sign in to comment.