Skip to content

Templates.is_scaled > Templates.is_in_uV #4036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret
channel_ids=recording.channel_ids,
unit_ids=gt_sorting.unit_ids,
probe=recording.get_probe(),
is_scaled=return_scaled,
is_in_uV=return_scaled,
)
return gt_templates

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
channel_ids=self.sorting_analyzer.channel_ids,
unit_ids=unit_ids,
probe=self.sorting_analyzer.get_probe(),
is_scaled=self.sorting_analyzer.return_in_uV,
is_in_uV=self.sorting_analyzer.return_in_uV,
)
else:
raise ValueError("`outputs` must be 'numpy' or 'Templates'")
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def from_snr(
return_scaled = templates_or_sorting_analyzer.return_scaled
elif isinstance(templates_or_sorting_analyzer, Templates):
assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels"
return_scaled = templates_or_sorting_analyzer.is_scaled
return_scaled = templates_or_sorting_analyzer.is_in_uV

mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")

Expand Down Expand Up @@ -491,9 +491,9 @@ def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode
"You can set `return_scaled=True` when computing the templates."
)
elif isinstance(templates_or_sorting_analyzer, Templates):
assert templates_or_sorting_analyzer.is_scaled, (
assert templates_or_sorting_analyzer.is_in_uV, (
"To compute sparsity from amplitude you need to have scaled templates. "
"You can set `is_scaled=True` when creating the Templates object."
"You can set `is_in_uV=True` when creating the Templates object."
)

mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")
Expand Down
16 changes: 8 additions & 8 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Templates:
Array of unit IDs. If `None`, defaults to an array of increasing integers.
probe: Probe, default: None
A `probeinterface.Probe` object
is_scaled : bool, optional default: True
is_in_uV : bool, optional default: True
If True, it means that the templates are in uV, otherwise they are in raw ADC values.
check_for_consistent_sparsity : bool, optional default: None
When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the
Expand Down Expand Up @@ -61,7 +61,7 @@ class Templates:
templates_array: np.ndarray
sampling_frequency: float
nbefore: int
is_scaled: bool = True
is_in_uV: bool = True

sparsity_mask: np.ndarray = None
channel_ids: np.ndarray = None
Expand Down Expand Up @@ -206,7 +206,7 @@ def to_sparse(self, sparsity):
unit_ids=self.unit_ids,
probe=self.probe,
check_for_consistent_sparsity=self.check_for_consistent_sparsity,
is_scaled=self.is_scaled,
is_in_uV=self.is_in_uV,
)

def get_one_template_dense(self, unit_index):
Expand Down Expand Up @@ -257,7 +257,7 @@ def to_dict(self):
"unit_ids": self.unit_ids,
"sampling_frequency": self.sampling_frequency,
"nbefore": self.nbefore,
"is_scaled": self.is_scaled,
"is_in_uV": self.is_in_uV,
"probe": self.probe.to_dict() if self.probe is not None else None,
}

Expand All @@ -270,7 +270,7 @@ def from_dict(cls, data):
unit_ids=np.asarray(data["unit_ids"]),
sampling_frequency=data["sampling_frequency"],
nbefore=data["nbefore"],
is_scaled=data["is_scaled"],
is_in_uV=data["is_in_uV"],
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
)

Expand Down Expand Up @@ -304,7 +304,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:

zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
zarr_group.attrs["nbefore"] = self.nbefore
zarr_group.attrs["is_scaled"] = self.is_scaled
zarr_group.attrs["is_in_uV"] = self.is_in_uV

if self.sparsity_mask is not None:
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)
Expand Down Expand Up @@ -362,7 +362,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
nbefore = zarr_group.attrs["nbefore"]

# TODO: Consider eliminating the True and make it required
is_scaled = zarr_group.attrs.get("is_scaled", True)
is_in_uV = zarr_group.attrs.get("is_in_uV", True)

