Skip to content

Commit

Permalink
Merge pull request #3292 from chrishalcrow/quality_metrics_update
Browse files Browse the repository at this point in the history
Do not delete quality and template metrics on recompute
  • Loading branch information
alejoe91 authored Sep 13, 2024
2 parents f64ad5f + 6c8889d commit 72d072a
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 20 deletions.
63 changes: 57 additions & 6 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class ComputeTemplateMetrics(AnalyzerExtension):
For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function.
include_multi_channel_metrics : bool, default: False
Whether to compute multi-channel metrics
delete_existing_metrics : bool, default: False
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged.
metrics_kwargs : dict
Additional arguments to pass to the metric functions. Including:
* recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7
Expand Down Expand Up @@ -109,9 +111,12 @@ def _set_params(
sparsity=None,
metrics_kwargs=None,
include_multi_channel_metrics=False,
delete_existing_metrics=False,
**other_kwargs,
):

import pandas as pd

# TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory()
if include_multi_channel_metrics or (
metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names])
Expand Down Expand Up @@ -139,12 +144,36 @@ def _set_params(
metrics_kwargs_ = _default_function_kwargs.copy()
metrics_kwargs_.update(metrics_kwargs)

metrics_to_compute = metric_names
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:

existing_params = tm_extension.params["metrics_kwargs"]
# checks that existing metrics were calculated using the same params
if existing_params != metrics_kwargs_:
warnings.warn(
f"The parameters used to calculate the previous template metrics are different"
f"than those used now.\nPrevious parameters: {existing_params}\nCurrent "
f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..."
)
tm_extension.params["metric_names"] = []
existing_metric_names = []
else:
existing_metric_names = tm_extension.params["metric_names"]

existing_metric_names_propogated = [
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
]
metric_names = metrics_to_compute + existing_metric_names_propogated

params = dict(
metric_names=[str(name) for name in np.unique(metric_names)],
metric_names=metric_names,
sparsity=sparsity,
peak_sign=peak_sign,
upsampling_factor=int(upsampling_factor),
metrics_kwargs=metrics_kwargs_,
delete_existing_metrics=delete_existing_metrics,
metrics_to_compute=metrics_to_compute,
)

return params
Expand All @@ -158,6 +187,7 @@ def _merge_extension_data(
):
import pandas as pd

metric_names = self.params["metric_names"]
old_metrics = self.data["metrics"]

all_unit_ids = new_sorting_analyzer.unit_ids
Expand All @@ -166,19 +196,20 @@ def _merge_extension_data(
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)
metrics.loc[new_unit_ids, :] = self._compute_metrics(
new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs
)

new_data = dict(metrics=metrics)
return new_data

def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs):
def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs):
"""
Compute template metrics.
"""
import pandas as pd
from scipy.signal import resample_poly

metric_names = self.params["metric_names"]
sparsity = self.params["sparsity"]
peak_sign = self.params["peak_sign"]
upsampling_factor = self.params["upsampling_factor"]
Expand Down Expand Up @@ -290,10 +321,30 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
return template_metrics

def _run(self, verbose=False):
self.data["metrics"] = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose

delete_existing_metrics = self.params["delete_existing_metrics"]
metrics_to_compute = self.params["metrics_to_compute"]

# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute
)

existing_metrics = []
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metrics_to_compute):
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name]

self.data["metrics"] = computed_metrics

def _get_data(self):
return self.data["metrics"]

Expand Down
33 changes: 33 additions & 0 deletions src/spikeinterface/postprocessing/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
)


def _small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=10,
seed=1205,
)

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {"operators": ["average", "median"]},
"spike_amplitudes": {},
}

sorting_analyzer.compute(extensions_to_compute)

return sorting_analyzer


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()
102 changes: 102 additions & 0 deletions src/spikeinterface/postprocessing/tests/test_template_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,108 @@
from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite
from spikeinterface.postprocessing import ComputeTemplateMetrics
import pytest
import csv

from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func

template_metrics = list(_single_channel_metric_name_to_func.keys())


def test_compute_new_template_metrics(small_sorting_analyzer):
"""
Computes template metrics then computes a subset of template metrics, and checks
that the old template metrics are not deleted.
Then computes template metrics with new parameters and checks that old metrics
are deleted.
"""

# calculate just exp_decay
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})
template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")

assert "exp_decay" in list(template_metric_extension.get_data().keys())
assert "half_width" not in list(template_metric_extension.get_data().keys())

# calculate all template metrics
small_sorting_analyzer.compute("template_metrics")
# calculate just exp_decay - this should not delete any other metrics
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})
template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")

set(template_metrics) == set(template_metric_extension.get_data().keys())

# calculate just exp_decay with delete_existing_metrics
small_sorting_analyzer.compute(
{"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}}
)
template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")
computed_metric_names = template_metric_extension.get_data().keys()

for metric_name in template_metrics:
if metric_name == "exp_decay":
assert metric_name in computed_metric_names
else:
assert metric_name not in computed_metric_names

# check that, when parameters are changed, the old metrics are deleted
small_sorting_analyzer.compute(
{"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}}
)


def test_metric_names_in_same_order(small_sorting_analyzer):
"""
Computes sepecified template metrics and checks order is propogated.
"""
specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"]
small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names)
tm_keys = small_sorting_analyzer.get_extension("template_metrics").get_data().keys()
for i in range(3):
assert specified_metric_names[i] == tm_keys[i]


def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
"""
Computes template metrics in binary folder format. Then computes subsets of template
metrics and checks if they are saved correctly.
"""

small_sorting_analyzer.compute("template_metrics")

cache_folder = create_cache_folder
output_folder = cache_folder / "sorting_analyzer"

folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder)
template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv"

with open(template_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in template_metrics:
assert metric_name in metric_names

folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False)

with open(template_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in template_metrics:
assert metric_name in metric_names

folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True)

with open(template_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in template_metrics:
if metric_name == "half_width":
assert metric_name in metric_names
else:
assert metric_name not in metric_names


class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite):
Expand Down
13 changes: 13 additions & 0 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs):
return num_spikes


_default_params["num_spikes"] = {}


def compute_firing_rates(sorting_analyzer, unit_ids=None):
"""
Compute the firing rate across segments.
Expand Down Expand Up @@ -98,6 +101,9 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None):
return firing_rates


_default_params["firing_rate"] = {}


def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None):
"""
Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold.
Expand Down Expand Up @@ -1550,3 +1556,10 @@ def compute_sd_ratio(
sd_ratio[unit_id] = unit_std / std_noise

return sd_ratio


_default_params["sd_ratio"] = dict(
censored_period_ms=4.0,
correct_for_drift=True,
correct_for_template_itself=True,
)
Loading

0 comments on commit 72d072a

Please sign in to comment.