Skip to content

Commit

Permalink
Merge pull request #1584 from h-mayorquin/fix_edf_handle
Browse files Browse the repository at this point in the history
EDFIO: Alleviate EDF single handle problem
  • Loading branch information
zm711 authored Oct 21, 2024
2 parents 8e180f4 + bc0f518 commit e1a8e16
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions neo/rawio/edfrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,14 @@ def _parse_header(self):
# or continuous EDF+ files ('EDF+C' in header)
if ("EDF+" in file_version_header) and ("EDF+C" not in file_version_header):
raise ValueError("Only continuous EDF+ files are currently supported.")

self.edf_reader = EdfReader(self.filename)
self._open_reader()
# load headers, signal information and
self.edf_header = self.edf_reader.getHeader()
self.signal_headers = self.edf_reader.getSignalHeaders()

# add annotations to header
annotations = self.edf_reader.readAnnotations()
self.signal_annotations = [[s, d, a] for s, d, a in zip(*annotations)]
self._edf_annotations = self.edf_reader.readAnnotations()
self.signal_annotations = [[s, d, a] for s, d, a in zip(*self._edf_annotations)]

# 1 stream = 1 sampling rate
stream_characteristics = []
Expand Down Expand Up @@ -120,7 +119,7 @@ def _parse_header(self):
signal_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id, buffer_id))

# convert channel index lists to arrays for indexing
self.stream_idx_to_chidx = {k: np.array(v) for k, v in self.stream_idx_to_chidx.items()}
self.stream_idx_to_chidx = {k: np.asarray(v) for k, v in self.stream_idx_to_chidx.items()}

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

Expand Down Expand Up @@ -174,6 +173,15 @@ def _parse_header(self):
for array_key in array_keys:
array_anno = {array_key: [h[array_key] for h in self.signal_headers]}
seg_ann["signals"].append({"__array_annotations__": array_anno})

# We store the following attributes for rapid access without needing the reader

self._t_stop = self.edf_reader.datarecord_duration * self.edf_reader.datarecords_in_file
# use sample count of first signal in stream
self._stream_index_samples = {stream_index : self.edf_reader.getNSamples()[chidx][0] for stream_index, chidx in self.stream_idx_to_chidx.items()}
self._number_of_events = len(self.edf_reader.readAnnotations()[0])

self.close()

def _get_stream_channels(self, stream_index):
return self.header["signal_channels"][self.stream_idx_to_chidx[stream_index]]
Expand All @@ -183,14 +191,11 @@ def _segment_t_start(self, block_index, seg_index):
return 0.0 # in seconds

def _segment_t_stop(self, block_index, seg_index):
t_stop = self.edf_reader.datarecord_duration * self.edf_reader.datarecords_in_file
# this must return an float scale in second
return t_stop
return self._t_stop

def _get_signal_size(self, block_index, seg_index, stream_index):
chidx = self.stream_idx_to_chidx[stream_index][0]
# use sample count of first signal in stream
return self.edf_reader.getNSamples()[chidx]
return self._stream_index_samples[stream_index]

def _get_signal_t_start(self, block_index, seg_index, stream_index):
return 0.0 # EDF does not provide temporal offset information
Expand Down Expand Up @@ -219,12 +224,13 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea

# load data into numpy array buffer
data = []
self._open_reader()
for i, channel_idx in enumerate(selected_channel_idxs):
# use int32 for compatibility with pyedflib
buffer = np.empty(n, dtype=np.int32)
self.edf_reader.read_digital_signal(channel_idx, i_start, n, buffer)
data.append(buffer)

self._close_reader()
# downgrade to int16 as this is what is used in the edf file format
# use fortran (column major) order to be more efficient after transposing
data = np.asarray(data, dtype=np.int16, order="F")
Expand All @@ -247,11 +253,11 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
return None

def _event_count(self, block_index, seg_index, event_channel_index):
return len(self.edf_reader.readAnnotations()[0])
return self._number_of_events

def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
# these time should be already in seconds
timestamps, durations, labels = self.edf_reader.readAnnotations()
timestamps, durations, labels = self._edf_annotations
if t_start is None:
t_start = self.segment_t_start(block_index, seg_index)
if t_stop is None:
Expand Down Expand Up @@ -281,6 +287,9 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
return np.asarray(raw_duration, dtype=dtype)

def _open_reader(self):
self.edf_reader = EdfReader(self.filename)

def __enter__(self):
return self

Expand Down

0 comments on commit e1a8e16

Please sign in to comment.