Skip to content

Update/Fix for MaxWell systems #4018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 102 additions & 17 deletions src/spikeinterface/extractors/neoextractors/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def install_maxwell_plugin(self, force_download=False):
auto_install_maxwell_hdf5_compression_plugin(force_download=False)


_maxwell_event_dtype = np.dtype([("frame", "int64"), ("state", "int8"), ("time", "float64")])
_maxwell_event_dtype = np.dtype(
[("id", "int8"), ("frame", "uint32"), ("time", "float64"), ("state", "uint32"), ("message", "object")]
)


class MaxwellEventExtractor(BaseEvent):
Expand All @@ -107,9 +109,41 @@ def __init__(self, file_path):
version = int(h5_file["version"][0].decode())
fs = 20000

if version < 20190530:
raise NotImplementedError(f"Version {self.version} not supported")

# get ttl events
bits = h5_file["bits"]
bit_states = bits["bits"]
channel_ids = np.unique(bit_states[bit_states != 0])

channel_ids = np.zeros((0), dtype=np.int8)
if len(bits) > 0:
bit_state = bits["bits"]
channel_ids = np.int8(np.unique(bit_state[bit_state != 0]))
if -1 in channel_ids or 1 in channel_ids:
raise ValueError("TTL bits cannot be -1 or 1.")

# access data_store from h5_file
data_store_keys = [x for x in h5_file["data_store"].keys()]
data_store_keys_id = [
("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys())
for x in data_store_keys
]
data_store = data_store_keys[data_store_keys_id.index(True)]

# get stim events
event_raw = h5_file["data_store"][data_store]["events"]
channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw]))
if -1 in channel_ids_stim or 0 in channel_ids_stim:
raise ValueError("Stimulation bits cannot be -1 or 0.")
if len(channel_ids) > 0:
if set(channel_ids) & set(channel_ids_stim):
raise ValueError("TTL and stimulation bits overlap.")
channel_ids = np.concatenate((channel_ids, channel_ids_stim), dtype=np.int8)

# set spike events channel == -1
spike_raw = h5_file["data_store"][data_store]["spikes"]
if len(spike_raw) > 0:
channel_ids = np.concatenate((channel_ids, [-1]), dtype=np.int8)

BaseEvent.__init__(self, channel_ids, structured_dtype=_maxwell_event_dtype)
event_segment = MaxwellEventSegment(h5_file, version, fs)
Expand All @@ -125,22 +159,73 @@ def __init__(self, h5_file, version, fs):
self.fs = fs

def get_events(self, channel_id, start_time, end_time):
if self.version != 20160704:
raise NotImplementedError(f"Version {self.version} not supported")
bits = self.bits

# get ttl events
channel_ids = np.zeros((0), dtype=np.int8)
bit_channel = np.zeros((0), dtype=np.int8)
bit_frameno = np.zeros((0), dtype=np.uint32)
bit_state = np.zeros((0), dtype=np.uint32)
bit_message = np.zeros((0), dtype=object)
if len(bits) > 0:
good_idx = np.where(bits["bits"] != 0)[0]
channel_ids = np.concatenate((channel_ids, np.int8(np.unique(bits["bits"][good_idx]))))
if 1 in channel_ids:
raise ValueError("TTL bits cannot be 1.")
bit_channel = np.concatenate((bit_channel, np.uint8(bits["bits"][good_idx])))
bit_frameno = np.concatenate((bit_frameno, np.uint32(bits["frameno"][good_idx])))
bit_state = np.concatenate((bit_state, np.uint32(bits["bits"][good_idx])))
bit_message = np.concatenate((bit_message, [b"{}\n"] * len(bit_state)), dtype=object)

