Skip to content

Commit

Permalink
Proposal to implement a sorting_analyzer.merge_units() syntax (#3043)
Browse files Browse the repository at this point in the history
Implement merging at SortingAnalyzer level and update curation
  • Loading branch information
yger authored Jul 15, 2024
1 parent a14ee81 commit fe74d45
Show file tree
Hide file tree
Showing 31 changed files with 1,425 additions and 225 deletions.
Binary file added doc/images/spikeinterface_gui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,20 @@ backends without writing to disk. So, you can compute an extension *in-memory* w
you have decided on your desired parameters you can either use :code:`compute` with :code:`save=True` or use :code:`save_as`
to write everything out to disk.


Finally, the :code:`SortingAnalyzer` object can be used directly to curate a spike sorting output by selecting/removing units
and merging unit groups.

.. code-block:: python
sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3])
sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0])
sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3])
All computed extensions will be automatically propagated or merged when curating. Please refer to the
:ref:`modules/curation` documentation for more information.


Event
-----

Expand Down
298 changes: 230 additions & 68 deletions doc/modules/curation.rst

Large diffs are not rendered by default.

172 changes: 154 additions & 18 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
* ComputeNoiseLevels which is very convenient to have
"""

import warnings

import numpy as np

from .sortinganalyzer import AnalyzerExtension, register_result_extension
Expand Down Expand Up @@ -76,6 +78,20 @@ def _select_extension_data(self, unit_ids):
new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_spike_mask])
return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()
random_spikes_indices = self.data["random_spikes_indices"]
if keep_mask is None:
new_data["random_spikes_indices"] = random_spikes_indices.copy()
else:
spikes = self.sorting_analyzer.sorting.to_spike_vector()
selected_mask = np.zeros(spikes.size, dtype=bool)
selected_mask[random_spikes_indices] = True
new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask])
return new_data

def _get_data(self):
return self.data["random_spikes_indices"]

Expand Down Expand Up @@ -224,18 +240,66 @@ def _select_extension_data(self, unit_ids):

return new_data

def get_waveforms_one_unit(
self,
unit_id,
force_dense: bool = False,
def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()

waveforms = self.data["waveforms"]
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
if keep_mask is not None:
spike_indices = self.sorting_analyzer.get_extension("random_spikes").get_data()
valid = keep_mask[spike_indices]
some_spikes = some_spikes[valid]
waveforms = waveforms[valid]
else:
waveforms = waveforms.copy()

old_sparsity = self.sorting_analyzer.sparsity
if old_sparsity is not None:
# we need a realignement inside each group because we take the channel intersection sparsity
for group_ids in merge_unit_groups:
group_indices = self.sorting_analyzer.sorting.ids_to_indices(group_ids)
group_sparsity_mask = old_sparsity.mask[group_indices, :]
group_selection = []
for unit_id in group_ids:
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)
selection = np.flatnonzero(some_spikes["unit_index"] == unit_index)
group_selection.append(selection)
_inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity_mask)

old_num_chans = int(np.max(np.sum(old_sparsity.mask, axis=1)))
new_num_chans = int(np.max(np.sum(new_sorting_analyzer.sparsity.mask, axis=1)))
if new_num_chans < old_num_chans:
waveforms = waveforms[:, :, :new_num_chans]

return dict(waveforms=waveforms)

def get_waveforms_one_unit(self, unit_id, force_dense: bool = False):
"""
Returns the waveforms of a unit id.
Parameters
----------
unit_id : int or str
The unit id to return waveforms for
force_dense : bool, default: False
If True, and SortingAnalyzer must be sparse then only waveforms on sparse channels are returned.
Returns
-------
waveforms: np.array
The waveforms (num_waveforms, num_samples, num_channels).
In case sparsity is used, only the waveforms on sparse channels are returned.
"""
sorting = self.sorting_analyzer.sorting
unit_index = sorting.id_to_index(unit_id)
# spikes = sorting.to_spike_vector()
# some_spikes = spikes[self.sorting_analyzer.random_spikes_indices]

waveforms = self.data["waveforms"]
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()

spike_mask = some_spikes["unit_index"] == unit_index
wfs = self.data["waveforms"][spike_mask, :, :]
wfs = waveforms[spike_mask, :, :]

if self.sorting_analyzer.sparsity is not None:
chan_inds = self.sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id]
Expand All @@ -252,6 +316,22 @@ def _get_data(self):
return self.data["waveforms"]


def _inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity_mask):
# this is used by "waveforms" extension but also "pca"

# common mask is intersection
common_mask = np.all(group_sparsity_mask, axis=0)

for i in range(len(group_selection)):
chan_mask = group_sparsity_mask[i, :]
sel = group_selection[i]
wfs = waveforms[sel, :, :][:, :, : np.sum(chan_mask)]
keep_mask = common_mask[chan_mask]
wfs = wfs[:, :, keep_mask]
waveforms[:, :, : wfs.shape[2]][sel, :, :] = wfs
waveforms[:, :, wfs.shape[2] :][sel, :, :] = 0.0


compute_waveforms = ComputeWaveforms.function_factory()
register_result_extension(ComputeWaveforms)

Expand Down Expand Up @@ -298,16 +378,13 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N

waveforms_extension = self.sorting_analyzer.get_extension("waveforms")
if waveforms_extension is not None:
nbefore = waveforms_extension.nbefore
nafter = waveforms_extension.nafter
else:
nbefore = int(ms_before * self.sorting_analyzer.sampling_frequency / 1000.0)
nafter = int(ms_after * self.sorting_analyzer.sampling_frequency / 1000.0)
ms_before = waveforms_extension.params["ms_before"]
ms_after = waveforms_extension.params["ms_after"]

params = dict(
operators=operators,
nbefore=nbefore,
nafter=nafter,
ms_before=ms_before,
ms_after=ms_after,
)
return params

Expand All @@ -316,6 +393,7 @@ def _run(self, verbose=False, **job_kwargs):

if self.sorting_analyzer.has_extension("waveforms"):
self._compute_and_append_from_waveforms(self.params["operators"])

else:
for operator in self.params["operators"]:
if operator not in ("average", "std"):
Expand Down Expand Up @@ -380,7 +458,6 @@ def _compute_and_append_from_waveforms(self, operators):
"random_spikes"
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()

for unit_index, unit_id in enumerate(unit_ids):
spike_mask = some_spikes["unit_index"] == unit_index
wfs = waveforms[spike_mask, :, :]
Expand Down Expand Up @@ -410,11 +487,33 @@ def _compute_and_append_from_waveforms(self, operators):

@property
def nbefore(self):
return self.params["nbefore"]
if "ms_before" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency
warnings.warn(
"The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning,
stacklevel=2,
)

nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return nbefore

@property
def nafter(self):
return self.params["nafter"]
if "ms_after" not in self.params:
# compatibility february 2024 > july 2024
warnings.warn(
"The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning,
stacklevel=2,
)
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency

nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return nafter

def _select_extension_data(self, unit_ids):
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))
Expand All @@ -425,12 +524,43 @@ def _select_extension_data(self, unit_ids):

return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):

all_new_units = new_sorting_analyzer.unit_ids
new_data = dict()
counts = self.sorting_analyzer.sorting.count_num_spikes_per_unit()
for key, arr in self.data.items():
new_data[key] = np.zeros((len(all_new_units), arr.shape[1], arr.shape[2]), dtype=arr.dtype)
for unit_index, unit_id in enumerate(all_new_units):
if unit_id not in new_unit_ids:
keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)
new_data[key][unit_index] = arr[keep_unit_index, :, :]
else:
merge_group = merge_unit_groups[list(new_unit_ids).index(unit_id)]
keep_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_group)
# We do a weighted sum of the templates
weights = np.zeros(len(merge_group), dtype=np.float32)
for count, merge_unit_id in enumerate(merge_group):
weights[count] = counts[merge_unit_id]
weights /= weights.sum()
new_data[key][unit_index] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum(
0
)
if new_sorting_analyzer.sparsity is not None:
chan_ids = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id]
mask = ~np.isin(np.arange(arr.shape[2]), chan_ids)
new_data[key][unit_index][:, mask] = 0

return new_data

def _get_data(self, operator="average", percentile=None, outputs="numpy"):
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
key = f"pencentile_{percentile}"
key = f"percentile_{percentile}"

templates_array = self.data[key]

Expand Down Expand Up @@ -582,6 +712,12 @@ def _select_extension_data(self, unit_ids):
# this do not depend on units
return self.data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
# this do not depend on units
return self.data.copy()

def _run(self, verbose=False):
self.data["noise_levels"] = get_noise_levels(
self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,8 @@ def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=Fa
-------
sorting_with_split : NumpySorting
A sorting with split units.
other_ids : dict
The dictionary with the split unit_ids. Returned only if output_ids is True.
"""
unit_ids = sorting.unit_ids
assert unit_ids.dtype.kind == "i"
Expand Down
Loading

0 comments on commit fe74d45

Please sign in to comment.