sparsity_mask = None
if "sparsity_mask" in zarr_group:
Expand All @@ -380,7 +380,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
is_scaled=is_scaled,
is_in_uV=is_in_uV,
)

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc
The dense templates (num_units, num_samples, num_channels)
"""
if isinstance(one_object, Templates):
if return_scaled != one_object.is_scaled:
if return_scaled != one_object.is_in_uV:
raise ValueError(
f"get_dense_templates_array: return_scaled={return_scaled} is not possible Templates has the reverse"
)
Expand Down Expand Up @@ -165,15 +165,15 @@ def get_template_extremum_channel(
channel_ids = templates_or_sorting_analyzer.channel_ids

# if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise
# we use the Templates is_scaled
# we use the Templates is_in_uV
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
# For backward compatibility
if hasattr(templates_or_sorting_analyzer, "return_scaled"):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = templates_or_sorting_analyzer.return_in_uV
else:
return_scaled = templates_or_sorting_analyzer.is_scaled
return_scaled = templates_or_sorting_analyzer.is_in_uV

peak_values = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled
Expand Down Expand Up @@ -218,15 +218,15 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak
shifts = {}

# We need to use the SortingAnalyzer return_scaled
# We need to use the Templates is_scaled
# We need to use the Templates is_in_uV
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
# For backward compatibility
if hasattr(templates_or_sorting_analyzer, "return_scaled"):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = templates_or_sorting_analyzer.return_in_uV
else:
return_scaled = templates_or_sorting_analyzer.is_scaled
return_scaled = templates_or_sorting_analyzer.is_in_uV

templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled)

Expand Down Expand Up @@ -291,7 +291,7 @@ def get_template_extremum_amplitude(
else:
return_scaled = templates_or_sorting_analyzer.return_in_uV
else:
return_scaled = templates_or_sorting_analyzer.is_scaled
return_scaled = templates_or_sorting_analyzer.is_in_uV

extremum_amplitudes = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled, abs_value=abs_value
Expand Down
44 changes: 22 additions & 22 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from probeinterface import generate_multi_columns_probe


def generate_test_template(template_type, is_scaled=True) -> Templates:
def generate_test_template(template_type, is_in_uV=True) -> Templates:
num_units = 3
num_samples = 5
num_channels = 4
Expand All @@ -28,7 +28,7 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
probe=probe,
unit_ids=unit_ids,
channel_ids=channel_ids,
is_scaled=is_scaled,
is_in_uV=is_in_uV,
)
elif template_type == "sparse": # sparse with sparse templates
sparsity_mask = np.array(
Expand All @@ -53,7 +53,7 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
is_in_uV=is_in_uV,
unit_ids=unit_ids,
channel_ids=channel_ids,
)
Expand All @@ -68,16 +68,16 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
is_in_uV=is_in_uV,
unit_ids=unit_ids,
channel_ids=channel_ids,
)


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("is_in_uV", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_pickle_serialization(template_type, is_scaled, tmp_path):
template = generate_test_template(template_type, is_scaled)
def test_pickle_serialization(template_type, is_in_uV, tmp_path):
template = generate_test_template(template_type, is_in_uV)

# Dump to pickle
pkl_path = tmp_path / "templates.pkl"
Expand All @@ -91,21 +91,21 @@ def test_pickle_serialization(template_type, is_scaled, tmp_path):
assert template == template_reloaded


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("is_in_uV", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_json_serialization(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
def test_json_serialization(template_type, is_in_uV):
template = generate_test_template(template_type, is_in_uV)

json_str = template.to_json()
template_reloaded_from_json = Templates.from_json(json_str)

assert template == template_reloaded_from_json


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("is_in_uV", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_get_dense_templates(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
def test_get_dense_templates(template_type, is_in_uV):
template = generate_test_template(template_type, is_in_uV)
dense_templates = template.get_dense_templates()
assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels)

Expand All @@ -115,10 +115,10 @@ def test_initialization_fail_with_dense_templates():
template = generate_test_template(template_type="sparse_with_dense_templates")


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("is_in_uV", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_save_and_load_zarr(template_type, is_scaled, tmp_path):
original_template = generate_test_template(template_type, is_scaled)
def test_save_and_load_zarr(template_type, is_in_uV, tmp_path):
original_template = generate_test_template(template_type, is_in_uV)

zarr_path = tmp_path / "templates.zarr"
original_template.to_zarr(str(zarr_path))
Expand All @@ -129,10 +129,10 @@ def test_save_and_load_zarr(template_type, is_scaled, tmp_path):
assert original_template == loaded_template


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("is_in_uV", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_select_units(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
def test_select_units(template_type, is_in_uV):
template = generate_test_template(template_type, is_in_uV)
selected_unit_ids = ["unit_a", "unit_c"]
selected_unit_ids_indices = [0, 2]

Expand All @@ -149,10 +149,10 @@ def test_select_units(template_type, is_scaled):
assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[selected_unit_ids_indices])


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("is_in_uV", [True, False])
@pytest.mark.parametrize("template_type", ["dense"])
def test_select_channels(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
def test_select_channels(template_type, is_in_uV):
template = generate_test_template(template_type, is_in_uV)
selected_channel_ids = ["channel1", "channel3"]
selected_channel_ids_indices = [0, 2]

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer):
sparsity_mask=None,
channel_ids=sorting_analyzer.channel_ids,
unit_ids=sorting_analyzer.unit_ids,
is_scaled=sorting_analyzer.return_in_uV,
is_in_uV=sorting_analyzer.return_in_uV,
)
return templates

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def from_static_templates(cls, templates: Templates):
nbefore=templates.nbefore,
probe=templates.probe,
sparsity_mask=templates.sparsity_mask,
is_scaled=templates.is_scaled,
is_in_uV=templates.is_in_uV,
unit_ids=templates.unit_ids,
channel_ids=templates.channel_ids,
)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def generate_drifting_recording(
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
is_in_uV=True,
)

drifting_templates = DriftingTemplates.from_static_templates(templates)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/generation/tests/test_drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def make_some_templates():
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
is_in_uV=True,
)

return templates
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/localization_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def compute_grid_convolution(

def get_return_scaled(sorting_analyzer_or_templates):
if isinstance(sorting_analyzer_or_templates, Templates):
return_scaled = sorting_analyzer_or_templates.is_scaled
return_scaled = sorting_analyzer_or_templates.is_in_uV
else:
return_scaled = sorting_analyzer_or_templates.return_scaled
return return_scaled
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
nbefore=nbefore,
sparsity_mask=None,
probe=recording_for_peeler.get_probe(),
is_scaled=False,
is_in_uV=False,
)

# TODO : try other methods for sparsity
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
channel_ids=recording.channel_ids,
unit_ids=templates.unit_ids[valid_templates],
probe=recording.get_probe(),
is_scaled=False,
is_in_uV=False,
)

if params["debug"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
nbefore=nbefore,
sparsity_mask=None,
probe=recording.get_probe(),
is_scaled=False,
is_in_uV=False,
)

labels, peak_labels = remove_duplicates_via_matching(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
channel_ids=recording.channel_ids,
unit_ids=unit_ids[valid_templates],
probe=recording.get_probe(),
is_scaled=False,
is_in_uV=False,
)

sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_templates_from_peaks_and_recording(
channel_ids=recording.channel_ids,
unit_ids=labels,
probe=recording.get_probe(),
is_scaled=False,
is_in_uV=False,
)

return templates
Expand Down Expand Up @@ -352,7 +352,7 @@ def get_templates_from_peaks_and_svd(
channel_ids=recording.channel_ids,
unit_ids=labels,
probe=recording.get_probe(),
is_scaled=False,
is_in_uV=False,
)

return templates
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def remove_empty_templates(templates):
channel_ids=templates.channel_ids,
unit_ids=templates.unit_ids[not_empty],
probe=templates.probe,
is_scaled=templates.is_scaled,
is_in_uV=templates.is_in_uV,
)


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/drift_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _update_ipywidget(self, keep_lims=False):
templates_array,
self.drifting_templates.sampling_frequency,
self.drifting_templates.nbefore,
is_scaled=self.drifting_templates.is_scaled,
is_in_uV=self.drifting_templates.is_in_uV,
sparsity_mask=None,
channel_ids=self.drifting_templates.channel_ids,
unit_ids=self.drifting_templates.unit_ids,
Expand Down