# access data_store from h5_file
h5_file = self.h5_file
data_store_keys = [x for x in h5_file["data_store"].keys()]
data_store_keys_id = [
("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys())
for x in data_store_keys
]
data_store = data_store_keys[data_store_keys_id.index(True)]

# get stim events
event_raw = h5_file["data_store"][data_store]["events"]
channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw]))
stim_arr = np.array(event_raw)
bit_channel_stim = stim_arr["eventtype"]
bit_frameno_stim = stim_arr["frameno"]
bit_state_stim = stim_arr["eventid"]
bit_message_stim = stim_arr["eventmessage"]

# get spike events
spike_raw = h5_file["data_store"][data_store]["spikes"]
if len(spike_raw) > 0:
channel_ids_spike = np.int8([-1])
spike_arr = np.array(spike_raw)
bit_channel_spike = -np.ones(len(spike_arr), dtype=np.int8)
bit_frameno_spike = spike_arr["frameno"]
bit_state_spike = spike_arr["channel"]
bit_message_spike = spike_arr["amplitude"]

# final array in order: spikes, stims, ttl
bit_channel = np.concatenate((bit_channel_spike, bit_channel_stim, bit_channel))
bit_frameno = np.concatenate((bit_frameno_spike, bit_frameno_stim, bit_frameno))
bit_state = np.concatenate((bit_state_spike, bit_state_stim, bit_state))
bit_message = np.concatenate((bit_message_spike, bit_message_stim, bit_message))

first_frame = h5_file["data_store"][data_store]["groups/routed/frame_nos"][0]
bit_frameno = bit_frameno - first_frame

framevals = self.h5_file["sig"][-2:, 0]
first_frame = framevals[1] << 16 | framevals[0]
ttl_frames = self.bits["frameno"] - first_frame
ttl_states = self.bits["bits"]
if channel_id is not None:
bits_channel_idx = np.where((ttl_states == channel_id) | (ttl_states == 0))[0]
ttl_frames = ttl_frames[bits_channel_idx]
ttl_states = ttl_states[bits_channel_idx]
ttl_states[ttl_states == 0] = -1
event = np.zeros(len(ttl_frames), dtype=_maxwell_event_dtype)
event["frame"] = ttl_frames
event["time"] = ttl_frames / self.fs
event["state"] = ttl_states
good_idx = np.where(bit_channel == channel_id)[0]
bit_channel = bit_channel[good_idx]
bit_frameno = bit_frameno[good_idx]
bit_state = bit_state[good_idx]
bit_message = bit_message[good_idx]
event = np.zeros(len(bit_channel), dtype=_maxwell_event_dtype)
event["id"] = bit_channel
event["frame"] = bit_frameno
event["time"] = np.float64(bit_frameno) / self.fs
event["state"] = bit_state
event["message"] = bit_message

if start_time is not None:
event = event[event["time"] >= start_time]
Expand Down
15 changes: 15 additions & 0 deletions src/spikeinterface/extractors/neoextractors/neobaseextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,21 @@ def __init__(
# need neo 0.10.0
signal_channels = self.neo_reader.header["signal_channels"]
mask = signal_channels["stream_id"] == stream_id

# remove all duplicate channel assigments corresponding to different electrodes (channel is a mix of mulitple electrode signals)
mask_id = np.argwhere(mask).flatten()
signal_channels_chan, _ = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"])))
[u, u_c] = np.unique(signal_channels_chan, return_counts=True)
for i in np.argwhere(u_c > 1).flatten():
mask[mask_id[np.argwhere(signal_channels_chan == u[i])[:].flatten()]] = False

# remove subsequent duplicated electrodes (single electrode saved to multiple channels)
mask_id = np.argwhere(mask).flatten()
_, signal_channels_elec = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"])))
[u, u_c] = np.unique(signal_channels_elec, return_counts=True)
for i in np.argwhere(u_c > 1).flatten():
mask[mask_id[np.argwhere(signal_channels_elec == u[i])[1:].flatten()]] = False

signal_channels = signal_channels[mask]

if use_names_as_ids:
Expand Down
Loading