diff --git a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py index 076e40b5a2..57708d273d 100644 --- a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py @@ -74,7 +74,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret channel_ids=recording.channel_ids, unit_ids=gt_sorting.unit_ids, probe=recording.get_probe(), - is_scaled=return_scaled, + is_in_uV=return_scaled, ) return gt_templates diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index e237342f5b..87d6358ae9 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -646,7 +646,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save channel_ids=self.sorting_analyzer.channel_ids, unit_ids=unit_ids, probe=self.sorting_analyzer.get_probe(), - is_scaled=self.sorting_analyzer.return_in_uV, + is_in_uV=self.sorting_analyzer.return_in_uV, ) else: raise ValueError("`outputs` must be 'numpy' or 'Templates'") diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index f760243fb5..886de89835 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -439,7 +439,7 @@ def from_snr( return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels" - return_scaled = templates_or_sorting_analyzer.is_scaled + return_scaled = templates_or_sorting_analyzer.is_in_uV mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -491,9 +491,9 @@ def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode "You can set `return_scaled=True` when computing the templates." ) elif isinstance(templates_or_sorting_analyzer, Templates): - assert templates_or_sorting_analyzer.is_scaled, ( + assert templates_or_sorting_analyzer.is_in_uV, ( "To compute sparsity from amplitude you need to have scaled templates. " - "You can set `is_scaled=True` when creating the Templates object." + "You can set `is_in_uV=True` when creating the Templates object." ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index cfc155149f..b155cc0bc4 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -31,7 +31,7 @@ class Templates: Array of unit IDs. If `None`, defaults to an array of increasing integers. probe: Probe, default: None A `probeinterface.Probe` object - is_scaled : bool, optional default: True + is_in_uV : bool, optional default: True If True, it means that the templates are in uV, otherwise they are in raw ADC values. check_for_consistent_sparsity : bool, optional default: None When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the @@ -61,7 +61,7 @@ class Templates: templates_array: np.ndarray sampling_frequency: float nbefore: int - is_scaled: bool = True + is_in_uV: bool = True sparsity_mask: np.ndarray = None channel_ids: np.ndarray = None @@ -206,7 +206,7 @@ def to_sparse(self, sparsity): unit_ids=self.unit_ids, probe=self.probe, check_for_consistent_sparsity=self.check_for_consistent_sparsity, - is_scaled=self.is_scaled, + is_in_uV=self.is_in_uV, ) def get_one_template_dense(self, unit_index): @@ -257,7 +257,7 @@ def to_dict(self): "unit_ids": self.unit_ids, "sampling_frequency": self.sampling_frequency, "nbefore": self.nbefore, - "is_scaled": self.is_scaled, + "is_in_uV": self.is_in_uV, "probe": self.probe.to_dict() if self.probe is not None else None, } @@ -270,7 +270,7 @@ def from_dict(cls, data): unit_ids=np.asarray(data["unit_ids"]), sampling_frequency=data["sampling_frequency"], nbefore=data["nbefore"], - is_scaled=data["is_scaled"], + is_in_uV=data["is_in_uV"], probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]), ) @@ -304,7 +304,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore - zarr_group.attrs["is_scaled"] = self.is_scaled + zarr_group.attrs["is_in_uV"] = self.is_in_uV if self.sparsity_mask is not None: zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) @@ -362,7 +362,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates": nbefore = zarr_group.attrs["nbefore"] # TODO: Consider eliminating the True and make it required - is_scaled = zarr_group.attrs.get("is_scaled", True) + is_in_uV = zarr_group.attrs.get("is_in_uV", True) sparsity_mask = None if "sparsity_mask" in zarr_group: @@ -380,7 +380,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates": channel_ids=channel_ids, unit_ids=unit_ids, probe=probe, - is_scaled=is_scaled, + is_in_uV=is_in_uV, ) @staticmethod diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 3dd9fd82ae..7107c9f1b0 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -22,7 +22,7 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc The dense templates (num_units, num_samples, num_channels) """ if isinstance(one_object, Templates): - if return_scaled != one_object.is_scaled: + if return_scaled != one_object.is_in_uV: raise ValueError( f"get_dense_templates_array: return_scaled={return_scaled} is not possible Templates has the reverse" ) @@ -165,7 +165,7 @@ def get_template_extremum_channel( channel_ids = templates_or_sorting_analyzer.channel_ids # if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise - # we use the Templates is_scaled + # we use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): # For backward compatibility if hasattr(templates_or_sorting_analyzer, "return_scaled"): @@ -173,7 +173,7 @@ def get_template_extremum_channel( else: return_scaled = templates_or_sorting_analyzer.return_in_uV else: - return_scaled = templates_or_sorting_analyzer.is_scaled + return_scaled = templates_or_sorting_analyzer.is_in_uV peak_values = get_template_amplitudes( templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled @@ -218,7 +218,7 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak shifts = {} # We need to use the SortingAnalyzer return_scaled - # We need to use the Templates is_scaled + # We need to use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): # For backward compatibility if hasattr(templates_or_sorting_analyzer, "return_scaled"): @@ -226,7 +226,7 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak else: return_scaled = templates_or_sorting_analyzer.return_in_uV else: - return_scaled = templates_or_sorting_analyzer.is_scaled + return_scaled = templates_or_sorting_analyzer.is_in_uV templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) @@ -291,7 +291,7 @@ def get_template_extremum_amplitude( else: return_scaled = templates_or_sorting_analyzer.return_in_uV else: - return_scaled = templates_or_sorting_analyzer.is_scaled + return_scaled = templates_or_sorting_analyzer.is_in_uV extremum_amplitudes = get_template_amplitudes( templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled, abs_value=abs_value diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 4e0a0c8567..953527c8d5 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -7,7 +7,7 @@ from probeinterface import generate_multi_columns_probe -def generate_test_template(template_type, is_scaled=True) -> Templates: +def generate_test_template(template_type, is_in_uV=True) -> Templates: num_units = 3 num_samples = 5 num_channels = 4 @@ -28,7 +28,7 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: probe=probe, unit_ids=unit_ids, channel_ids=channel_ids, - is_scaled=is_scaled, + is_in_uV=is_in_uV, ) elif template_type == "sparse": # sparse with sparse templates sparsity_mask = np.array( @@ -53,7 +53,7 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, - is_scaled=is_scaled, + is_in_uV=is_in_uV, unit_ids=unit_ids, channel_ids=channel_ids, ) @@ -68,16 +68,16 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, - is_scaled=is_scaled, + is_in_uV=is_in_uV, unit_ids=unit_ids, channel_ids=channel_ids, ) -@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("is_in_uV", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_pickle_serialization(template_type, is_scaled, tmp_path): - template = generate_test_template(template_type, is_scaled) +def test_pickle_serialization(template_type, is_in_uV, tmp_path): + template = generate_test_template(template_type, is_in_uV) # Dump to pickle pkl_path = tmp_path / "templates.pkl" @@ -91,10 +91,10 @@ def test_pickle_serialization(template_type, is_scaled, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("is_in_uV", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_json_serialization(template_type, is_scaled): - template = generate_test_template(template_type, is_scaled) +def test_json_serialization(template_type, is_in_uV): + template = generate_test_template(template_type, is_in_uV) json_str = template.to_json() template_reloaded_from_json = Templates.from_json(json_str) @@ -102,10 +102,10 @@ def test_json_serialization(template_type, is_scaled): assert template == template_reloaded_from_json -@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("is_in_uV", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_get_dense_templates(template_type, is_scaled): - template = generate_test_template(template_type, is_scaled) +def test_get_dense_templates(template_type, is_in_uV): + template = generate_test_template(template_type, is_in_uV) dense_templates = template.get_dense_templates() assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) @@ -115,10 +115,10 @@ def test_initialization_fail_with_dense_templates(): template = generate_test_template(template_type="sparse_with_dense_templates") -@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("is_in_uV", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_save_and_load_zarr(template_type, is_scaled, tmp_path): - original_template = generate_test_template(template_type, is_scaled) +def test_save_and_load_zarr(template_type, is_in_uV, tmp_path): + original_template = generate_test_template(template_type, is_in_uV) zarr_path = tmp_path / "templates.zarr" original_template.to_zarr(str(zarr_path)) @@ -129,10 +129,10 @@ def test_save_and_load_zarr(template_type, is_scaled, tmp_path): assert original_template == loaded_template -@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("is_in_uV", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_select_units(template_type, is_scaled): - template = generate_test_template(template_type, is_scaled) +def test_select_units(template_type, is_in_uV): + template = generate_test_template(template_type, is_in_uV) selected_unit_ids = ["unit_a", "unit_c"] selected_unit_ids_indices = [0, 2] @@ -149,10 +149,10 @@ def test_select_units(template_type, is_scaled): assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[selected_unit_ids_indices]) -@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("is_in_uV", [True, False]) @pytest.mark.parametrize("template_type", ["dense"]) -def test_select_channels(template_type, is_scaled): - template = generate_test_template(template_type, is_scaled) +def test_select_channels(template_type, is_in_uV): + template = generate_test_template(template_type, is_in_uV) selected_channel_ids = ["channel1", "channel3"] selected_channel_ids_indices = [0, 2] diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 524a53f2c4..a28680612a 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -48,7 +48,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): sparsity_mask=None, channel_ids=sorting_analyzer.channel_ids, unit_ids=sorting_analyzer.unit_ids, - is_scaled=sorting_analyzer.return_in_uV, + is_in_uV=sorting_analyzer.return_in_uV, ) return templates diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 3eda7e3e72..70c9803921 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -166,7 +166,7 @@ def from_static_templates(cls, templates: Templates): nbefore=templates.nbefore, probe=templates.probe, sparsity_mask=templates.sparsity_mask, - is_scaled=templates.is_scaled, + is_in_uV=templates.is_in_uV, unit_ids=templates.unit_ids, channel_ids=templates.channel_ids, ) diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index e5b7ceba84..d63271370d 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -459,7 +459,7 @@ def generate_drifting_recording( sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, - is_scaled=True, + is_in_uV=True, ) drifting_templates = DriftingTemplates.from_static_templates(templates) diff --git a/src/spikeinterface/generation/tests/test_drift_tools.py b/src/spikeinterface/generation/tests/test_drift_tools.py index d7d4a2159b..a7ab0da700 100644 --- a/src/spikeinterface/generation/tests/test_drift_tools.py +++ b/src/spikeinterface/generation/tests/test_drift_tools.py @@ -65,7 +65,7 @@ def make_some_templates(): sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, - is_scaled=True, + is_in_uV=True, ) return templates diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index cfa89ad106..d0c07b5db4 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -325,7 +325,7 @@ def compute_grid_convolution( def get_return_scaled(sorting_analyzer_or_templates): if isinstance(sorting_analyzer_or_templates, Templates): - return_scaled = sorting_analyzer_or_templates.is_scaled + return_scaled = sorting_analyzer_or_templates.is_in_uV else: return_scaled = sorting_analyzer_or_templates.return_scaled return return_scaled diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index da2a6f3807..0ba8dac751 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -221,7 +221,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore=nbefore, sparsity_mask=None, probe=recording_for_peeler.get_probe(), - is_scaled=False, + is_in_uV=False, ) # TODO : try other methods for sparsity diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 94df58c97d..f0ae71b57d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -213,7 +213,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): channel_ids=recording.channel_ids, unit_ids=templates.unit_ids[valid_templates], probe=recording.get_probe(), - is_scaled=False, + is_in_uV=False, ) if params["debug"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index a954474f5d..1aac77c071 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -180,7 +180,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): nbefore=nbefore, sparsity_mask=None, probe=recording.get_probe(), - is_scaled=False, + is_in_uV=False, ) labels, peak_labels = remove_duplicates_via_matching( diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index e67e1907f1..f3f21e1bd4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -158,7 +158,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): channel_ids=recording.channel_ids, unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), - is_scaled=False, + is_in_uV=False, ) sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 7b982ca879..86b51fa190 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -269,7 +269,7 @@ def get_templates_from_peaks_and_recording( channel_ids=recording.channel_ids, unit_ids=labels, probe=recording.get_probe(), - is_scaled=False, + is_in_uV=False, ) return templates @@ -352,7 +352,7 @@ def get_templates_from_peaks_and_svd( channel_ids=recording.channel_ids, unit_ids=labels, probe=recording.get_probe(), - is_scaled=False, + is_in_uV=False, ) return templates diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index e83b5d246f..ac49f50a23 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -412,7 +412,7 @@ def remove_empty_templates(templates): channel_ids=templates.channel_ids, unit_ids=templates.unit_ids[not_empty], probe=templates.probe, - is_scaled=templates.is_scaled, + is_in_uV=templates.is_in_uV, ) diff --git a/src/spikeinterface/widgets/drift_templates.py b/src/spikeinterface/widgets/drift_templates.py index a86df64f5c..0a9c667fc7 100644 --- a/src/spikeinterface/widgets/drift_templates.py +++ b/src/spikeinterface/widgets/drift_templates.py @@ -113,7 +113,7 @@ def _update_ipywidget(self, keep_lims=False): templates_array, self.drifting_templates.sampling_frequency, self.drifting_templates.nbefore, - is_scaled=self.drifting_templates.is_scaled, + is_in_uV=self.drifting_templates.is_in_uV, sparsity_mask=None, channel_ids=self.drifting_templates.channel_ids, unit_ids=self.drifting_templates.unit_ids,