Skip to content

Commit

Permalink
Merge branch 'master' into final-2.0-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Jan 17, 2025
2 parents dcb9440 + 7f6e973 commit b7a2248
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 44 deletions.
7 changes: 5 additions & 2 deletions neo/core/analogsignal.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def __new__(
"""
if copy is not None:
raise ValueError(
"`copy` is now deprecated in Neo due to removal in NumPy 2.0 and will be removed in 0.15.0."
"`copy` is now deprecated in Neo due to removal in Quantites to support Numpy 2.0. "
"In order to facilitate the deprecation copy can be set to None but will raise an "
"error if set to True/False since this will silently do nothing. This argument will be completely "
"removed in Neo 0.15.0. Please update your code base as necessary."
)

signal = cls._rescale(signal, units=units)
Expand All @@ -210,7 +213,7 @@ def __new__(
obj.shape = (-1, 1)

if t_start is None:
raise ValueError("t_start cannot be None")
raise ValueError("`t_start` cannot be None")
obj._t_start = t_start

obj._sampling_rate = _get_sampling_rate(sampling_rate, sampling_period)
Expand Down
5 changes: 4 additions & 1 deletion neo/core/imagesequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def __new__(

if copy is not None:
raise ValueError(
"`copy` is now deprecated in Neo due to removal in NumPy 2.0 and will be removed in 0.15.0."
"`copy` is now deprecated in Neo due to removal in Quantites to support Numpy 2.0. "
"In order to facilitate the deprecation copy can be set to None but will raise an "
"error if set to True/False since this will silently do nothing. This argument will be completely "
"removed in Neo 0.15.0. Please update your code base as necessary."
)

if spatial_scale is None:
Expand Down
5 changes: 4 additions & 1 deletion neo/core/irregularlysampledsignal.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def __new__(

if copy is not None:
raise ValueError(
"`copy` is now deprecated in Neo due to removal in NumPy 2.0 and will be removed in 0.15.0."
"`copy` is now deprecated in Neo due to removal in Quantites to support Numpy 2.0. "
"In order to facilitate the deprecation copy can be set to None but will raise an "
"error if set to True/False since this will silently do nothing. This argument will be completely "
"removed in Neo 0.15.0. Please update your code base as necessary."
)

signal = cls._rescale(signal, units=units)
Expand Down
12 changes: 10 additions & 2 deletions neo/core/spiketrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def normalize_times_array(times, units=None, dtype=None, copy=None):
"""

if copy is not None:
raise ValueError("`copy` is now deprecated in Neo due to removal in NumPy 2.0 and will be removed in 0.15.0.")
raise ValueError(
"`copy` is now deprecated in Neo due to removal in Quantites to support Numpy 2.0. "
"In order to facilitate the deprecation copy can be set to None but will raise an "
"error if set to True/False since this will silently do nothing. This argument will be completely "
"removed in Neo 0.15.0. Please update your code base as necessary."
)

if dtype is None:
if not hasattr(times, "dtype"):
Expand Down Expand Up @@ -352,7 +357,10 @@ def __new__(
"""
if copy is not None:
raise ValueError(
"`copy` is now deprecated in Neo due to removal in NumPy 2.0 and will be removed in 0.15.0."
"`copy` is now deprecated in Neo due to removal in Quantites to support Numpy 2.0. "
"In order to facilitate the deprecation copy can be set to None but will raise an "
"error if set to True/False since this will silently do nothing. This argument will be completely "
"removed in Neo 0.15.0. Please update your code base as necessary."
)

if len(times) != 0 and waveforms is not None and len(times) != waveforms.shape[0]:
Expand Down
16 changes: 12 additions & 4 deletions neo/rawio/blackrockrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def _parse_header(self):
segment_mask = ev_ids == data_bl
if data[segment_mask].size > 0:
t = data[segment_mask][-1]["timestamp"] / self.__nev_basic_header["timestamp_resolution"]

max_nev_time = max(max_nev_time, t)
if max_nev_time > t_stop:
t_stop = max_nev_time
Expand Down Expand Up @@ -680,7 +681,8 @@ def _get_timestamp_slice(self, timestamp, seg_index, t_start, t_stop):
if t_start is None:
t_start = self._seg_t_starts[seg_index]
if t_stop is None:
t_stop = self._seg_t_stops[seg_index]
t_stop = self._seg_t_stops[seg_index] + 1 / float(
self.__nev_basic_header['timestamp_resolution'])

if t_start is None:
ind_start = None
Expand Down Expand Up @@ -713,10 +715,16 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start,
)
unit_spikes = all_spikes[mask]

wf_dtype = self.__nev_params("waveform_dtypes")[channel_id]
wf_size = self.__nev_params("waveform_size")[channel_id]
wf_dtype = self.__nev_params('waveform_dtypes')[channel_id]
wf_size = self.__nev_params('waveform_size')[channel_id]
wf_byte_size = np.dtype(wf_dtype).itemsize * wf_size

dt1 = [
('extra', 'S{}'.format(unit_spikes['waveform'].dtype.itemsize - wf_byte_size)),
('ch_waveform', 'S{}'.format(wf_byte_size))]

waveforms = unit_spikes['waveform'].view(dt1)['ch_waveform'].flatten().view(wf_dtype)

waveforms = unit_spikes["waveform"].flatten().view(wf_dtype)
waveforms = waveforms.reshape(int(unit_spikes.size), 1, int(wf_size))

timestamp = unit_spikes["timestamp"]
Expand Down
116 changes: 82 additions & 34 deletions neo/rawio/spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ class SpikeGLXRawIO(BaseRawWithBufferApiIO):
Notes
-----
* Contrary to other implementations this IO reads the entire folder and subfolders and:
deals with several segments based on the `_gt0`, `_gt1`, `_gt2`, etc postfixes
deals with all signals "imec0", "imec1" for neuropixel probes and also
external signal like"nidq". This is the "device"
* For imec device both "ap" and "lf" are extracted so one device have several "streams"
* There are several versions depending the neuropixel probe generation (`1.x`/`2.x`/`3.x`)
* Here, we assume that the `meta` file has the same structure across all generations.
* This IO is developed based on neuropixel generation 2.0, single shank recordings.
* This IO reads the entire folder and subfolders locating the `.bin` and `.meta` files
* Handles gates and triggers as segments (based on the `_gt0`, `_gt1`, `_t0` , `_t1` in filenames)
* Handles all signals coming from different acquisition cards ("imec0", "imec1", etc) in a typical
PXIe chassis setup and also external signal like "nidq".
* For imec devices both "ap" and "lf" are extracted so even a one device setup will have several "streams"
Examples
--------
Expand Down Expand Up @@ -125,7 +122,6 @@ def _parse_header(self):
stream_names = sorted(list(srates.keys()), key=lambda e: srates[e])[::-1]
nb_segment = np.unique([info["seg_index"] for info in self.signals_info_list]).size

# self._memmaps = {}
self.signals_info_dict = {}
# one unique block
self._buffer_descriptions = {0: {}}
Expand Down Expand Up @@ -166,7 +162,6 @@ def _parse_header(self):

stream_id = stream_name

stream_index = stream_names.index(info["stream_name"])
signal_streams.append((stream_name, stream_id, buffer_id))

# add channels to global list
Expand Down Expand Up @@ -229,14 +224,25 @@ def _parse_header(self):
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

# deal with nb_segment and t_start/t_stop per segment
self._t_starts = {seg_index: 0.0 for seg_index in range(nb_segment)}

self._t_starts = {stream_name: {} for stream_name in stream_names}
self._t_stops = {seg_index: 0.0 for seg_index in range(nb_segment)}
for seg_index in range(nb_segment):
for stream_name in stream_names:

for stream_name in stream_names:
for seg_index in range(nb_segment):
info = self.signals_info_dict[seg_index, stream_name]

frame_start = float(info["meta"]["firstSample"])
sampling_frequency = info["sampling_rate"]
t_start = frame_start / sampling_frequency

self._t_starts[stream_name][seg_index] = t_start
t_stop = info["sample_length"] / info["sampling_rate"]
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)




# fille into header dict
self.header = {}
self.header["nb_block"] = 1
Expand All @@ -250,7 +256,6 @@ def _parse_header(self):
# insert some annotation at some place
self._generate_minimal_annotations()
self._generate_minimal_annotations()
block_ann = self.raw_annotations["blocks"][0]

for seg_index in range(nb_segment):
seg_ann = self.raw_annotations["blocks"][0]["segments"][seg_index]
Expand Down Expand Up @@ -282,7 +287,8 @@ def _segment_t_stop(self, block_index, seg_index):
return self._t_stops[seg_index]

def _get_signal_t_start(self, block_index, seg_index, stream_index):
return 0.0
stream_name = self.header["signal_streams"][stream_index]["name"]
return self._t_starts[stream_name][seg_index]

def _event_count(self, event_channel_idx, block_index=None, seg_index=None):
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None)
Expand Down Expand Up @@ -354,23 +360,56 @@ def scan_files(dirname):
if len(info_list) == 0:
raise FileNotFoundError(f"No appropriate combination of .meta and .bin files were detected in {dirname}")

# the segment index will depend on both 'gate_num' and 'trigger_num'
# so we order by 'gate_num' then 'trigger_num'
# None is before any int
def make_key(info):
k0 = info["gate_num"]
if k0 is None:
k0 = -1
k1 = info["trigger_num"]
if k1 is None:
k1 = -1
return (k0, k1)

order_key = list({make_key(info) for info in info_list})
order_key = sorted(order_key)
# This sets non-integers values before integers
normalize = lambda x: x if isinstance(x, int) else -1

# Segment index is determined by the gate_num and trigger_num in that order
def get_segment_tuple(info):
# Create a key from the normalized gate_num and trigger_num
gate_num = normalize(info.get("gate_num"))
trigger_num = normalize(info.get("trigger_num"))
return (gate_num, trigger_num)

unique_segment_tuples = {get_segment_tuple(info) for info in info_list}
sorted_keys = sorted(unique_segment_tuples)

# Map each unique key to a corresponding index
segment_tuple_to_segment_index = {key: idx for idx, key in enumerate(sorted_keys)}

for info in info_list:
info["seg_index"] = order_key.index(make_key(info))
info["seg_index"] = segment_tuple_to_segment_index[get_segment_tuple(info)]


# Probe index calculation
# The calculation is ordered by slot, port, dock in that order, this is the number that appears in the filename
# after imec when using native names (e.g. imec0, imec1, etc.)
def get_probe_tuple(info):
slot = normalize(info.get("probe_slot"))
port = normalize(info.get("probe_port"))
dock = normalize(info.get("probe_dock"))
return (slot, port, dock)

# TODO: handle one box case
info_list_imec = [info for info in info_list if info.get("device") != "nidq"]
unique_probe_tuples = {get_probe_tuple(info) for info in info_list_imec}
sorted_probe_keys = sorted(unique_probe_tuples)
probe_tuple_to_probe_index = {key: idx for idx, key in enumerate(sorted_probe_keys)}

for info in info_list:
if info.get("device") == "nidq":
info["device_index"] = "" # TODO: Handle multi nidq case, maybe use meta["typeNiEnabled"]
else:
info["device_index"] = probe_tuple_to_probe_index[get_probe_tuple(info)]

# Define stream base on device [imec|nidq], device_index and stream_kind [ap|lf] for imec
for info in info_list:
device_kind = info["device_kind"]
device_index = info["device_index"]
stream_kind = f".{info['stream_kind']}" if info["stream_kind"] else ""
stream_name = f"{device_kind}{device_index}{stream_kind}"

info["stream_name"] = stream_name

return info_list


Expand Down Expand Up @@ -488,13 +527,15 @@ def extract_stream_info(meta_file, meta):
else:
# NIDQ case
has_sync_trace = False
fname = Path(meta_file).stem

bin_file_path = meta["fileName"]
fname = Path(bin_file_path).stem

run_name, gate_num, trigger_num, device, stream_kind = parse_spikeglx_fname(fname)

if "imec" in fname.split(".")[-2]:
device = fname.split(".")[-2]
stream_kind = fname.split(".")[-1]
stream_name = device + "." + stream_kind
units = "uV"
# please note the 1e6 in gain for this uV

Expand Down Expand Up @@ -534,7 +575,6 @@ def extract_stream_info(meta_file, meta):
else:
device = fname.split(".")[-1]
stream_kind = ""
stream_name = device
units = "V"
channel_gains = np.ones(num_chan)

Expand All @@ -550,6 +590,10 @@ def extract_stream_info(meta_file, meta):
gain_factor = float(meta["niAiRangeMax"]) / 32768
channel_gains = per_channel_gain * gain_factor

probe_slot = meta.get("imDatPrb_slot", None)
probe_port = meta.get("imDatPrb_port", None)
probe_dock = meta.get("imDatPrb_dock", None)

info = {}
info["fname"] = fname
info["meta"] = meta
Expand All @@ -563,12 +607,16 @@ def extract_stream_info(meta_file, meta):
info["trigger_num"] = trigger_num
info["device"] = device
info["stream_kind"] = stream_kind
info["stream_name"] = stream_name
# All non-production probes (phase 3B onwards) have "typeThis", otherwise revert to file parsing
info["device_kind"] = meta.get("typeThis", device.split(".")[0])
info["units"] = units
info["channel_names"] = [txt.split(";")[0] for txt in meta["snsChanMap"]]
info["channel_gains"] = channel_gains
info["channel_offsets"] = np.zeros(info["num_chan"])
info["has_sync_trace"] = has_sync_trace
info["probe_slot"] = int(probe_slot) if probe_slot else None
info["probe_port"] = int(probe_port) if probe_port else None
info["probe_dock"] = int(probe_dock) if probe_dock else None

if "nidq" in device:
info["digital_channels"] = []
Expand Down
61 changes: 61 additions & 0 deletions neo/test/rawiotest/test_spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase):
"spikeglx/NP2_subset_with_sync",
# NP-ultra
"spikeglx/np_ultra_stub",
# Filename changed by the user, multi-dock
"spikeglx/multi_probe_multi_dock_multi_shank_filename_without_info",
# CatGT
"spikeglx/multi_trigger_multi_gate/CatGT/CatGT-A",
"spikeglx/multi_trigger_multi_gate/CatGT/CatGT-B",
Expand Down Expand Up @@ -110,6 +112,65 @@ def test_nidq_digital_channel(self):
atol = 0.001
assert np.allclose(on_diff, 1, atol=atol)

def test_t_start_reading(self):
"""Test that t_start values are correctly read for all streams and segments."""

# Expected t_start values for each stream and segment
expected_t_starts = {
'imec0.ap': {
0: 15.319535472007237,
1: 15.339535431281986,
2: 21.284723325294053,
3: 21.3047232845688
},
'imec1.ap': {
0: 15.319554693264516,
1: 15.339521518106308,
2: 21.284735282142822,
3: 21.304702106984614
},
'imec0.lf': {
0: 15.3191688060872,
1: 15.339168765361949,
2: 21.284356659374016,
3: 21.304356618648765
},
'imec1.lf': {
0: 15.319321358082725,
1: 15.339321516521915,
2: 21.284568614155827,
3: 21.30456877259502
}
}

# Initialize the RawIO
rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4"))
rawio.parse_header()

# Get list of stream names
stream_names = rawio.header["signal_streams"]["name"]

# Test t_start for each stream and segment
for stream_name, expected_values in expected_t_starts.items():
# Get stream index
stream_index = list(stream_names).index(stream_name)

# Check each segment
for seg_index, expected_t_start in expected_values.items():
actual_t_start = rawio.get_signal_t_start(
block_index=0,
seg_index=seg_index,
stream_index=stream_index
)

# Use numpy.testing for proper float comparison
np.testing.assert_allclose(
actual_t_start,
expected_t_start,
rtol=1e-9,
atol=1e-9,
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}"
)

if __name__ == "__main__":
unittest.main()

0 comments on commit b7a2248

Please sign in to comment.