From 963289685ecfa567146898bbfee7bd4203f315e3 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 9 Apr 2024 11:52:09 -0500 Subject: [PATCH] fix: robust channel to electrode mapping and ordering --- element_array_ephys/ephys_organoids.py | 53 +++++++++++++++----------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/element_array_ephys/ephys_organoids.py b/element_array_ephys/ephys_organoids.py index d3b95493..332ff22c 100644 --- a/element_array_ephys/ephys_organoids.py +++ b/element_array_ephys/ephys_organoids.py @@ -290,7 +290,7 @@ def make(self, key): try: data = intanrhdreader.load_file(file) except OSError: - raise OSError(f"OS error occured when loading file {file.name}") + raise OSError(f"OS error occurred when loading file {file.name}") if not header: header = data.pop("header") @@ -736,14 +736,14 @@ def make(self, key): sorting_dir / "si_sorting.pkl" ) - unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_index} + unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( + we, outputs="id" + ) # {unit: peak_channel_id} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector(extremum_channel_inds=unit_peak_channel_map) + spikes = si_sorting.to_spike_vector() # Get electrode & channel info probe_info = (probe.Probe * EphysSessionProbe & key).fetch1() @@ -762,8 +762,13 @@ def make(self, key): ) electrode_query &= f'electrode IN {tuple(probe_info["used_electrodes"])}' - channel_info = electrode_query.fetch(as_dict=True, order_by="electrode") - channel_info: dict[int, dict] = {ch_idx: ch for ch_idx, ch in enumerate(channel_info)} + channel2electrode_map = electrode_query.fetch(as_dict=True) + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in channel2electrode_map + } # e.g., {0: {'organoid_id': 'O09', + + # reorder channel2electrode_map according to recording channel ids + channel2electrode_map = {chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids} # Get unit id to quality label mapping try: @@ -783,7 +788,7 @@ def make(self, key): # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ - channel_info[unit_peak_channel_map[unit_id]]["electrode"] + channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] ) @@ -791,7 +796,7 @@ def make(self, key): # Get channel depth channel_depth_ind = np.array( [ - channel_info[unit_peak_channel_map[unit_id]]["y_coord"] + channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"] for unit_id in si_sorting.unit_ids ] ) @@ -816,7 +821,7 @@ def make(self, key): units.append( { **key, - **channel_info[unit_peak_channel_map[unit_id]], + **channel2electrode_map[unit_peak_channel[unit_id]], "unit": unit_id, "cluster_quality_label": cluster_quality_label_map.get( unit_id, "n.a." @@ -908,9 +913,9 @@ def make(self, key): ) electrode_query &= f'electrode IN {tuple(probe_info["used_electrodes"])}' - channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): key | ch for ch in channel_info + channel2electrode_map = electrode_query.fetch(as_dict=True) + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in channel2electrode_map } # e.g., {0: {'organoid_id': 'O09', waveform_dir = output_dir / sorter_name / "waveform" @@ -921,8 +926,11 @@ def make(self, key): unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + ).unit_id_to_channel_ids + ) # {unit: peak_channel_id} + + # reorder channel2electrode_map according to recording channel ids + channel2electrode_map = {chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids} # Get mean waveform for each unit from all channels mean_waveforms = we.get_all_templates( @@ -933,23 +941,24 @@ def make(self, key): unit_electrode_waveforms = [] for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): + unit_waveforms = we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + ) # (sample x channel) + peak_chn_idx = list(we.channel_ids).index(unit_id_to_peak_channel_map[unit["unit"]][0]) unit_peak_waveform.append( { **unit, - "peak_electrode_waveform": we.get_template( - unit_id=unit["unit"], mode="average", force_dense=True - )[:, unit_id_to_peak_channel_map[unit["unit"]][0]], + "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], } ) - unit_electrode_waveforms.extend( [ { **unit, - **channel_info[c], - "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c], + **channel2electrode_map[c], + "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx], } - for c in channel_info + for c_idx, c in enumerate(channel2electrode_map) ] )