diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 83705b94db..f2597d162f 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -357,8 +357,8 @@ def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None): for unit_ind in range(n): mask = spikes["unit_index"] == unit_ind - valid = pair_mask[unit_ind, unit_ind+1:] - valid_indices = np.arange(unit_ind+1, n)[valid] + valid = pair_mask[unit_ind, unit_ind + 1 :] + valid_indices = np.arange(unit_ind + 1, n)[valid] if len(valid_indices) > 0: ind = kdtree.kneighbors(data[mask], return_distance=False) ind = ind.flatten() @@ -366,7 +366,7 @@ def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None): ind = ind[mask_2] chan_inds, all_counts = np.unique(spikes["unit_index"][ind], return_counts=True) all_counts = all_counts.astype(float) - #all_counts /= all_spike_counts[chan_inds] + # all_counts /= all_spike_counts[chan_inds] best_indices = np.argsort(all_counts)[::-1][0:] pair_mask[unit_ind] &= np.isin(np.arange(n), chan_inds[best_indices]) return pair_mask