Skip to content

Commit

Permalink
Merge pull request #3588 from h-mayorquin/use_strings_as_ids_in_gener…
Browse files Browse the repository at this point in the history
…ators

Use strings as ids in generators
  • Loading branch information
alejoe91 authored Jan 7, 2025
2 parents 6fde997 + 7dea3b2 commit d38dbf4
Show file tree
Hide file tree
Showing 16 changed files with 89 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_total_duration(self) -> float:

def get_unit_spike_train(
self,
unit_id,
unit_id: str | int,
segment_index: Union[int, None] = None,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import warnings
import numpy as np
from typing import Literal
from typing import Literal, Optional
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -134,7 +134,7 @@ def generate_sorting(
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)
unit_ids = [str(idx) for idx in np.arange(num_units)]

spikes = []
for segment_index in range(num_segments):
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def __init__(
"""

unit_ids = np.arange(num_units)
unit_ids = [str(idx) for idx in np.arange(num_units)]
super().__init__(sampling_frequency, unit_ids)

self.num_units = num_units
Expand All @@ -1138,6 +1138,7 @@ def __init__(
firing_rates=firing_rates,
refractory_period_seconds=self.refractory_period_seconds,
seed=segment_seed,
unit_ids=unit_ids,
t_start=None,
)
self.add_sorting_segment(segment)
Expand All @@ -1161,6 +1162,7 @@ def __init__(
firing_rates: float | np.ndarray,
refractory_period_seconds: float | np.ndarray,
seed: int,
unit_ids: list[str],
t_start: Optional[float] = None,
):
self.num_units = num_units
Expand All @@ -1177,7 +1179,8 @@ def __init__(
self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64")

self.segment_seed = seed
self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)}
self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids}

self.num_samples = math.ceil(sampling_frequency * duration)
super().__init__(t_start)

Expand Down Expand Up @@ -1280,7 +1283,7 @@ def __init__(
noise_block_size: int = 30000,
):

channel_ids = np.arange(num_channels)
channel_ids = [str(idx) for idx in np.arange(num_channels)]
dtype = np.dtype(dtype).name # Cast to string for serialization
if dtype not in ("float32", "float64"):
raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}")
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/core/tests/test_basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder):
assert snippets.get_num_segments() == len(duration)
assert snippets.get_num_channels() == num_channels

assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None))
assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None))

# annotations / properties
snippets.annotate(gre="ta")
Expand All @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder):
)

# missing property
snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1])
snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"])
values = snippets.get_property("string_property")
assert values[2] == ""

Expand All @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder):
snippets.set_property,
key="string_property_nan",
values=["hola", "chabon"],
ids=[0, 1],
ids=["0", "1"],
missing_value=np.nan,
)

# int properties without missing values raise an error
assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2])

snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200)
snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200)
values = snippets.get_property("int_property")
assert values.dtype.kind == "i"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def test_channelsaggregationrecording():

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg)
traces2_0,
recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg),
)
assert np.allclose(
traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg)
traces3_2,
recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg),
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def get_dataset():
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

# TODO: the tests or the sorting analyzer make assumptions about the ids being integers
# So keeping this the way it was
integer_channel_ids = [int(id) for id in recording.get_channel_ids()]
integer_unit_ids = [int(id) for id in sorting.get_unit_ids()]

recording = recording.rename_channels(new_channel_ids=integer_channel_ids)
sorting = sorting.rename_units(new_unit_ids=integer_unit_ids)
return recording, sorting


Expand Down
22 changes: 13 additions & 9 deletions src/spikeinterface/core/tests/test_unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,43 @@
def test_basic_functions():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
assert np.array_equal(sorting2.unit_ids, [0, 2])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"])
assert np.array_equal(sorting2.unit_ids, ["0", "2"])
assert sorting2.get_parent() == sorting

sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"])
sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"])
assert np.array_equal(sorting3.unit_ids, ["a", "b"])

assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting2.get_unit_spike_train(unit_id="0", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting3.get_unit_spike_train(unit_id="a", segment_index=0),
)

assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting2.get_unit_spike_train(unit_id="2", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting3.get_unit_spike_train(unit_id="b", segment_index=0),
)


def test_failure_with_non_unique_unit_ids():
seed = 10
sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed)
with pytest.raises(AssertionError):
sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"])


def test_custom_cache_spike_vector():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"])
sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"])
cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True)
computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False)
assert np.all(cached_spike_vector == computed_spike_vector)
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def make_sorting_analyzer(sparse=True):
seed=2205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("waveforms", **job_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_gh_curation():
Test curation using GitHub URI.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

# curated link:
# https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5
gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json"
Expand Down Expand Up @@ -76,6 +79,8 @@ def test_sha1_curation():
Test curation using SHA1 URI.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

# from SHA1
# curated link:
Expand Down Expand Up @@ -105,6 +110,8 @@ def test_json_curation():
Test curation using a JSON file.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

# from curation.json
json_file = parent_folder / "sv-sorting-curation.json"
Expand Down Expand Up @@ -248,6 +255,8 @@ def test_json_no_merge_curation():
Test curation with no merges using a JSON file.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

json_file = parent_folder / "sv-sorting-curation-no-merge.json"
sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file)
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/extractors/tests/test_mdaextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def test_mda_extractors(create_cache_folder):
cache_folder = create_cache_folder
rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10)

ids_as_integers = [id for id in range(rec.get_num_channels())]
rec = rec.rename_channels(new_channel_ids=ids_as_integers)

ids_as_integers = [id for id in range(sort.get_num_units())]
sort = sort.rename_units(new_unit_ids=ids_as_integers)

MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest")
rec_mda = MdaRecordingExtractor(cache_folder / "mdatest")
probe = rec_mda.get_probe()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def get_dataset():
seed=2205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

# since templates are going to be averaged and this might be a problem for amplitude scaling
# we select the 3 units with the largest templates to split
analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False)
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/preprocessing/tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def test_clip():
rec1 = clip(rec, a_min=-1.5)
rec1.save(verbose=False)

traces0 = rec0.get_traces(segment_index=0, channel_ids=[1])
traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"])
assert traces0.shape[1] == 1

assert np.all(-2 <= traces0[0] <= 3)

traces1 = rec1.get_traces(segment_index=0, channel_ids=[0, 1])
traces1 = rec1.get_traces(segment_index=0, channel_ids=["0", "1"])
assert traces1.shape[1] == 2

assert np.all(-1.5 <= traces1[1])
Expand All @@ -34,11 +34,11 @@ def test_blank_staturation():
rec1 = blank_staturation(rec, quantile_threshold=0.01, direction="both", chunk_size=10000)
rec1.save(verbose=False)

traces0 = rec0.get_traces(segment_index=0, channel_ids=[1])
traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"])
assert traces0.shape[1] == 1
assert np.all(traces0 < 3.0)

traces1 = rec1.get_traces(segment_index=0, channel_ids=[0])
traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"])
assert traces1.shape[1] == 1
# use a smaller value to be sure
a_min = rec1._recording_segments[0].a_min
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def test_output_values():
expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)]
expected_weights /= np.sum(expected_weights)

si_interpolated_recording = spre.interpolate_bad_channels(recording, bad_channel_indexes, sigma_um=1, p=1)
si_interpolated_recording = spre.interpolate_bad_channels(
recording, bad_channel_ids=bad_channel_ids, sigma_um=1, p=1
)
si_interpolated = si_interpolated_recording.get_traces()

expected_ts = si_interpolated[:, 1:] @ expected_weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_normalize_by_quantile():
rec2 = normalize_by_quantile(rec, mode="by_channel")
rec2.save(verbose=False)

traces = rec2.get_traces(segment_index=0, channel_ids=[1])
traces = rec2.get_traces(segment_index=0, channel_ids=["1"])
assert traces.shape[1] == 1

rec2 = normalize_by_quantile(rec, mode="pool_channel")
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/tests/test_rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_rectify():
rec2 = rectify(rec)
rec2.save(verbose=False)

traces = rec2.get_traces(segment_index=0, channel_ids=[1])
traces = rec2.get_traces(segment_index=0, channel_ids=["1"])
assert traces.shape[1] == 1

# import matplotlib.pyplot as plt
Expand Down
10 changes: 10 additions & 0 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def small_sorting_analyzer():
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
Expand Down Expand Up @@ -60,6 +65,11 @@ def sorting_analyzer_simple():
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/sortingcomponents/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ def make_dataset():
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
seed=2205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

return recording, sorting

0 comments on commit d38dbf4

Please sign in to comment.