diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index a1f2549..47c5457 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import KiloSortWorker, filter_and_whiten, find_good_channels +from .sorter import filter_and_whiten, get_predicted_traces, KiloSortWorker from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index c73b8c2..859a5bd 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -2,6 +2,7 @@ import typing as t import pyqtgraph as pg from cupy import asnumpy +from datetime import datetime from pykilosort.gui.logger import setup_logger from pykilosort.gui.minor_gui_elements import controls_popup_text from pykilosort.gui.palettes import COLORMAP_COLORS @@ -13,8 +14,8 @@ class DataViewBox(QtWidgets.QGroupBox): - channelChanged = QtCore.pyqtSignal() - modeChanged = QtCore.pyqtSignal(str) + channelChanged = QtCore.pyqtSignal(int, int) + modeChanged = QtCore.pyqtSignal(str, int) def __init__(self, parent): QtWidgets.QGroupBox.__init__(self, parent=parent) @@ -225,48 +226,49 @@ def on_views_clicked(self): @QtCore.pyqtSlot(int) def on_wheel_scroll(self, direction): - if self.gui.context is not None: + if self.context_set(): self.shift_current_time(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_control(self, direction): - if self.gui.context is not None: - if self.traces_button.isChecked(): + if self.context_set(): + if self.traces_mode_active(): self.shift_primary_channel(direction) else: self.change_displayed_channel_count(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_shift(self, direction): - if self.gui.context is not None: + if self.context_set(): self.change_plot_range(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_alt(self, direction): - if self.gui.context is not None: + if self.context_set(): self.change_plot_scaling(direction) + @QtCore.pyqtSlot() def toggle_mode_from_click(self): - if self.traces_button.isChecked(): - self.modeChanged.emit("traces") + if self.traces_mode_active(): + self.modeChanged.emit("traces", self.get_currently_displayed_channel_count()) self.view_buttons_group.setExclusive(False) self.update_plot() - if self.colormap_button.isChecked(): - self.modeChanged.emit("colormap") + if self.colormap_mode_active(): + self.modeChanged.emit("colormap", self.get_currently_displayed_channel_count()) self._traces_to_colormap_toggle() self.update_plot() def toggle_mode(self): - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self.traces_button.toggle() - self.modeChanged.emit("traces") + self.modeChanged.emit("traces", self.get_currently_displayed_channel_count()) self.view_buttons_group.setExclusive(False) self.update_plot() - elif self.traces_button.isChecked(): + elif self.traces_mode_active(): self.colormap_button.toggle() - self.modeChanged.emit("colormap") + self.modeChanged.emit("colormap", self.get_currently_displayed_channel_count()) self._traces_to_colormap_toggle() self.update_plot() @@ -301,9 +303,18 @@ def _traces_to_colormap_toggle(self): self.view_buttons_group.setExclusive(True) + def traces_mode_active(self): + return self.traces_button.isChecked() + + def colormap_mode_active(self): + return self.colormap_button.isChecked() + + def context_set(self): + return self.gui.context is not None + def change_primary_channel(self, channel): self.primary_channel = channel - self.channelChanged.emit() + self.channelChanged.emit(self.primary_channel, self.get_currently_displayed_channel_count()) self.update_plot() def shift_primary_channel(self, shift): @@ -312,11 +323,11 @@ def shift_primary_channel(self, shift): total_channels = self.get_total_channels() if (0 <= primary_channel < total_channels) and total_channels is not None: self.primary_channel = primary_channel - self.channelChanged.emit() + self.channelChanged.emit(self.primary_channel, self.get_currently_displayed_channel_count()) self.update_plot() def get_currently_displayed_channel_count(self): - if self.traces_button.isChecked(): + if self.traces_mode_active(): return self.channels_displayed_traces else: count = self.channels_displayed_colormap @@ -325,7 +336,7 @@ def get_currently_displayed_channel_count(self): return count def set_currently_displayed_channel_count(self, count): - if self.traces_button.isChecked(): + if self.traces_mode_active(): self.channels_displayed_traces = count else: self.channels_displayed_colormap = count @@ -334,22 +345,25 @@ def get_total_channels(self): return self.gui.probe_view_box.total_channels def change_displayed_channel_count(self, direction): + current_channel = self.primary_channel total_channels = self.get_total_channels() current_count = self.get_currently_displayed_channel_count() new_count = current_count + (direction * 5) - if 0 < new_count <= total_channels: + if (current_channel + new_count) <= total_channels: self.set_currently_displayed_channel_count(new_count) + self.channelChanged.emit(self.primary_channel, new_count) self.refresh_plot_on_displayed_channel_count_change() elif new_count <= 0 and current_count != 1: self.set_currently_displayed_channel_count(1) + self.channelChanged.emit(self.primary_channel, 1) self.refresh_plot_on_displayed_channel_count_change() - elif new_count > total_channels: - self.set_currently_displayed_channel_count(total_channels) + elif (current_channel + new_count) > total_channels: + self.set_currently_displayed_channel_count(total_channels - current_channel) + self.channelChanged.emit(self.primary_channel, total_channels) self.refresh_plot_on_displayed_channel_count_change() def refresh_plot_on_displayed_channel_count_change(self): - self.channelChanged.emit() self.plot_item.clear() self.create_plot_items() self.update_plot() @@ -371,14 +385,14 @@ def change_plot_range(self, direction): self.update_plot() def change_plot_scaling(self, direction): - if self.traces_button.isChecked(): + if self.traces_mode_active(): scale_factor = self.scale_factor * (1.1 ** direction) if 0.1 < scale_factor < 10.0: self.scale_factor = scale_factor self.update_plot() - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): colormap_min = self.colormap_min + (direction * 0.05) colormap_max = self.colormap_max - (direction * 0.05) if 0.0 <= colormap_min < colormap_max <= 1.0: @@ -390,15 +404,9 @@ def change_plot_scaling(self, direction): self.update_plot() - def change_channel_display(self, direction): - if self.traces_button.isChecked(): - self.shift_primary_channel(direction) - else: - self.change_displayed_channel_count(direction) - def scene_clicked(self, ev): - if self.gui.context is not None: - if self.traces_button.isChecked(): + if self.context_set(): + if self.traces_mode_active(): x_pos = ev.pos().x() else: x_pos = self.colormap_image.mapFromScene(ev.pos()).x() @@ -411,7 +419,7 @@ def scene_clicked(self, ev): self.shift_current_time(direction=-1) def seek_clicked(self, ev): - if self.gui.context is not None: + if self.context_set(): new_time = self.seek_view_box.mapSceneToView(ev.pos()).x() seek_range_min = self.seek_range[0] seek_range_max = self.seek_range[1] @@ -445,7 +453,7 @@ def change_sorting_status(self, status_dict): self.enable_view_buttons() def enable_view_buttons(self): - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): if self.prediction_button.isChecked() or self.residual_button.isChecked(): self.raw_button.click() else: @@ -532,13 +540,13 @@ def add_traces_to_plot_items( ): """ Update plot items with traces. - + Loops over traces and plots each trace using the setData() method of pyqtgraph's PlotCurveItem. The color of the trace depends on the mode requested (raw, whitened, prediction, residual). Bad channels are plotted in a different color. Each trace is also scaled by a certain factor defined in self.traces_scaling_factor. - + Parameters ---------- traces : numpy.ndarray @@ -549,22 +557,17 @@ def add_traces_to_plot_items( view : str One of "raw", "whitened", "prediction" and "residual" views """ - for c, channel in enumerate( - range( - self.primary_channel, - self.primary_channel + self.channels_displayed_traces, - ) - ): + for c, good in enumerate(good_channels): try: curve = self.traces_plot_items[view][c] color = ( self.traces_curve_color[view] - if good_channels[channel] + if good else self.bad_channel_color ) curve.setPen(color=color, width=1) curve.setData( - traces.T[channel] * + traces[:, c] * self.scale_factor * self.traces_scaling_factor[view] ) @@ -622,11 +625,16 @@ def get_whitened_traces( raw_data=raw_data, params=params, probe=probe, nSkipCov=nSkipCov ) + good_channels = intermediate.igood.ravel() \ + if "igood" in intermediate \ + else np.ones_like(probe.chanMapBackup, dtype=bool) + whitened_traces = filter_and_whiten( raw_traces=raw_traces, params=params, probe=probe, whitening_matrix=self.whitening_matrix, + good_channels=good_channels, ) return whitened_traces @@ -635,7 +643,7 @@ def update_plot(self, context=None): if context is None: context = self.gui.context - if context is not None: + if context is not None: # since context may still be None if self.colormap_image is not None: self.plot_item.removeItem(self.colormap_image) self.colormap_image = None @@ -644,7 +652,10 @@ def update_plot(self, context=None): probe = context.probe raw_data = context.raw_data intermediate = context.intermediate - good_channels = intermediate.igood.ravel() + try: + good_channels = intermediate.igood.ravel() + except AttributeError: + good_channels = np.ones_like(probe.chanMapBackup, dtype=bool) sample_rate = raw_data.sample_rate @@ -655,20 +666,39 @@ def update_plot(self, context=None): self.data_view_widget.setXRange(0, time_range, padding=0.0) self.data_view_widget.setLimits(xMin=0, xMax=time_range) - if self.traces_button.isChecked(): + orig_chan_map = probe.chanMapBackup + max_channels = np.size(orig_chan_map) + + start_channel = self.primary_channel + active_channels = self.get_currently_displayed_channel_count() + if active_channels is None: + active_channels = self.get_total_channels() + end_channel = start_channel + active_channels + + # good channels after start_channel to display + to_display = np.arange(start_channel, end_channel, dtype=int) + to_display = to_display[to_display < max_channels] + + raw_traces = raw_data[start_time:end_time] + + if self.traces_mode_active(): self._update_traces(params=params, probe=probe, raw_data=raw_data, + raw_traces=raw_traces, + to_display=to_display, intermediate=intermediate, good_channels=good_channels, start_time=start_time, end_time=end_time, ) - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self._update_colormap(params=params, probe=probe, raw_data=raw_data, + raw_traces=raw_traces, + to_display=to_display, intermediate=intermediate, good_channels=good_channels, start_time=start_time, @@ -686,15 +716,24 @@ def update_plot(self, context=None): self.data_view_widget.autoRange() - def _update_traces(self, params, probe, raw_data, intermediate, good_channels, start_time, end_time): + def _update_traces( + self, + params, + probe, + raw_data, + raw_traces, + to_display, + intermediate, + good_channels, + start_time, + end_time + ): self.hide_inactive_traces() - raw_traces = raw_data[start_time:end_time] - if self.raw_button.isChecked(): self.add_traces_to_plot_items( - traces=raw_traces, - good_channels=good_channels, + traces=raw_traces[:, to_display], + good_channels=good_channels[to_display], view="raw", ) @@ -713,9 +752,11 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: whitened_traces = self.whitened_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = whitened_traces self.add_traces_to_plot_items( - traces=whitened_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="whitened", ) @@ -731,9 +772,11 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: prediction_traces = self.prediction_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = prediction_traces self.add_traces_to_plot_items( - traces=prediction_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="prediction", ) @@ -773,27 +816,35 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: residual_traces = self.residual_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = residual_traces self.add_traces_to_plot_items( - traces=residual_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="residual", ) - def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, start_time, end_time): + def _update_colormap( + self, + params, + probe, + raw_data, + raw_traces, + to_display, + intermediate, + good_channels, + start_time, + end_time + ): self.hide_traces() - raw_traces = raw_data[start_time:end_time] - - start_channel = self.primary_channel - displayed_channels = self.channels_displayed_colormap - if displayed_channels is None: - displayed_channels = self.get_total_channels() - end_channel = start_channel + displayed_channels + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) if self.raw_button.isChecked(): colormap_min, colormap_max = -32.0, 32.0 + self.add_image_to_plot( - raw_traces[:, start_channel:end_channel], + raw_traces[:, to_display], colormap_min, colormap_max, ) @@ -813,8 +864,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: whitened_traces = self.whitened_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = whitened_traces self.add_image_to_plot( - whitened_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) @@ -831,8 +883,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: prediction_traces = self.prediction_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = prediction_traces self.add_image_to_plot( - prediction_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) @@ -873,8 +926,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: residual_traces = self.residual_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = residual_traces self.add_image_to_plot( - residual_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 9c9e222..98cf7e6 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -16,7 +16,7 @@ ) from pykilosort.gui.logger import setup_logger from pykilosort.params import KilosortParams -from pykilosort.utils import Context +from pykilosort.utils import Context, copy_bunch, extend_probe from PyQt5 import QtCore, QtGui, QtWidgets logger = setup_logger(__name__) @@ -163,10 +163,11 @@ def setup(self): self.run_box.updateContext.connect(self.update_context) self.run_box.disableInput.connect(self.disable_all_input) self.run_box.sortingStepStatusUpdate.connect(self.update_sorting_status) + self.run_box.updateProbeView.connect(self.update_probe_view) def change_channel_display(self, direction): if self.context is not None: - self.data_view_box.change_channel_display(direction) + self.data_view_box.shift_primary_channel(direction) def shift_data(self, time_shift): if self.context is not None: @@ -229,7 +230,7 @@ def set_parameters(self): self.prepare_for_new_context() self.load_raw_data() self.setup_context() - self.setup_probe_view() + self.update_probe_view() self.setup_data_view() self.update_run_box() @@ -252,38 +253,27 @@ def setup_data_view(self): self.data_view_box.create_plot_items() self.data_view_box.update_plot(self.context) - def update_context_with_good_channels(self): - worker = KiloSortWorker( - self.context, self.data_path, self.results_directory, ["goodchannels"] - ) - - worker.foundGoodChannels.connect(self.update_context) - - QtWidgets.QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.WaitCursor)) - worker.start() - while worker.isRunning(): - QtWidgets.QApplication.processEvents() - QtWidgets.QApplication.restoreOverrideCursor() - def setup_context(self): context_path = Path( os.path.join(self.working_directory, ".kilosort", self.raw_data.name) ) self.context = Context(context_path=context_path) - self.context.probe = self.probe_layout + probe_layout = self.probe_layout + probe_layout.Nchan = self.num_channels + self.context.probe = extend_probe(probe_layout) + self.context.raw_probe = extend_probe(copy_bunch(probe_layout)) self.context.params = self.params self.context.raw_data = self.raw_data self.context.load() - self.update_context_with_good_channels() - @QtCore.pyqtSlot(object) def update_context(self, context): self.context = context - def setup_probe_view(self): + @QtCore.pyqtSlot() + def update_probe_view(self): self.probe_view_box.set_layout(self.context) def update_run_box(self): @@ -294,6 +284,7 @@ def update_run_box(self): @QtCore.pyqtSlot(dict) def update_sorting_status(self, status_dict): self.data_view_box.change_sorting_status(status_dict) + self.probe_view_box.change_sorting_status(status_dict) def get_context(self): return self.context diff --git a/pykilosort/gui/minor_gui_elements.py b/pykilosort/gui/minor_gui_elements.py index b4dd080..b6fd29d 100644 --- a/pykilosort/gui/minor_gui_elements.py +++ b/pykilosort/gui/minor_gui_elements.py @@ -199,6 +199,7 @@ def construct_probe(self): probe.yc = self.y_coords probe.kcoords = self.k_coords probe.chanMap = self.channel_map + probe.chanMapBackup = probe.chanMap.copy() probe.bad_channels = self.bad_channels probe.NchanTOT = len(self.x_coords) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index a4223e7..960031e 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -33,6 +33,12 @@ def __init__(self, parent): self.channel_map_dict = {} self.good_channels = None + self.sorting_status = { + "preprocess": False, + "spikesort": False, + "export": False + } + self.configuration = { "active_channel": "g", "good_channel": "b", @@ -40,7 +46,7 @@ def __init__(self, parent): } self.active_data_view_mode = "colormap" - self.primary_channel = None + self.primary_channel = 0 self.active_channels = [] def setup(self): @@ -56,20 +62,13 @@ def setup(self): layout.addWidget(self.probe_view, 95) self.setLayout(layout) - def set_active_channels(self): - if self.active_data_view_mode == "traces": - displayed_channels = self.gui.data_view_box.channels_displayed_traces - else: - displayed_channels = self.gui.data_view_box.channels_displayed_colormap - if displayed_channels is None: - displayed_channels = self.total_channels - + def set_active_channels(self, channels_displayed): primary_channel = self.primary_channel channel_map = np.array(self.channel_map) primary_channel_position = int(np.where(channel_map == primary_channel)[0]) end_channel_position = np.where( - channel_map == primary_channel + displayed_channels + channel_map == primary_channel + channels_displayed )[0] # prevent the last displayed channel would be set as the end channel in the case that # `primary_channel + displayed_channels` exceeds the total number of channels in the channel map @@ -83,7 +82,13 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() - self.set_active_layout(context.probe, context.intermediate.igood) + probe = context.raw_probe + try: + good_channels = context.intermediate.igood + except AttributeError: + good_channels = None + + self.set_active_layout(probe, good_channels) self.update_probe_view() @@ -97,7 +102,10 @@ def set_active_layout(self, probe, good_channels=None): self.total_channels = self.active_layout.NchanTOT self.channel_map = self.active_layout.chanMap if good_channels is None: - self.good_channels = np.ones_like(self.channel_map, dtype=bool) + self.good_channels = np.ones_like( + self.active_layout.chanMapBackup, + dtype=bool + ) else: self.good_channels = good_channels @@ -110,16 +118,15 @@ def on_points_clicked(self, points): channel = self.channel_map[index] self.channelSelected.emit(channel) - def synchronize_data_view_mode(self, string): - old_mode = self.active_data_view_mode - self.active_data_view_mode = string - - if old_mode != self.active_data_view_mode and self.primary_channel is not None: + @QtCore.pyqtSlot(str, int) + def synchronize_data_view_mode(self, mode: str, channels_displayed: int): + if self.active_data_view_mode != mode: self.probe_view.clear() - self.update_probe_view() + self.update_probe_view(channels_displayed=channels_displayed) + self.active_data_view_mode = mode - def synchronize_primary_channel(self): - self.primary_channel = self.gui.data_view_box.primary_channel + def change_sorting_status(self, status_dict): + self.sorting_status = status_dict def generate_spots_list(self): spots = [] @@ -129,7 +136,7 @@ def generate_spots_list(self): for ind, (x_pos, y_pos) in enumerate(zip(self.xc, self.yc)): pos = (x_pos, y_pos) good_channel = self.good_channels[ind] - is_active = np.isin(ind, self.active_channels) + is_active = np.isin(ind, self.active_channels).tolist() if not good_channel: color = self.configuration["bad_channel"] elif good_channel and is_active: @@ -144,10 +151,15 @@ def generate_spots_list(self): return spots - @QtCore.pyqtSlot() - def update_probe_view(self): - self.synchronize_primary_channel() - self.set_active_channels() + @QtCore.pyqtSlot(int, int) + def update_probe_view(self, primary_channel=None, channels_displayed=None): + if primary_channel is not None: + self.primary_channel = primary_channel + + if channels_displayed is None: + channels_displayed = self.total_channels + + self.set_active_channels(channels_displayed) self.create_plot() @QtCore.pyqtSlot(object) @@ -168,7 +180,7 @@ def reset(self): self.clear_plot() self.reset_current_probe_layout() self.reset_active_data_view_mode() - self.primary_channel = None + self.primary_channel = 0 def reset_active_data_view_mode(self): self.active_data_view_mode = "colormap" diff --git a/pykilosort/gui/run_box.py b/pykilosort/gui/run_box.py index 515149c..e6dc39a 100644 --- a/pykilosort/gui/run_box.py +++ b/pykilosort/gui/run_box.py @@ -6,6 +6,7 @@ class RunBox(QtWidgets.QGroupBox): updateContext = QtCore.pyqtSignal(object) + updateProbeView = QtCore.pyqtSignal() sortingStepStatusUpdate = QtCore.pyqtSignal(dict) disableInput = QtCore.pyqtSignal(bool) @@ -117,6 +118,7 @@ def set_sorting_step_status(self, step, status): @QtCore.pyqtSlot(object) def finished_preprocess(self, context): self.updateContext.emit(context) + self.updateProbeView.emit() self.set_sorting_step_status("preprocess", True) @QtCore.pyqtSlot(object) diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index e71e77c..9437f61 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -2,36 +2,11 @@ import numpy as np from numba import jit from pykilosort.main import run_export, run_preprocess, run_spikesort -from pykilosort.preprocess import get_good_channels, get_whitening_matrix, gpufilter +from pykilosort.preprocess import gpufilter from PyQt5 import QtCore -def find_good_channels(context): - params = context.params - probe = context.probe - raw_data = context.raw_data - intermediate = context.intermediate - - if "igood" not in intermediate: - if params.minfr_goodchannels > 0: # discard channels that have very few spikes - # determine bad channels - with context.time("good_channels"): - intermediate.igood = get_good_channels( - raw_data=raw_data, probe=probe, params=params - ) - intermediate.igood = intermediate.igood.ravel() - # Cache the result. - context.write(igood=intermediate.igood) - - else: - intermediate.igood = np.ones_like(probe.chanMap, dtype=bool) - - probe.Nchan = len(probe.chanMap) - context.probe = probe - return context - - -def filter_and_whiten(raw_traces, params, probe, whitening_matrix): +def filter_and_whiten(raw_traces, params, probe, whitening_matrix, good_channels): sample_rate = params.fs high_pass_freq = params.fshigh low_pass_freq = params.fslow @@ -43,7 +18,7 @@ def filter_and_whiten(raw_traces, params, probe, whitening_matrix): filtered_data = gpufilter( buff=cp.asarray(raw_traces, dtype=np.float32), - chanMap=probe.chanMap, + chanMap=probe.chanMapBackup[good_channels], fs=sample_rate, fslow=low_pass_freq, fshigh=high_pass_freq, @@ -57,20 +32,6 @@ def filter_and_whiten(raw_traces, params, probe, whitening_matrix): return whitened_array.get() -def get_whitened_traces(raw_data, probe, params, whitening_matrix): - if whitening_matrix is None: - whitening_matrix = get_whitening_matrix( - raw_data=raw_data, probe=probe, params=params, nSkipCov=100 - ) - whitened_traces = filter_and_whiten( - raw_traces=raw_data, - params=params, - probe=probe, - whitening_matrix=whitening_matrix, - ) - return whitened_traces, whitening_matrix - - @jit(nopython=True) def get_predicted_traces( matrix_U: np.ndarray, @@ -122,7 +83,6 @@ def get_predicted_traces( class KiloSortWorker(QtCore.QThread): - foundGoodChannels = QtCore.pyqtSignal(object) finishedPreprocess = QtCore.pyqtSignal(object) finishedSpikesort = QtCore.pyqtSignal(object) finishedAll = QtCore.pyqtSignal(object) @@ -162,7 +122,3 @@ def run(self): if "export" in self.steps: run_export(self.context, self.data_path, self.output_directory) self.finishedAll.emit(self.context) - - if "goodchannels" in self.steps: - self.context = find_good_channels(self.context) - self.foundGoodChannels.emit(self.context) diff --git a/pykilosort/utils.py b/pykilosort/utils.py index df5f7d2..a16eea8 100644 --- a/pykilosort/utils.py +++ b/pykilosort/utils.py @@ -394,6 +394,7 @@ def load_probe(probe_path): probe.yc.append([pos[c][1] for c in ch]) probe.kcoords.append([cg for c in ch]) probe.chanMap = np.concatenate(probe.chanMap).ravel().astype(np.int32) + probe.chanMapBackup = probe.chanMap.copy() probe.xc = np.concatenate(probe.xc) probe.yc = np.concatenate(probe.yc) probe.kcoords = np.concatenate(probe.kcoords) @@ -406,6 +407,7 @@ def load_probe(probe_path): probe.yc = mat['ycoords'].ravel().astype(np.float64) probe.kcoords = mat.get('kcoords', np.zeros(nc)).ravel().astype(np.float64) probe.chanMap = (mat['chanMap'] - 1).ravel().astype(np.int32) # NOTE: 0-indexing in Python + probe.chanMapBackup = probe.chanMap.copy() probe.NchanTOT = len(probe.chanMap) # NOTE: should match the # of columns in the raw data for n in _required_keys: @@ -444,6 +446,79 @@ def create_prb(probe): return probe_prb +def extend_probe( + probe_layout: Bunch +) -> Bunch: + """ + Extend probe layout to account for extra num_channels. + + The probe layout selected by the user may have a different + number of channels than the requested number of channels in + the dataset. The function attempts to smartly extend the + layout of the probe to match the requested number of + channels. + + In case the requested number of channels is less than the + total channels on the layout, the original probe layout + is returned. + + Parameters + ---------- + probe_layout : Bunch + Input probe layout which might have to be extended. + + Returns + ------- + probe_layout : Bunch + Possibly extended probe layout. + + """ + if len(probe_layout["xc"]) >= probe_layout["Nchan"]: + # if the requested number of channels is less than the number of + # channels on the probe, return original probe + return probe_layout + else: + n_channels = probe_layout["Nchan"] + xc = probe_layout["xc"] + yc = probe_layout["yc"] + + # the assumption here is that the probe layout has a repetitive pattern + unique_x = np.sort(np.unique(probe_layout["xc"])) # unique values for the x-axis + unique_y = np.sort(np.unique(probe_layout["yc"])) # unique values for the y-axis + + kcoords = probe_layout["kcoords"] + chan_map = probe_layout["chanMap"] + + new_channels = n_channels - len(xc) # number of new channels to be added + + # create new properties + # this will probably break if new_channels > len(unique_x/y) + append_x = unique_x[-new_channels] + append_y = unique_y[-new_channels] + + append_chan_map = np.array([ + i for i in np.arange(chan_map[-1], + chan_map[-1] + new_channels) + ]) + 1 # to account for 0-ordering + append_kcoords = np.zeros(new_channels) + + # append new properties to existing properties + new_xc = np.append(xc, append_x) + new_yc = np.append(yc, append_y) + new_chan_map = np.append(chan_map, append_chan_map) + new_kcoords = np.append(kcoords, append_kcoords) + + # save new properties into probe layout + probe_layout["xc"] = new_xc + probe_layout["yc"] = new_yc + probe_layout["chanMap"] = new_chan_map + probe_layout["chanMapBackup"] = new_chan_map.copy() + probe_layout["kcoords"] = new_kcoords + probe_layout["NchanTOT"] = len(new_chan_map) + + return probe_layout + + def plot_dissimilarity_matrices(ccb, ccbsort, plot_widget): ccb = cp.asnumpy(ccb) ccbsort = cp.asnumpy(ccbsort)