From e225356a056c049ddeb227d2a0d317997f83d0f2 Mon Sep 17 00:00:00 2001 From: Shashwat Sridhar Date: Thu, 26 Nov 2020 16:38:33 -0500 Subject: [PATCH 01/42] remove calculation of good channels during data load --- pykilosort/gui/main.py | 27 +++++++++------------------ pykilosort/gui/probe_view_box.py | 15 ++++++++------- pykilosort/gui/run_box.py | 2 ++ pykilosort/gui/sorter.py | 32 +------------------------------- 4 files changed, 20 insertions(+), 56 deletions(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index e682ce4..8260523 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -163,6 +163,7 @@ def setup(self): self.run_box.updateContext.connect(self.update_context) self.run_box.disableInput.connect(self.disable_all_input) self.run_box.sortingStepStatusUpdate.connect(self.update_sorting_status) + self.run_box.updateProbeView.connect(self.update_probe_view) def change_channel_display(self, direction): if self.context is not None: @@ -229,7 +230,7 @@ def set_parameters(self): self.prepare_for_new_context() self.load_raw_data() self.setup_context() - self.setup_probe_view() + self.update_probe_view() self.setup_data_view() self.update_run_box() @@ -252,38 +253,28 @@ def setup_data_view(self): self.data_view_box.create_plot_items() self.data_view_box.update_plot(self.context) - def update_context_with_good_channels(self): - worker = KiloSortWorker( - self.context, self.data_path, self.results_directory, ["goodchannels"] - ) - - worker.foundGoodChannels.connect(self.update_context) - - QtWidgets.QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.WaitCursor)) - worker.start() - while worker.isRunning(): - QtWidgets.QApplication.processEvents() - QtWidgets.QApplication.restoreOverrideCursor() - def setup_context(self): context_path = Path( os.path.join(self.working_directory, ".kilosort", self.raw_data.name) ) self.context = Context(context_path=context_path) - self.context.probe = self.probe_layout + probe_layout = self.probe_layout + probe_layout.Nchan = len(probe_layout.chanMap) + self.context.probe = probe_layout self.context.params = self.params self.context.raw_data = self.raw_data - self.context.load() + self.context.intermediate.igood = np.ones_like(probe_layout.chanMap, dtype=bool) - self.update_context_with_good_channels() + self.context.load() @QtCore.pyqtSlot(object) def update_context(self, context): self.context = context - def setup_probe_view(self): + @QtCore.pyqtSlot() + def update_probe_view(self): self.probe_view_box.set_layout(self.context) def update_run_box(self): diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index a4223e7..63f8686 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -83,11 +83,14 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() - self.set_active_layout(context.probe, context.intermediate.igood) + probe = context.probe + good_channels = context.intermediate.igood + + self.set_active_layout(probe, good_channels) self.update_probe_view() - def set_active_layout(self, probe, good_channels=None): + def set_active_layout(self, probe, good_channels): self.active_layout = probe self.kcoords = self.active_layout.kcoords self.xc, self.yc = self.active_layout.xc, self.active_layout.yc @@ -96,10 +99,7 @@ def set_active_layout(self, probe, good_channels=None): self.channel_map_dict[(xc, yc)] = ind self.total_channels = self.active_layout.NchanTOT self.channel_map = self.active_layout.chanMap - if good_channels is None: - self.good_channels = np.ones_like(self.channel_map, dtype=bool) - else: - self.good_channels = good_channels + self.good_channels = good_channels def on_points_clicked(self, points): selected_point = points.ptsClicked[0] @@ -153,7 +153,8 @@ def update_probe_view(self): @QtCore.pyqtSlot(object) def preview_probe(self, probe): self.probe_view.clear() - self.set_active_layout(probe) + good_channels_dummy = np.ones_like(probe.chanMap, dtype=bool) + self.set_active_layout(probe, good_channels_dummy) self.create_plot(connect=False) def create_plot(self, connect=True): diff --git a/pykilosort/gui/run_box.py b/pykilosort/gui/run_box.py index 515149c..e6dc39a 100644 --- a/pykilosort/gui/run_box.py +++ b/pykilosort/gui/run_box.py @@ -6,6 +6,7 @@ class RunBox(QtWidgets.QGroupBox): updateContext = QtCore.pyqtSignal(object) + updateProbeView = QtCore.pyqtSignal() sortingStepStatusUpdate = QtCore.pyqtSignal(dict) disableInput = QtCore.pyqtSignal(bool) @@ -117,6 +118,7 @@ def set_sorting_step_status(self, step, status): @QtCore.pyqtSlot(object) def finished_preprocess(self, context): self.updateContext.emit(context) + self.updateProbeView.emit() self.set_sorting_step_status("preprocess", True) @QtCore.pyqtSlot(object) diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index e71e77c..06f72de 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -2,35 +2,10 @@ import numpy as np from numba import jit from pykilosort.main import run_export, run_preprocess, run_spikesort -from pykilosort.preprocess import get_good_channels, get_whitening_matrix, gpufilter +from pykilosort.preprocess import get_whitening_matrix, gpufilter from PyQt5 import QtCore -def find_good_channels(context): - params = context.params - probe = context.probe - raw_data = context.raw_data - intermediate = context.intermediate - - if "igood" not in intermediate: - if params.minfr_goodchannels > 0: # discard channels that have very few spikes - # determine bad channels - with context.time("good_channels"): - intermediate.igood = get_good_channels( - raw_data=raw_data, probe=probe, params=params - ) - intermediate.igood = intermediate.igood.ravel() - # Cache the result. - context.write(igood=intermediate.igood) - - else: - intermediate.igood = np.ones_like(probe.chanMap, dtype=bool) - - probe.Nchan = len(probe.chanMap) - context.probe = probe - return context - - def filter_and_whiten(raw_traces, params, probe, whitening_matrix): sample_rate = params.fs high_pass_freq = params.fshigh @@ -122,7 +97,6 @@ def get_predicted_traces( class KiloSortWorker(QtCore.QThread): - foundGoodChannels = QtCore.pyqtSignal(object) finishedPreprocess = QtCore.pyqtSignal(object) finishedSpikesort = QtCore.pyqtSignal(object) finishedAll = QtCore.pyqtSignal(object) @@ -162,7 +136,3 @@ def run(self): if "export" in self.steps: run_export(self.context, self.data_path, self.output_directory) self.finishedAll.emit(self.context) - - if "goodchannels" in self.steps: - self.context = find_good_channels(self.context) - self.foundGoodChannels.emit(self.context) From e738d40d61926bd4c3c46c468f9e6e397acecde7 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Mon, 30 Nov 2020 14:45:16 -0500 Subject: [PATCH 02/42] remove calculation of good channels on data load --- pykilosort/gui/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index 3330c2a..39005e4 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import KiloSortWorker, filter_and_whiten, find_good_channels +from .sorter import KiloSortWorker, filter_and_whiten from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI From 8ef540a3340e943837d6fcdf0485e7c3ce255f69 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 1 Dec 2020 14:17:54 -0500 Subject: [PATCH 03/42] fix last channel marked as inactive --- pykilosort/gui/probe_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 63f8686..334cc34 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -79,7 +79,7 @@ def set_active_channels(self): end_channel_position = int(end_channel_position) self.active_channels = channel_map[ primary_channel_position:end_channel_position+1 - ].tolist() + +1].tolist() def set_layout(self, context): self.probe_view.clear() From 76404ebfe6202b91ec0e869ced0169b13bebd6a1 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Thu, 3 Dec 2020 09:36:19 -0500 Subject: [PATCH 04/42] Connect sorting status to probe view box --- pykilosort/gui/main.py | 1 + pykilosort/gui/probe_view_box.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 8260523..1ce112b 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -285,6 +285,7 @@ def update_run_box(self): @QtCore.pyqtSlot(dict) def update_sorting_status(self, status_dict): self.data_view_box.change_sorting_status(status_dict) + self.probe_view_box.change_sorting_status(status_dict) def get_context(self): return self.context diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 334cc34..d053361 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -33,6 +33,12 @@ def __init__(self, parent): self.channel_map_dict = {} self.good_channels = None + self.sorting_status = { + "preprocess": False, + "spikesort": False, + "export": False + } + self.configuration = { "active_channel": "g", "good_channel": "b", @@ -121,6 +127,9 @@ def synchronize_data_view_mode(self, string): def synchronize_primary_channel(self): self.primary_channel = self.gui.data_view_box.primary_channel + def change_sorting_status(self, status_dict): + self.sorting_status = status_dict + def generate_spots_list(self): spots = [] size = 10 From 584cc208ba4e1cf9fbabc8d4eb6b5ab547906d84 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Thu, 3 Dec 2020 09:36:43 -0500 Subject: [PATCH 05/42] minor bug fix --- pykilosort/gui/probe_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index d053361..da77682 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -138,7 +138,7 @@ def generate_spots_list(self): for ind, (x_pos, y_pos) in enumerate(zip(self.xc, self.yc)): pos = (x_pos, y_pos) good_channel = self.good_channels[ind] - is_active = np.isin(ind, self.active_channels) + is_active = np.isin(ind, self.active_channels).tolist() if not good_channel: color = self.configuration["bad_channel"] elif good_channel and is_active: From 9374cc607ac214ecbcbe7f30267b28ff2b9670aa Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Fri, 4 Dec 2020 14:04:20 -0500 Subject: [PATCH 06/42] use raw_probe to set layout in probe view box --- pykilosort/gui/main.py | 1 + pykilosort/gui/probe_view_box.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 1ce112b..357cd47 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -262,6 +262,7 @@ def setup_context(self): probe_layout = self.probe_layout probe_layout.Nchan = len(probe_layout.chanMap) self.context.probe = probe_layout + self.context.raw_probe = copy_bunch(probe_layout) self.context.params = self.params self.context.raw_data = self.raw_data diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index da77682..edf9e3a 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -89,7 +89,7 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() - probe = context.probe + probe = context.raw_probe good_channels = context.intermediate.igood self.set_active_layout(probe, good_channels) From dcdacbf30876a6e6a99de1ed8590eeae9031ec59 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Mon, 7 Dec 2020 08:58:10 -0500 Subject: [PATCH 07/42] fixup! use raw_probe to set layout in probe view box --- pykilosort/gui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 357cd47..581f3cc 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -16,7 +16,7 @@ ) from pykilosort.gui.logger import setup_logger from pykilosort.params import KilosortParams -from pykilosort.utils import Context +from pykilosort.utils import Context, copy_bunch from PyQt5 import QtCore, QtGui, QtWidgets logger = setup_logger(__name__) From 2eb0cc89d3f746cd03d232ea8ab70cd8724bbba3 Mon Sep 17 00:00:00 2001 From: Shashwat Sridhar Date: Thu, 26 Nov 2020 16:38:33 -0500 Subject: [PATCH 08/42] remove calculation of good channels during data load --- pykilosort/gui/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index 39005e4..98bd049 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import KiloSortWorker, filter_and_whiten +from .sorter import filter_and_whiten, KiloSortWorker from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI From ca649fb7e8b1f9c2836e76b2972477a911329b92 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 1 Dec 2020 14:17:54 -0500 Subject: [PATCH 09/42] fix last channel marked as inactive --- pykilosort/gui/probe_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index edf9e3a..fefcf62 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -85,7 +85,7 @@ def set_active_channels(self): end_channel_position = int(end_channel_position) self.active_channels = channel_map[ primary_channel_position:end_channel_position+1 - +1].tolist() + ].tolist() def set_layout(self, context): self.probe_view.clear() From e42345efd00e202c6dfa753644b1a1fa5f7bfc10 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:13:32 -0400 Subject: [PATCH 10/42] fix minor imports and remove redundant function --- pykilosort/gui/__init__.py | 2 +- pykilosort/gui/sorter.py | 16 +--------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index 98bd049..cdecd9a 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import filter_and_whiten, KiloSortWorker +from .sorter import filter_and_whiten, get_predicted_traces, KiloSortWorker from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index 06f72de..644b118 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -2,7 +2,7 @@ import numpy as np from numba import jit from pykilosort.main import run_export, run_preprocess, run_spikesort -from pykilosort.preprocess import get_whitening_matrix, gpufilter +from pykilosort.preprocess import gpufilter from PyQt5 import QtCore @@ -32,20 +32,6 @@ def filter_and_whiten(raw_traces, params, probe, whitening_matrix): return whitened_array.get() -def get_whitened_traces(raw_data, probe, params, whitening_matrix): - if whitening_matrix is None: - whitening_matrix = get_whitening_matrix( - raw_data=raw_data, probe=probe, params=params, nSkipCov=100 - ) - whitened_traces = filter_and_whiten( - raw_traces=raw_data, - params=params, - probe=probe, - whitening_matrix=whitening_matrix, - ) - return whitened_traces, whitening_matrix - - @jit(nopython=True) def get_predicted_traces( matrix_U: np.ndarray, From d2254be3e996c1d17aa7649081292590c5b93dee Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:14:54 -0400 Subject: [PATCH 11/42] remove redundant import --- pykilosort/gui/data_view_box.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index cdf7b61..6a771c1 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -2,6 +2,7 @@ import typing as t import pyqtgraph as pg from cupy import asnumpy +from datetime import datetime from pykilosort.gui.logger import setup_logger from pykilosort.gui.minor_gui_elements import controls_popup_text from pykilosort.gui.palettes import COLORMAP_COLORS From 870b88c777f4fa20270e013c5dc99eb1162675ba Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:26:15 -0400 Subject: [PATCH 12/42] code readability improvements --- pykilosort/gui/data_view_box.py | 49 +++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 6a771c1..2af7e80 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -236,46 +236,46 @@ def on_views_clicked(self): @QtCore.pyqtSlot(int) def on_wheel_scroll(self, direction): - if self.gui.context is not None: + if self.context_set(): self.shift_current_time(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_control(self, direction): - if self.gui.context is not None: - if self.traces_button.isChecked(): + if self.context_set(): + if self.traces_mode_active(): self.shift_primary_channel(direction) else: self.change_displayed_channel_count(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_shift(self, direction): - if self.gui.context is not None: + if self.context_set(): self.change_plot_range(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_alt(self, direction): - if self.gui.context is not None: + if self.context_set(): self.change_plot_scaling(direction) def toggle_mode_from_click(self): - if self.traces_button.isChecked(): + if self.traces_mode_active(): self.modeChanged.emit("traces") self.view_buttons_group.setExclusive(False) self.update_plot() - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self.modeChanged.emit("colormap") self._traces_to_colormap_toggle() self.update_plot() def toggle_mode(self): - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self.traces_button.toggle() self.modeChanged.emit("traces") self.view_buttons_group.setExclusive(False) self.update_plot() - elif self.traces_button.isChecked(): + elif self.traces_mode_active(): self.colormap_button.toggle() self.modeChanged.emit("colormap") self._traces_to_colormap_toggle() @@ -312,6 +312,15 @@ def _traces_to_colormap_toggle(self): self.view_buttons_group.setExclusive(True) + def traces_mode_active(self): + return self.traces_button.isChecked() + + def colormap_mode_active(self): + return self.colormap_button.isChecked() + + def context_set(self): + return self.gui.context is not None + def change_primary_channel(self, channel): self.primary_channel = channel self.channelChanged.emit() @@ -327,7 +336,7 @@ def shift_primary_channel(self, shift): self.update_plot() def get_currently_displayed_channel_count(self): - if self.traces_button.isChecked(): + if self.traces_mode_active(): return self.channels_displayed_traces else: count = self.channels_displayed_colormap @@ -336,7 +345,7 @@ def get_currently_displayed_channel_count(self): return count def set_currently_displayed_channel_count(self, count): - if self.traces_button.isChecked(): + if self.traces_mode_active(): self.channels_displayed_traces = count else: self.channels_displayed_colormap = count @@ -382,14 +391,14 @@ def change_plot_range(self, direction): self.update_plot() def change_plot_scaling(self, direction): - if self.traces_button.isChecked(): + if self.traces_mode_active(): scale_factor = self.scale_factor * (1.1 ** direction) if 0.1 < scale_factor < 10.0: self.scale_factor = scale_factor self.update_plot() - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): colormap_min = self.colormap_min + (direction * 0.05) colormap_max = self.colormap_max - (direction * 0.05) if 0.0 <= colormap_min < colormap_max <= 1.0: @@ -408,8 +417,8 @@ def change_channel_display(self, direction): self.change_displayed_channel_count(direction) def scene_clicked(self, ev): - if self.gui.context is not None: - if self.traces_button.isChecked(): + if self.context_set(): + if self.traces_mode_active(): x_pos = ev.pos().x() else: x_pos = self.colormap_image.mapFromScene(ev.pos()).x() @@ -422,7 +431,7 @@ def scene_clicked(self, ev): self.shift_current_time(direction=-1) def seek_clicked(self, ev): - if self.gui.context is not None: + if self.context_set(): new_time = self.seek_view_box.mapSceneToView(ev.pos()).x() seek_range_min = self.seek_range[0] seek_range_max = self.seek_range[1] @@ -456,7 +465,7 @@ def change_sorting_status(self, status_dict): self.enable_view_buttons() def enable_view_buttons(self): - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): if self.prediction_button.isChecked() or self.residual_button.isChecked(): self.raw_button.click() else: @@ -552,13 +561,13 @@ def add_traces_to_plot_items( ): """ Update plot items with traces. - + Loops over traces and plots each trace using the setData() method of pyqtgraph's PlotCurveItem. The color of the trace depends on the mode requested (raw, whitened, prediction, residual). Bad channels are plotted in a different color. Each trace is also scaled by a certain factor defined in self.traces_scaling_factor. - + Parameters ---------- traces : numpy.ndarray @@ -685,7 +694,7 @@ def update_plot(self, context=None): end_time=end_time, ) - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self._update_colormap(params=params, probe=probe, raw_data=raw_data, From bc2c44e20aee8b35ba35feb45cc2717c9ab6df33 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:34:18 -0400 Subject: [PATCH 13/42] change data_view and probe_view sync signal * change signal signature --- pykilosort/gui/data_view_box.py | 17 ++++++++++------- pykilosort/gui/probe_view_box.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 2af7e80..e12a62a 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -14,7 +14,7 @@ class DataViewBox(QtWidgets.QGroupBox): - channelChanged = QtCore.pyqtSignal() + channelChanged = QtCore.pyqtSignal(int, int) modeChanged = QtCore.pyqtSignal(str) def __init__(self, parent): @@ -323,7 +323,7 @@ def context_set(self): def change_primary_channel(self, channel): self.primary_channel = channel - self.channelChanged.emit() + self.channelChanged.emit(self.primary_channel, self.get_currently_displayed_channel_count()) self.update_plot() def shift_primary_channel(self, shift): @@ -332,7 +332,7 @@ def shift_primary_channel(self, shift): total_channels = self.get_total_channels() if (0 <= primary_channel < total_channels) and total_channels is not None: self.primary_channel = primary_channel - self.channelChanged.emit() + self.channelChanged.emit(self.primary_channel, self.get_currently_displayed_channel_count()) self.update_plot() def get_currently_displayed_channel_count(self): @@ -354,22 +354,25 @@ def get_total_channels(self): return self.gui.probe_view_box.total_channels def change_displayed_channel_count(self, direction): + current_channel = self.primary_channel total_channels = self.get_total_channels() current_count = self.get_currently_displayed_channel_count() new_count = current_count + (direction * 5) - if 0 < new_count <= total_channels: + if (current_channel + new_count) <= total_channels: self.set_currently_displayed_channel_count(new_count) + self.channelChanged.emit(self.primary_channel, new_count) self.refresh_plot_on_displayed_channel_count_change() elif new_count <= 0 and current_count != 1: self.set_currently_displayed_channel_count(1) + self.channelChanged.emit(self.primary_channel, 1) self.refresh_plot_on_displayed_channel_count_change() - elif new_count > total_channels: - self.set_currently_displayed_channel_count(total_channels) + elif (current_channel + new_count) > total_channels: + self.set_currently_displayed_channel_count(total_channels - current_channel) + self.channelChanged.emit(self.primary_channel, total_channels) self.refresh_plot_on_displayed_channel_count_change() def refresh_plot_on_displayed_channel_count_change(self): - self.channelChanged.emit() self.plot_item.clear() self.create_plot_items() self.update_plot() diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index fefcf62..0eb2215 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -153,10 +153,15 @@ def generate_spots_list(self): return spots - @QtCore.pyqtSlot() - def update_probe_view(self): - self.synchronize_primary_channel() - self.set_active_channels() + @QtCore.pyqtSlot(int, int) + def update_probe_view(self, primary_channel=None, channels_displayed=None): + if primary_channel is not None: + self.primary_channel = primary_channel + + if channels_displayed is None: + channels_displayed = self.total_channels + + self.set_active_channels(channels_displayed) self.create_plot() @QtCore.pyqtSlot(object) From df5a3ffdfb4e5fc24c78ba76d13dac71db160d16 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:37:46 -0400 Subject: [PATCH 14/42] minor changes to data_view and probe_view sync * add missing pyqtSlot decorators * coda readability improvements --- pykilosort/gui/data_view_box.py | 1 + pykilosort/gui/probe_view_box.py | 12 ++++-------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index e12a62a..a69091d 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -257,6 +257,7 @@ def on_wheel_scroll_plus_alt(self, direction): if self.context_set(): self.change_plot_scaling(direction) + @QtCore.pyqtSlot() def toggle_mode_from_click(self): if self.traces_mode_active(): self.modeChanged.emit("traces") diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 0eb2215..c2eb5ba 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -116,16 +116,12 @@ def on_points_clicked(self, points): channel = self.channel_map[index] self.channelSelected.emit(channel) - def synchronize_data_view_mode(self, string): - old_mode = self.active_data_view_mode - self.active_data_view_mode = string - - if old_mode != self.active_data_view_mode and self.primary_channel is not None: + @QtCore.pyqtSlot(str) + def synchronize_data_view_mode(self, mode: str): + if self.active_data_view_mode != mode: self.probe_view.clear() self.update_probe_view() - - def synchronize_primary_channel(self): - self.primary_channel = self.gui.data_view_box.primary_channel + self.active_data_view_mode = mode def change_sorting_status(self, status_dict): self.sorting_status = status_dict From 38fd0251591c010c609658e92201ebbc0ceb9e1e Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:39:11 -0400 Subject: [PATCH 15/42] remove redundant function --- pykilosort/gui/data_view_box.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index a69091d..049215a 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -414,12 +414,6 @@ def change_plot_scaling(self, direction): self.update_plot() - def change_channel_display(self, direction): - if self.traces_button.isChecked(): - self.shift_primary_channel(direction) - else: - self.change_displayed_channel_count(direction) - def scene_clicked(self, ev): if self.context_set(): if self.traces_mode_active(): From 3ac053cd1cd29f9908d6fd155de4505f2e4f1ee5 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:04:50 -0400 Subject: [PATCH 16/42] introduce chanMapBackup as a property of Probe * preserves original chanMap, which is overwritten during preprocess * use chanMapBackup to generate default good_channels arrays * remove risky usages of chanMap --- pykilosort/gui/data_view_box.py | 10 +++- pykilosort/gui/main.py | 10 ++-- pykilosort/gui/minor_gui_elements.py | 1 + pykilosort/gui/probe_view_box.py | 13 +++-- pykilosort/gui/sorter.py | 4 +- pykilosort/utils.py | 75 ++++++++++++++++++++++++++++ 6 files changed, 100 insertions(+), 13 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 049215a..543c924 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -649,11 +649,16 @@ def get_whitened_traces( raw_data=raw_data, params=params, probe=probe, nSkipCov=nSkipCov ) + good_channels = intermediate.igood.ravel() \ + if "igood" in intermediate \ + else np.ones_like(probe.chanMapBackup) + whitened_traces = filter_and_whiten( raw_traces=raw_traces, params=params, probe=probe, whitening_matrix=self.whitening_matrix, + good_channels=good_channels, ) return whitened_traces @@ -671,7 +676,10 @@ def update_plot(self, context=None): probe = context.probe raw_data = context.raw_data intermediate = context.intermediate - good_channels = intermediate.igood.ravel() + try: + good_channels = intermediate.igood.ravel() + except AttributeError: + good_channels = np.ones_like(probe.chanMapBackup, dtype=bool) sample_rate = raw_data.sample_rate diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 581f3cc..d66c5c2 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -16,7 +16,7 @@ ) from pykilosort.gui.logger import setup_logger from pykilosort.params import KilosortParams -from pykilosort.utils import Context, copy_bunch +from pykilosort.utils import Context, copy_bunch, extend_probe from PyQt5 import QtCore, QtGui, QtWidgets logger = setup_logger(__name__) @@ -260,14 +260,12 @@ def setup_context(self): self.context = Context(context_path=context_path) probe_layout = self.probe_layout - probe_layout.Nchan = len(probe_layout.chanMap) - self.context.probe = probe_layout - self.context.raw_probe = copy_bunch(probe_layout) + probe_layout.Nchan = self.num_channels + self.context.probe = extend_probe(probe_layout) + self.context.raw_probe = extend_probe(copy_bunch(probe_layout)) self.context.params = self.params self.context.raw_data = self.raw_data - self.context.intermediate.igood = np.ones_like(probe_layout.chanMap, dtype=bool) - self.context.load() @QtCore.pyqtSlot(object) diff --git a/pykilosort/gui/minor_gui_elements.py b/pykilosort/gui/minor_gui_elements.py index b793b78..8d5e237 100644 --- a/pykilosort/gui/minor_gui_elements.py +++ b/pykilosort/gui/minor_gui_elements.py @@ -199,6 +199,7 @@ def construct_probe(self): probe.yc = self.y_coords probe.kcoords = self.k_coords probe.chanMap = self.channel_map + probe.chanMapBackup = probe.chanMap.copy() probe.bad_channels = self.bad_channels probe.NchanTOT = len(self.x_coords) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index c2eb5ba..eb3c8af 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -96,7 +96,7 @@ def set_layout(self, context): self.update_probe_view() - def set_active_layout(self, probe, good_channels): + def set_active_layout(self, probe, good_channels=None): self.active_layout = probe self.kcoords = self.active_layout.kcoords self.xc, self.yc = self.active_layout.xc, self.active_layout.yc @@ -105,7 +105,13 @@ def set_active_layout(self, probe, good_channels): self.channel_map_dict[(xc, yc)] = ind self.total_channels = self.active_layout.NchanTOT self.channel_map = self.active_layout.chanMap - self.good_channels = good_channels + if good_channels is None: + self.good_channels = np.ones_like( + self.active_layout.chanMapBackup, + dtype=bool + ) + else: + self.good_channels = good_channels def on_points_clicked(self, points): selected_point = points.ptsClicked[0] @@ -163,8 +169,7 @@ def update_probe_view(self, primary_channel=None, channels_displayed=None): @QtCore.pyqtSlot(object) def preview_probe(self, probe): self.probe_view.clear() - good_channels_dummy = np.ones_like(probe.chanMap, dtype=bool) - self.set_active_layout(probe, good_channels_dummy) + self.set_active_layout(probe) self.create_plot(connect=False) def create_plot(self, connect=True): diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index 644b118..9437f61 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -6,7 +6,7 @@ from PyQt5 import QtCore -def filter_and_whiten(raw_traces, params, probe, whitening_matrix): +def filter_and_whiten(raw_traces, params, probe, whitening_matrix, good_channels): sample_rate = params.fs high_pass_freq = params.fshigh low_pass_freq = params.fslow @@ -18,7 +18,7 @@ def filter_and_whiten(raw_traces, params, probe, whitening_matrix): filtered_data = gpufilter( buff=cp.asarray(raw_traces, dtype=np.float32), - chanMap=probe.chanMap, + chanMap=probe.chanMapBackup[good_channels], fs=sample_rate, fslow=low_pass_freq, fshigh=high_pass_freq, diff --git a/pykilosort/utils.py b/pykilosort/utils.py index 32ff282..64e2d04 100644 --- a/pykilosort/utils.py +++ b/pykilosort/utils.py @@ -394,6 +394,7 @@ def load_probe(probe_path): probe.yc.append([pos[c][1] for c in ch]) probe.kcoords.append([cg for c in ch]) probe.chanMap = np.concatenate(probe.chanMap).ravel().astype(np.int32) + probe.chanMapBackup = probe.chanMap.copy() probe.xc = np.concatenate(probe.xc) probe.yc = np.concatenate(probe.yc) probe.kcoords = np.concatenate(probe.kcoords) @@ -406,6 +407,7 @@ def load_probe(probe_path): probe.yc = mat['ycoords'].ravel().astype(np.float64) probe.kcoords = mat.get('kcoords', np.zeros(nc)).ravel().astype(np.float64) probe.chanMap = (mat['chanMap'] - 1).ravel().astype(np.int32) # NOTE: 0-indexing in Python + probe.chanMapBackup = probe.chanMap.copy() probe.NchanTOT = len(probe.chanMap) # NOTE: should match the # of columns in the raw data for n in _required_keys: @@ -444,6 +446,79 @@ def create_prb(probe): return probe_prb +def extend_probe( + probe_layout: Bunch +) -> Bunch: + """ + Extend probe layout to account for extra num_channels. + + The probe layout selected by the user may have a different + number of channels than the requested number of channels in + the dataset. The function attempts to smartly extend the + layout of the probe to match the requested number of + channels. + + In case the requested number of channels is less than the + total channels on the layout, the original probe layout + is returned. + + Parameters + ---------- + probe_layout : Bunch + Input probe layout which might have to be extended. + + Returns + ------- + probe_layout : Bunch + Possibly extended probe layout. + + """ + if len(probe_layout["xc"]) >= probe_layout["Nchan"]: + # if the requested number of channels is less than the number of + # channels on the probe, return original probe + return probe_layout + else: + n_channels = probe_layout["Nchan"] + xc = probe_layout["xc"] + yc = probe_layout["yc"] + + # the assumption here is that the probe layout has a repetitive pattern + unique_x = np.sort(np.unique(probe_layout["xc"])) # unique values for the x-axis + unique_y = np.sort(np.unique(probe_layout["yc"])) # unique values for the y-axis + + kcoords = probe_layout["kcoords"] + chan_map = probe_layout["chanMap"] + + new_channels = n_channels - len(xc) # number of new channels to be added + + # create new properties + # this will probably break if new_channels > len(unique_x/y) + append_x = unique_x[-new_channels] + append_y = unique_y[-new_channels] + + append_chan_map = np.array([ + i for i in np.arange(chan_map[-1], + chan_map[-1] + new_channels) + ]) + 1 # to account for 0-ordering + append_kcoords = np.zeros(new_channels) + + # append new properties to existing properties + new_xc = np.append(xc, append_x) + new_yc = np.append(yc, append_y) + new_chan_map = np.append(chan_map, append_chan_map) + new_kcoords = np.append(kcoords, append_kcoords) + + # save new properties into probe layout + probe_layout["xc"] = new_xc + probe_layout["yc"] = new_yc + probe_layout["chanMap"] = new_chan_map + probe_layout["chanMapBackup"] = new_chan_map.copy() + probe_layout["kcoords"] = new_kcoords + probe_layout["NchanTOT"] = len(new_chan_map) + + return probe_layout + + def plot_dissimilarity_matrices(ccb, ccbsort, plot_widget): ccb = cp.asnumpy(ccb) ccbsort = cp.asnumpy(ccbsort) From 5fdd7909e96396c910e04418296cc13c9e2b2c68 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:06:08 -0400 Subject: [PATCH 17/42] refactor add_traces_to_plot_items() --- pykilosort/gui/data_view_box.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 543c924..8966884 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -576,22 +576,17 @@ def add_traces_to_plot_items( view : str One of "raw", "whitened", "prediction" and "residual" views """ - for c, channel in enumerate( - range( - self.primary_channel, - self.primary_channel + self.channels_displayed_traces, - ) - ): + for c, good in enumerate(good_channels): try: curve = self.traces_plot_items[view][c] color = ( self.traces_curve_color[view] - if good_channels[channel] + if good else self.bad_channel_color ) curve.setPen(color=color, width=1) curve.setData( - traces.T[channel] * + traces[:, c] * self.scale_factor * self.traces_scaling_factor[view] ) From 850cef75b4109185cd7f8b4ec31848b623095fab Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:08:35 -0400 Subject: [PATCH 18/42] same behaviour for ctrl+up/down & ctrl+scroll --- pykilosort/gui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index d66c5c2..30f26d0 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -167,7 +167,7 @@ def setup(self): def change_channel_display(self, direction): if self.context is not None: - self.data_view_box.change_channel_display(direction) + self.data_view_box.shift_primary_channel(direction) def shift_data(self, time_shift): if self.context is not None: From 3eb64c924e61cd04c0985d0adc1eca8cc6dc5a85 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:10:05 -0400 Subject: [PATCH 19/42] code refactor --- pykilosort/gui/probe_view_box.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index eb3c8af..582dede 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -46,7 +46,7 @@ def __init__(self, parent): } self.active_data_view_mode = "colormap" - self.primary_channel = None + self.primary_channel = 0 self.active_channels = [] def setup(self): @@ -62,20 +62,13 @@ def setup(self): layout.addWidget(self.probe_view, 95) self.setLayout(layout) - def set_active_channels(self): - if self.active_data_view_mode == "traces": - displayed_channels = self.gui.data_view_box.channels_displayed_traces - else: - displayed_channels = self.gui.data_view_box.channels_displayed_colormap - if displayed_channels is None: - displayed_channels = self.total_channels - + def set_active_channels(self, channels_displayed): primary_channel = self.primary_channel channel_map = np.array(self.channel_map) primary_channel_position = int(np.where(channel_map == primary_channel)[0]) end_channel_position = np.where( - channel_map == primary_channel + displayed_channels + channel_map == primary_channel + channels_displayed )[0] # prevent the last displayed channel would be set as the end channel in the case that # `primary_channel + displayed_channels` exceeds the total number of channels in the channel map @@ -90,7 +83,10 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() probe = context.raw_probe - good_channels = context.intermediate.igood + try: + good_channels = context.intermediate.igood + except AttributeError: + good_channels = None self.set_active_layout(probe, good_channels) @@ -184,7 +180,7 @@ def reset(self): self.clear_plot() self.reset_current_probe_layout() self.reset_active_data_view_mode() - self.primary_channel = None + self.primary_channel = 0 def reset_active_data_view_mode(self): self.active_data_view_mode = "colormap" From 6442092f608985eb4d383fcaf1f295269662bc2a Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:13:57 -0400 Subject: [PATCH 20/42] major refactor of plotting code * correctly account for good channels and determine channels to be displayed --- pykilosort/gui/data_view_box.py | 93 ++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 8966884..ab8f36e 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -662,7 +662,7 @@ def update_plot(self, context=None): if context is None: context = self.gui.context - if context is not None: + if context is not None: # since context may still be None if self.colormap_image is not None: self.plot_item.removeItem(self.colormap_image) self.colormap_image = None @@ -685,10 +685,27 @@ def update_plot(self, context=None): self.data_view_widget.setXRange(0, time_range, padding=0.0) self.data_view_widget.setLimits(xMin=0, xMax=time_range) - if self.traces_button.isChecked(): + orig_chan_map = probe.chanMapBackup + max_channels = np.size(orig_chan_map) + + start_channel = self.primary_channel + active_channels = self.get_currently_displayed_channel_count() + if active_channels is None: + active_channels = self.get_total_channels() + end_channel = start_channel + active_channels + + # good channels after start_channel to display + to_display = np.arange(start_channel, end_channel, dtype=int) + to_display = to_display[to_display < max_channels] + + raw_traces = raw_data[start_time:end_time] + + if self.traces_mode_active(): self._update_traces(params=params, probe=probe, raw_data=raw_data, + raw_traces=raw_traces, + to_display=to_display, intermediate=intermediate, good_channels=good_channels, start_time=start_time, @@ -699,6 +716,8 @@ def update_plot(self, context=None): self._update_colormap(params=params, probe=probe, raw_data=raw_data, + raw_traces=raw_traces, + to_display=to_display, intermediate=intermediate, good_channels=good_channels, start_time=start_time, @@ -716,15 +735,24 @@ def update_plot(self, context=None): self.data_view_widget.autoRange() - def _update_traces(self, params, probe, raw_data, intermediate, good_channels, start_time, end_time): + def _update_traces( + self, + params, + probe, + raw_data, + raw_traces, + to_display, + intermediate, + good_channels, + start_time, + end_time + ): self.hide_inactive_traces() - raw_traces = raw_data[start_time:end_time] - if self.raw_button.isChecked(): self.add_traces_to_plot_items( - traces=raw_traces, - good_channels=good_channels, + traces=raw_traces[:, to_display], + good_channels=good_channels[to_display], view="raw", ) @@ -743,9 +771,11 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: whitened_traces = self.whitened_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = whitened_traces self.add_traces_to_plot_items( - traces=whitened_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="whitened", ) @@ -761,9 +791,11 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: prediction_traces = self.prediction_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = prediction_traces self.add_traces_to_plot_items( - traces=prediction_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="prediction", ) @@ -803,27 +835,35 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: residual_traces = self.residual_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = residual_traces self.add_traces_to_plot_items( - traces=residual_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="residual", ) - def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, start_time, end_time): + def _update_colormap( + self, + params, + probe, + raw_data, + raw_traces, + to_display, + intermediate, + good_channels, + start_time, + end_time + ): self.hide_traces() - raw_traces = raw_data[start_time:end_time] - - start_channel = self.primary_channel - displayed_channels = self.channels_displayed_colormap - if displayed_channels is None: - displayed_channels = self.get_total_channels() - end_channel = start_channel + displayed_channels + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) if self.raw_button.isChecked(): colormap_min, colormap_max = -32.0, 32.0 + self.add_image_to_plot( - raw_traces[:, start_channel:end_channel], + raw_traces[:, to_display], colormap_min, colormap_max, ) @@ -843,8 +883,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: whitened_traces = self.whitened_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = whitened_traces self.add_image_to_plot( - whitened_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) @@ -861,8 +902,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: prediction_traces = self.prediction_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = prediction_traces self.add_image_to_plot( - prediction_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) @@ -903,8 +945,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: residual_traces = self.residual_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = residual_traces self.add_image_to_plot( - residual_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) From ead0277a069875d33124a6fdbf729f039d778923 Mon Sep 17 00:00:00 2001 From: Shashwat Sridhar Date: Thu, 26 Nov 2020 16:38:33 -0500 Subject: [PATCH 21/42] remove calculation of good channels during data load --- pykilosort/gui/main.py | 27 +++++++++------------------ pykilosort/gui/probe_view_box.py | 15 ++++++++------- pykilosort/gui/run_box.py | 2 ++ pykilosort/gui/sorter.py | 32 +------------------------------- 4 files changed, 20 insertions(+), 56 deletions(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 9c9e222..0f56c81 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -163,6 +163,7 @@ def setup(self): self.run_box.updateContext.connect(self.update_context) self.run_box.disableInput.connect(self.disable_all_input) self.run_box.sortingStepStatusUpdate.connect(self.update_sorting_status) + self.run_box.updateProbeView.connect(self.update_probe_view) def change_channel_display(self, direction): if self.context is not None: @@ -229,7 +230,7 @@ def set_parameters(self): self.prepare_for_new_context() self.load_raw_data() self.setup_context() - self.setup_probe_view() + self.update_probe_view() self.setup_data_view() self.update_run_box() @@ -252,38 +253,28 @@ def setup_data_view(self): self.data_view_box.create_plot_items() self.data_view_box.update_plot(self.context) - def update_context_with_good_channels(self): - worker = KiloSortWorker( - self.context, self.data_path, self.results_directory, ["goodchannels"] - ) - - worker.foundGoodChannels.connect(self.update_context) - - QtWidgets.QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.WaitCursor)) - worker.start() - while worker.isRunning(): - QtWidgets.QApplication.processEvents() - QtWidgets.QApplication.restoreOverrideCursor() - def setup_context(self): context_path = Path( os.path.join(self.working_directory, ".kilosort", self.raw_data.name) ) self.context = Context(context_path=context_path) - self.context.probe = self.probe_layout + probe_layout = self.probe_layout + probe_layout.Nchan = len(probe_layout.chanMap) + self.context.probe = probe_layout self.context.params = self.params self.context.raw_data = self.raw_data - self.context.load() + self.context.intermediate.igood = np.ones_like(probe_layout.chanMap, dtype=bool) - self.update_context_with_good_channels() + self.context.load() @QtCore.pyqtSlot(object) def update_context(self, context): self.context = context - def setup_probe_view(self): + @QtCore.pyqtSlot() + def update_probe_view(self): self.probe_view_box.set_layout(self.context) def update_run_box(self): diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index a4223e7..63f8686 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -83,11 +83,14 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() - self.set_active_layout(context.probe, context.intermediate.igood) + probe = context.probe + good_channels = context.intermediate.igood + + self.set_active_layout(probe, good_channels) self.update_probe_view() - def set_active_layout(self, probe, good_channels=None): + def set_active_layout(self, probe, good_channels): self.active_layout = probe self.kcoords = self.active_layout.kcoords self.xc, self.yc = self.active_layout.xc, self.active_layout.yc @@ -96,10 +99,7 @@ def set_active_layout(self, probe, good_channels=None): self.channel_map_dict[(xc, yc)] = ind self.total_channels = self.active_layout.NchanTOT self.channel_map = self.active_layout.chanMap - if good_channels is None: - self.good_channels = np.ones_like(self.channel_map, dtype=bool) - else: - self.good_channels = good_channels + self.good_channels = good_channels def on_points_clicked(self, points): selected_point = points.ptsClicked[0] @@ -153,7 +153,8 @@ def update_probe_view(self): @QtCore.pyqtSlot(object) def preview_probe(self, probe): self.probe_view.clear() - self.set_active_layout(probe) + good_channels_dummy = np.ones_like(probe.chanMap, dtype=bool) + self.set_active_layout(probe, good_channels_dummy) self.create_plot(connect=False) def create_plot(self, connect=True): diff --git a/pykilosort/gui/run_box.py b/pykilosort/gui/run_box.py index 515149c..e6dc39a 100644 --- a/pykilosort/gui/run_box.py +++ b/pykilosort/gui/run_box.py @@ -6,6 +6,7 @@ class RunBox(QtWidgets.QGroupBox): updateContext = QtCore.pyqtSignal(object) + updateProbeView = QtCore.pyqtSignal() sortingStepStatusUpdate = QtCore.pyqtSignal(dict) disableInput = QtCore.pyqtSignal(bool) @@ -117,6 +118,7 @@ def set_sorting_step_status(self, step, status): @QtCore.pyqtSlot(object) def finished_preprocess(self, context): self.updateContext.emit(context) + self.updateProbeView.emit() self.set_sorting_step_status("preprocess", True) @QtCore.pyqtSlot(object) diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index e71e77c..06f72de 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -2,35 +2,10 @@ import numpy as np from numba import jit from pykilosort.main import run_export, run_preprocess, run_spikesort -from pykilosort.preprocess import get_good_channels, get_whitening_matrix, gpufilter +from pykilosort.preprocess import get_whitening_matrix, gpufilter from PyQt5 import QtCore -def find_good_channels(context): - params = context.params - probe = context.probe - raw_data = context.raw_data - intermediate = context.intermediate - - if "igood" not in intermediate: - if params.minfr_goodchannels > 0: # discard channels that have very few spikes - # determine bad channels - with context.time("good_channels"): - intermediate.igood = get_good_channels( - raw_data=raw_data, probe=probe, params=params - ) - intermediate.igood = intermediate.igood.ravel() - # Cache the result. - context.write(igood=intermediate.igood) - - else: - intermediate.igood = np.ones_like(probe.chanMap, dtype=bool) - - probe.Nchan = len(probe.chanMap) - context.probe = probe - return context - - def filter_and_whiten(raw_traces, params, probe, whitening_matrix): sample_rate = params.fs high_pass_freq = params.fshigh @@ -122,7 +97,6 @@ def get_predicted_traces( class KiloSortWorker(QtCore.QThread): - foundGoodChannels = QtCore.pyqtSignal(object) finishedPreprocess = QtCore.pyqtSignal(object) finishedSpikesort = QtCore.pyqtSignal(object) finishedAll = QtCore.pyqtSignal(object) @@ -162,7 +136,3 @@ def run(self): if "export" in self.steps: run_export(self.context, self.data_path, self.output_directory) self.finishedAll.emit(self.context) - - if "goodchannels" in self.steps: - self.context = find_good_channels(self.context) - self.foundGoodChannels.emit(self.context) From f2c77ab34fce3affde70c316d084ff4e747550e8 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Mon, 30 Nov 2020 14:45:16 -0500 Subject: [PATCH 22/42] remove calculation of good channels on data load --- pykilosort/gui/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index a1f2549..671148e 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import KiloSortWorker, filter_and_whiten, find_good_channels +from .sorter import KiloSortWorker, filter_and_whiten from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI From 4f1d1bcc6a4b6776066c6157a36fb8c3ac69e299 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 1 Dec 2020 14:17:54 -0500 Subject: [PATCH 23/42] fix last channel marked as inactive --- pykilosort/gui/probe_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 63f8686..334cc34 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -79,7 +79,7 @@ def set_active_channels(self): end_channel_position = int(end_channel_position) self.active_channels = channel_map[ primary_channel_position:end_channel_position+1 - ].tolist() + +1].tolist() def set_layout(self, context): self.probe_view.clear() From cd2e9647cbe333b74ccf3b435382548deed24e69 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Thu, 3 Dec 2020 09:36:19 -0500 Subject: [PATCH 24/42] Connect sorting status to probe view box --- pykilosort/gui/main.py | 1 + pykilosort/gui/probe_view_box.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 0f56c81..9831209 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -285,6 +285,7 @@ def update_run_box(self): @QtCore.pyqtSlot(dict) def update_sorting_status(self, status_dict): self.data_view_box.change_sorting_status(status_dict) + self.probe_view_box.change_sorting_status(status_dict) def get_context(self): return self.context diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 334cc34..d053361 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -33,6 +33,12 @@ def __init__(self, parent): self.channel_map_dict = {} self.good_channels = None + self.sorting_status = { + "preprocess": False, + "spikesort": False, + "export": False + } + self.configuration = { "active_channel": "g", "good_channel": "b", @@ -121,6 +127,9 @@ def synchronize_data_view_mode(self, string): def synchronize_primary_channel(self): self.primary_channel = self.gui.data_view_box.primary_channel + def change_sorting_status(self, status_dict): + self.sorting_status = status_dict + def generate_spots_list(self): spots = [] size = 10 From a1b7c22e4bbd02a28fcfe74d6230c6f89a1c4061 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Thu, 3 Dec 2020 09:36:43 -0500 Subject: [PATCH 25/42] minor bug fix --- pykilosort/gui/probe_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index d053361..da77682 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -138,7 +138,7 @@ def generate_spots_list(self): for ind, (x_pos, y_pos) in enumerate(zip(self.xc, self.yc)): pos = (x_pos, y_pos) good_channel = self.good_channels[ind] - is_active = np.isin(ind, self.active_channels) + is_active = np.isin(ind, self.active_channels).tolist() if not good_channel: color = self.configuration["bad_channel"] elif good_channel and is_active: From bc8bc0259ab5fe08be49f514c00761742bccd7cb Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Fri, 4 Dec 2020 14:04:20 -0500 Subject: [PATCH 26/42] use raw_probe to set layout in probe view box --- pykilosort/gui/main.py | 1 + pykilosort/gui/probe_view_box.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 9831209..68bb4dc 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -262,6 +262,7 @@ def setup_context(self): probe_layout = self.probe_layout probe_layout.Nchan = len(probe_layout.chanMap) self.context.probe = probe_layout + self.context.raw_probe = copy_bunch(probe_layout) self.context.params = self.params self.context.raw_data = self.raw_data diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index da77682..edf9e3a 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -89,7 +89,7 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() - probe = context.probe + probe = context.raw_probe good_channels = context.intermediate.igood self.set_active_layout(probe, good_channels) From 8749817cad8bfbfaf1e7ea1e5f0693e09184f414 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Mon, 7 Dec 2020 08:58:10 -0500 Subject: [PATCH 27/42] fixup! use raw_probe to set layout in probe view box --- pykilosort/gui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 68bb4dc..4df7025 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -16,7 +16,7 @@ ) from pykilosort.gui.logger import setup_logger from pykilosort.params import KilosortParams -from pykilosort.utils import Context +from pykilosort.utils import Context, copy_bunch from PyQt5 import QtCore, QtGui, QtWidgets logger = setup_logger(__name__) From e4f0d3e8648d0fdbc641a3e5b9889b8d91a7b811 Mon Sep 17 00:00:00 2001 From: Shashwat Sridhar Date: Thu, 26 Nov 2020 16:38:33 -0500 Subject: [PATCH 28/42] remove calculation of good channels during data load --- pykilosort/gui/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index 671148e..32a4f35 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import KiloSortWorker, filter_and_whiten +from .sorter import filter_and_whiten, KiloSortWorker from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI From 0054b5b9c8906f64c586b8327f1d2f667182e521 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 1 Dec 2020 14:17:54 -0500 Subject: [PATCH 29/42] fix last channel marked as inactive --- pykilosort/gui/probe_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index edf9e3a..fefcf62 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -85,7 +85,7 @@ def set_active_channels(self): end_channel_position = int(end_channel_position) self.active_channels = channel_map[ primary_channel_position:end_channel_position+1 - +1].tolist() + ].tolist() def set_layout(self, context): self.probe_view.clear() From f81bf257615b85227139095e3a312607ac3271ab Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:13:32 -0400 Subject: [PATCH 30/42] fix minor imports and remove redundant function --- pykilosort/gui/__init__.py | 2 +- pykilosort/gui/sorter.py | 16 +--------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/pykilosort/gui/__init__.py b/pykilosort/gui/__init__.py index 32a4f35..47c5457 100644 --- a/pykilosort/gui/__init__.py +++ b/pykilosort/gui/__init__.py @@ -10,7 +10,7 @@ from .probe_view_box import ProbeViewBox from .run_box import RunBox from .settings_box import SettingsBox -from .sorter import filter_and_whiten, KiloSortWorker +from .sorter import filter_and_whiten, get_predicted_traces, KiloSortWorker from .sanity_plots import SanityPlotWidget from .main import KiloSortGUI diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index 06f72de..644b118 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -2,7 +2,7 @@ import numpy as np from numba import jit from pykilosort.main import run_export, run_preprocess, run_spikesort -from pykilosort.preprocess import get_whitening_matrix, gpufilter +from pykilosort.preprocess import gpufilter from PyQt5 import QtCore @@ -32,20 +32,6 @@ def filter_and_whiten(raw_traces, params, probe, whitening_matrix): return whitened_array.get() -def get_whitened_traces(raw_data, probe, params, whitening_matrix): - if whitening_matrix is None: - whitening_matrix = get_whitening_matrix( - raw_data=raw_data, probe=probe, params=params, nSkipCov=100 - ) - whitened_traces = filter_and_whiten( - raw_traces=raw_data, - params=params, - probe=probe, - whitening_matrix=whitening_matrix, - ) - return whitened_traces, whitening_matrix - - @jit(nopython=True) def get_predicted_traces( matrix_U: np.ndarray, From 9a5a3aee39a9262f8208ee00777fd0f23d9c5b90 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:14:54 -0400 Subject: [PATCH 31/42] remove redundant import --- pykilosort/gui/data_view_box.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index c73b8c2..df19b74 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -2,6 +2,7 @@ import typing as t import pyqtgraph as pg from cupy import asnumpy +from datetime import datetime from pykilosort.gui.logger import setup_logger from pykilosort.gui.minor_gui_elements import controls_popup_text from pykilosort.gui.palettes import COLORMAP_COLORS From 9a11013e7f3b7014c9a696853706ce97c8c15213 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:26:15 -0400 Subject: [PATCH 32/42] code readability improvements --- pykilosort/gui/data_view_box.py | 49 +++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index df19b74..1ed34e9 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -226,46 +226,46 @@ def on_views_clicked(self): @QtCore.pyqtSlot(int) def on_wheel_scroll(self, direction): - if self.gui.context is not None: + if self.context_set(): self.shift_current_time(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_control(self, direction): - if self.gui.context is not None: - if self.traces_button.isChecked(): + if self.context_set(): + if self.traces_mode_active(): self.shift_primary_channel(direction) else: self.change_displayed_channel_count(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_shift(self, direction): - if self.gui.context is not None: + if self.context_set(): self.change_plot_range(direction) @QtCore.pyqtSlot(int) def on_wheel_scroll_plus_alt(self, direction): - if self.gui.context is not None: + if self.context_set(): self.change_plot_scaling(direction) def toggle_mode_from_click(self): - if self.traces_button.isChecked(): + if self.traces_mode_active(): self.modeChanged.emit("traces") self.view_buttons_group.setExclusive(False) self.update_plot() - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self.modeChanged.emit("colormap") self._traces_to_colormap_toggle() self.update_plot() def toggle_mode(self): - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self.traces_button.toggle() self.modeChanged.emit("traces") self.view_buttons_group.setExclusive(False) self.update_plot() - elif self.traces_button.isChecked(): + elif self.traces_mode_active(): self.colormap_button.toggle() self.modeChanged.emit("colormap") self._traces_to_colormap_toggle() @@ -302,6 +302,15 @@ def _traces_to_colormap_toggle(self): self.view_buttons_group.setExclusive(True) + def traces_mode_active(self): + return self.traces_button.isChecked() + + def colormap_mode_active(self): + return self.colormap_button.isChecked() + + def context_set(self): + return self.gui.context is not None + def change_primary_channel(self, channel): self.primary_channel = channel self.channelChanged.emit() @@ -317,7 +326,7 @@ def shift_primary_channel(self, shift): self.update_plot() def get_currently_displayed_channel_count(self): - if self.traces_button.isChecked(): + if self.traces_mode_active(): return self.channels_displayed_traces else: count = self.channels_displayed_colormap @@ -326,7 +335,7 @@ def get_currently_displayed_channel_count(self): return count def set_currently_displayed_channel_count(self, count): - if self.traces_button.isChecked(): + if self.traces_mode_active(): self.channels_displayed_traces = count else: self.channels_displayed_colormap = count @@ -372,14 +381,14 @@ def change_plot_range(self, direction): self.update_plot() def change_plot_scaling(self, direction): - if self.traces_button.isChecked(): + if self.traces_mode_active(): scale_factor = self.scale_factor * (1.1 ** direction) if 0.1 < scale_factor < 10.0: self.scale_factor = scale_factor self.update_plot() - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): colormap_min = self.colormap_min + (direction * 0.05) colormap_max = self.colormap_max - (direction * 0.05) if 0.0 <= colormap_min < colormap_max <= 1.0: @@ -398,8 +407,8 @@ def change_channel_display(self, direction): self.change_displayed_channel_count(direction) def scene_clicked(self, ev): - if self.gui.context is not None: - if self.traces_button.isChecked(): + if self.context_set(): + if self.traces_mode_active(): x_pos = ev.pos().x() else: x_pos = self.colormap_image.mapFromScene(ev.pos()).x() @@ -412,7 +421,7 @@ def scene_clicked(self, ev): self.shift_current_time(direction=-1) def seek_clicked(self, ev): - if self.gui.context is not None: + if self.context_set(): new_time = self.seek_view_box.mapSceneToView(ev.pos()).x() seek_range_min = self.seek_range[0] seek_range_max = self.seek_range[1] @@ -446,7 +455,7 @@ def change_sorting_status(self, status_dict): self.enable_view_buttons() def enable_view_buttons(self): - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): if self.prediction_button.isChecked() or self.residual_button.isChecked(): self.raw_button.click() else: @@ -533,13 +542,13 @@ def add_traces_to_plot_items( ): """ Update plot items with traces. - + Loops over traces and plots each trace using the setData() method of pyqtgraph's PlotCurveItem. The color of the trace depends on the mode requested (raw, whitened, prediction, residual). Bad channels are plotted in a different color. Each trace is also scaled by a certain factor defined in self.traces_scaling_factor. - + Parameters ---------- traces : numpy.ndarray @@ -666,7 +675,7 @@ def update_plot(self, context=None): end_time=end_time, ) - if self.colormap_button.isChecked(): + if self.colormap_mode_active(): self._update_colormap(params=params, probe=probe, raw_data=raw_data, From 6db8a12910292d635ce0523ee7768d4546ddf438 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:34:18 -0400 Subject: [PATCH 33/42] change data_view and probe_view sync signal * change signal signature --- pykilosort/gui/data_view_box.py | 17 ++++++++++------- pykilosort/gui/probe_view_box.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 1ed34e9..11d5013 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -14,7 +14,7 @@ class DataViewBox(QtWidgets.QGroupBox): - channelChanged = QtCore.pyqtSignal() + channelChanged = QtCore.pyqtSignal(int, int) modeChanged = QtCore.pyqtSignal(str) def __init__(self, parent): @@ -313,7 +313,7 @@ def context_set(self): def change_primary_channel(self, channel): self.primary_channel = channel - self.channelChanged.emit() + self.channelChanged.emit(self.primary_channel, self.get_currently_displayed_channel_count()) self.update_plot() def shift_primary_channel(self, shift): @@ -322,7 +322,7 @@ def shift_primary_channel(self, shift): total_channels = self.get_total_channels() if (0 <= primary_channel < total_channels) and total_channels is not None: self.primary_channel = primary_channel - self.channelChanged.emit() + self.channelChanged.emit(self.primary_channel, self.get_currently_displayed_channel_count()) self.update_plot() def get_currently_displayed_channel_count(self): @@ -344,22 +344,25 @@ def get_total_channels(self): return self.gui.probe_view_box.total_channels def change_displayed_channel_count(self, direction): + current_channel = self.primary_channel total_channels = self.get_total_channels() current_count = self.get_currently_displayed_channel_count() new_count = current_count + (direction * 5) - if 0 < new_count <= total_channels: + if (current_channel + new_count) <= total_channels: self.set_currently_displayed_channel_count(new_count) + self.channelChanged.emit(self.primary_channel, new_count) self.refresh_plot_on_displayed_channel_count_change() elif new_count <= 0 and current_count != 1: self.set_currently_displayed_channel_count(1) + self.channelChanged.emit(self.primary_channel, 1) self.refresh_plot_on_displayed_channel_count_change() - elif new_count > total_channels: - self.set_currently_displayed_channel_count(total_channels) + elif (current_channel + new_count) > total_channels: + self.set_currently_displayed_channel_count(total_channels - current_channel) + self.channelChanged.emit(self.primary_channel, total_channels) self.refresh_plot_on_displayed_channel_count_change() def refresh_plot_on_displayed_channel_count_change(self): - self.channelChanged.emit() self.plot_item.clear() self.create_plot_items() self.update_plot() diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index fefcf62..0eb2215 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -153,10 +153,15 @@ def generate_spots_list(self): return spots - @QtCore.pyqtSlot() - def update_probe_view(self): - self.synchronize_primary_channel() - self.set_active_channels() + @QtCore.pyqtSlot(int, int) + def update_probe_view(self, primary_channel=None, channels_displayed=None): + if primary_channel is not None: + self.primary_channel = primary_channel + + if channels_displayed is None: + channels_displayed = self.total_channels + + self.set_active_channels(channels_displayed) self.create_plot() @QtCore.pyqtSlot(object) From a94a65b1f5a8be5d85bc655ed187ef4393c7cbd6 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:37:46 -0400 Subject: [PATCH 34/42] minor changes to data_view and probe_view sync * add missing pyqtSlot decorators * coda readability improvements --- pykilosort/gui/data_view_box.py | 1 + pykilosort/gui/probe_view_box.py | 12 ++++-------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 11d5013..ff4079e 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -247,6 +247,7 @@ def on_wheel_scroll_plus_alt(self, direction): if self.context_set(): self.change_plot_scaling(direction) + @QtCore.pyqtSlot() def toggle_mode_from_click(self): if self.traces_mode_active(): self.modeChanged.emit("traces") diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 0eb2215..c2eb5ba 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -116,16 +116,12 @@ def on_points_clicked(self, points): channel = self.channel_map[index] self.channelSelected.emit(channel) - def synchronize_data_view_mode(self, string): - old_mode = self.active_data_view_mode - self.active_data_view_mode = string - - if old_mode != self.active_data_view_mode and self.primary_channel is not None: + @QtCore.pyqtSlot(str) + def synchronize_data_view_mode(self, mode: str): + if self.active_data_view_mode != mode: self.probe_view.clear() self.update_probe_view() - - def synchronize_primary_channel(self): - self.primary_channel = self.gui.data_view_box.primary_channel + self.active_data_view_mode = mode def change_sorting_status(self, status_dict): self.sorting_status = status_dict From 76cd184ae723095252523e92bbefd368efa5a739 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 18:39:11 -0400 Subject: [PATCH 35/42] remove redundant function --- pykilosort/gui/data_view_box.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index ff4079e..be164c0 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -404,12 +404,6 @@ def change_plot_scaling(self, direction): self.update_plot() - def change_channel_display(self, direction): - if self.traces_button.isChecked(): - self.shift_primary_channel(direction) - else: - self.change_displayed_channel_count(direction) - def scene_clicked(self, ev): if self.context_set(): if self.traces_mode_active(): From f8f3451375af1d074086f9bb22231e2d5c1ae8ed Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:04:50 -0400 Subject: [PATCH 36/42] introduce chanMapBackup as a property of Probe * preserves original chanMap, which is overwritten during preprocess * use chanMapBackup to generate default good_channels arrays * remove risky usages of chanMap --- pykilosort/gui/data_view_box.py | 10 +++- pykilosort/gui/main.py | 10 ++-- pykilosort/gui/minor_gui_elements.py | 1 + pykilosort/gui/probe_view_box.py | 13 +++-- pykilosort/gui/sorter.py | 4 +- pykilosort/utils.py | 75 ++++++++++++++++++++++++++++ 6 files changed, 100 insertions(+), 13 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index be164c0..adc94eb 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -630,11 +630,16 @@ def get_whitened_traces( raw_data=raw_data, params=params, probe=probe, nSkipCov=nSkipCov ) + good_channels = intermediate.igood.ravel() \ + if "igood" in intermediate \ + else np.ones_like(probe.chanMapBackup) + whitened_traces = filter_and_whiten( raw_traces=raw_traces, params=params, probe=probe, whitening_matrix=self.whitening_matrix, + good_channels=good_channels, ) return whitened_traces @@ -652,7 +657,10 @@ def update_plot(self, context=None): probe = context.probe raw_data = context.raw_data intermediate = context.intermediate - good_channels = intermediate.igood.ravel() + try: + good_channels = intermediate.igood.ravel() + except AttributeError: + good_channels = np.ones_like(probe.chanMapBackup, dtype=bool) sample_rate = raw_data.sample_rate diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 4df7025..0c94a8c 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -16,7 +16,7 @@ ) from pykilosort.gui.logger import setup_logger from pykilosort.params import KilosortParams -from pykilosort.utils import Context, copy_bunch +from pykilosort.utils import Context, copy_bunch, extend_probe from PyQt5 import QtCore, QtGui, QtWidgets logger = setup_logger(__name__) @@ -260,14 +260,12 @@ def setup_context(self): self.context = Context(context_path=context_path) probe_layout = self.probe_layout - probe_layout.Nchan = len(probe_layout.chanMap) - self.context.probe = probe_layout - self.context.raw_probe = copy_bunch(probe_layout) + probe_layout.Nchan = self.num_channels + self.context.probe = extend_probe(probe_layout) + self.context.raw_probe = extend_probe(copy_bunch(probe_layout)) self.context.params = self.params self.context.raw_data = self.raw_data - self.context.intermediate.igood = np.ones_like(probe_layout.chanMap, dtype=bool) - self.context.load() @QtCore.pyqtSlot(object) diff --git a/pykilosort/gui/minor_gui_elements.py b/pykilosort/gui/minor_gui_elements.py index b4dd080..b6fd29d 100644 --- a/pykilosort/gui/minor_gui_elements.py +++ b/pykilosort/gui/minor_gui_elements.py @@ -199,6 +199,7 @@ def construct_probe(self): probe.yc = self.y_coords probe.kcoords = self.k_coords probe.chanMap = self.channel_map + probe.chanMapBackup = probe.chanMap.copy() probe.bad_channels = self.bad_channels probe.NchanTOT = len(self.x_coords) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index c2eb5ba..eb3c8af 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -96,7 +96,7 @@ def set_layout(self, context): self.update_probe_view() - def set_active_layout(self, probe, good_channels): + def set_active_layout(self, probe, good_channels=None): self.active_layout = probe self.kcoords = self.active_layout.kcoords self.xc, self.yc = self.active_layout.xc, self.active_layout.yc @@ -105,7 +105,13 @@ def set_active_layout(self, probe, good_channels): self.channel_map_dict[(xc, yc)] = ind self.total_channels = self.active_layout.NchanTOT self.channel_map = self.active_layout.chanMap - self.good_channels = good_channels + if good_channels is None: + self.good_channels = np.ones_like( + self.active_layout.chanMapBackup, + dtype=bool + ) + else: + self.good_channels = good_channels def on_points_clicked(self, points): selected_point = points.ptsClicked[0] @@ -163,8 +169,7 @@ def update_probe_view(self, primary_channel=None, channels_displayed=None): @QtCore.pyqtSlot(object) def preview_probe(self, probe): self.probe_view.clear() - good_channels_dummy = np.ones_like(probe.chanMap, dtype=bool) - self.set_active_layout(probe, good_channels_dummy) + self.set_active_layout(probe) self.create_plot(connect=False) def create_plot(self, connect=True): diff --git a/pykilosort/gui/sorter.py b/pykilosort/gui/sorter.py index 644b118..9437f61 100644 --- a/pykilosort/gui/sorter.py +++ b/pykilosort/gui/sorter.py @@ -6,7 +6,7 @@ from PyQt5 import QtCore -def filter_and_whiten(raw_traces, params, probe, whitening_matrix): +def filter_and_whiten(raw_traces, params, probe, whitening_matrix, good_channels): sample_rate = params.fs high_pass_freq = params.fshigh low_pass_freq = params.fslow @@ -18,7 +18,7 @@ def filter_and_whiten(raw_traces, params, probe, whitening_matrix): filtered_data = gpufilter( buff=cp.asarray(raw_traces, dtype=np.float32), - chanMap=probe.chanMap, + chanMap=probe.chanMapBackup[good_channels], fs=sample_rate, fslow=low_pass_freq, fshigh=high_pass_freq, diff --git a/pykilosort/utils.py b/pykilosort/utils.py index df5f7d2..a16eea8 100644 --- a/pykilosort/utils.py +++ b/pykilosort/utils.py @@ -394,6 +394,7 @@ def load_probe(probe_path): probe.yc.append([pos[c][1] for c in ch]) probe.kcoords.append([cg for c in ch]) probe.chanMap = np.concatenate(probe.chanMap).ravel().astype(np.int32) + probe.chanMapBackup = probe.chanMap.copy() probe.xc = np.concatenate(probe.xc) probe.yc = np.concatenate(probe.yc) probe.kcoords = np.concatenate(probe.kcoords) @@ -406,6 +407,7 @@ def load_probe(probe_path): probe.yc = mat['ycoords'].ravel().astype(np.float64) probe.kcoords = mat.get('kcoords', np.zeros(nc)).ravel().astype(np.float64) probe.chanMap = (mat['chanMap'] - 1).ravel().astype(np.int32) # NOTE: 0-indexing in Python + probe.chanMapBackup = probe.chanMap.copy() probe.NchanTOT = len(probe.chanMap) # NOTE: should match the # of columns in the raw data for n in _required_keys: @@ -444,6 +446,79 @@ def create_prb(probe): return probe_prb +def extend_probe( + probe_layout: Bunch +) -> Bunch: + """ + Extend probe layout to account for extra num_channels. + + The probe layout selected by the user may have a different + number of channels than the requested number of channels in + the dataset. The function attempts to smartly extend the + layout of the probe to match the requested number of + channels. + + In case the requested number of channels is less than the + total channels on the layout, the original probe layout + is returned. + + Parameters + ---------- + probe_layout : Bunch + Input probe layout which might have to be extended. + + Returns + ------- + probe_layout : Bunch + Possibly extended probe layout. + + """ + if len(probe_layout["xc"]) >= probe_layout["Nchan"]: + # if the requested number of channels is less than the number of + # channels on the probe, return original probe + return probe_layout + else: + n_channels = probe_layout["Nchan"] + xc = probe_layout["xc"] + yc = probe_layout["yc"] + + # the assumption here is that the probe layout has a repetitive pattern + unique_x = np.sort(np.unique(probe_layout["xc"])) # unique values for the x-axis + unique_y = np.sort(np.unique(probe_layout["yc"])) # unique values for the y-axis + + kcoords = probe_layout["kcoords"] + chan_map = probe_layout["chanMap"] + + new_channels = n_channels - len(xc) # number of new channels to be added + + # create new properties + # this will probably break if new_channels > len(unique_x/y) + append_x = unique_x[-new_channels] + append_y = unique_y[-new_channels] + + append_chan_map = np.array([ + i for i in np.arange(chan_map[-1], + chan_map[-1] + new_channels) + ]) + 1 # to account for 0-ordering + append_kcoords = np.zeros(new_channels) + + # append new properties to existing properties + new_xc = np.append(xc, append_x) + new_yc = np.append(yc, append_y) + new_chan_map = np.append(chan_map, append_chan_map) + new_kcoords = np.append(kcoords, append_kcoords) + + # save new properties into probe layout + probe_layout["xc"] = new_xc + probe_layout["yc"] = new_yc + probe_layout["chanMap"] = new_chan_map + probe_layout["chanMapBackup"] = new_chan_map.copy() + probe_layout["kcoords"] = new_kcoords + probe_layout["NchanTOT"] = len(new_chan_map) + + return probe_layout + + def plot_dissimilarity_matrices(ccb, ccbsort, plot_widget): ccb = cp.asnumpy(ccb) ccbsort = cp.asnumpy(ccbsort) From c8565ef1f870712adb357b272cb32f758b2a872f Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:06:08 -0400 Subject: [PATCH 37/42] refactor add_traces_to_plot_items() --- pykilosort/gui/data_view_box.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index adc94eb..d37c89b 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -557,22 +557,17 @@ def add_traces_to_plot_items( view : str One of "raw", "whitened", "prediction" and "residual" views """ - for c, channel in enumerate( - range( - self.primary_channel, - self.primary_channel + self.channels_displayed_traces, - ) - ): + for c, good in enumerate(good_channels): try: curve = self.traces_plot_items[view][c] color = ( self.traces_curve_color[view] - if good_channels[channel] + if good else self.bad_channel_color ) curve.setPen(color=color, width=1) curve.setData( - traces.T[channel] * + traces[:, c] * self.scale_factor * self.traces_scaling_factor[view] ) From 63dca73f7982cd3add1a1e076062a4dfa9e7d710 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:08:35 -0400 Subject: [PATCH 38/42] same behaviour for ctrl+up/down & ctrl+scroll --- pykilosort/gui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/main.py b/pykilosort/gui/main.py index 0c94a8c..98cf7e6 100644 --- a/pykilosort/gui/main.py +++ b/pykilosort/gui/main.py @@ -167,7 +167,7 @@ def setup(self): def change_channel_display(self, direction): if self.context is not None: - self.data_view_box.change_channel_display(direction) + self.data_view_box.shift_primary_channel(direction) def shift_data(self, time_shift): if self.context is not None: From f375c77119339b55ee0e5d31f85d9a5a84b7393a Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:10:05 -0400 Subject: [PATCH 39/42] code refactor --- pykilosort/gui/probe_view_box.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index eb3c8af..582dede 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -46,7 +46,7 @@ def __init__(self, parent): } self.active_data_view_mode = "colormap" - self.primary_channel = None + self.primary_channel = 0 self.active_channels = [] def setup(self): @@ -62,20 +62,13 @@ def setup(self): layout.addWidget(self.probe_view, 95) self.setLayout(layout) - def set_active_channels(self): - if self.active_data_view_mode == "traces": - displayed_channels = self.gui.data_view_box.channels_displayed_traces - else: - displayed_channels = self.gui.data_view_box.channels_displayed_colormap - if displayed_channels is None: - displayed_channels = self.total_channels - + def set_active_channels(self, channels_displayed): primary_channel = self.primary_channel channel_map = np.array(self.channel_map) primary_channel_position = int(np.where(channel_map == primary_channel)[0]) end_channel_position = np.where( - channel_map == primary_channel + displayed_channels + channel_map == primary_channel + channels_displayed )[0] # prevent the last displayed channel would be set as the end channel in the case that # `primary_channel + displayed_channels` exceeds the total number of channels in the channel map @@ -90,7 +83,10 @@ def set_active_channels(self): def set_layout(self, context): self.probe_view.clear() probe = context.raw_probe - good_channels = context.intermediate.igood + try: + good_channels = context.intermediate.igood + except AttributeError: + good_channels = None self.set_active_layout(probe, good_channels) @@ -184,7 +180,7 @@ def reset(self): self.clear_plot() self.reset_current_probe_layout() self.reset_active_data_view_mode() - self.primary_channel = None + self.primary_channel = 0 def reset_active_data_view_mode(self): self.active_data_view_mode = "colormap" From 86b8d585e32bc7487e51d7ae3747761a4e36fe0a Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 19:13:57 -0400 Subject: [PATCH 40/42] major refactor of plotting code * correctly account for good channels and determine channels to be displayed --- pykilosort/gui/data_view_box.py | 93 ++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index d37c89b..34e0b55 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -643,7 +643,7 @@ def update_plot(self, context=None): if context is None: context = self.gui.context - if context is not None: + if context is not None: # since context may still be None if self.colormap_image is not None: self.plot_item.removeItem(self.colormap_image) self.colormap_image = None @@ -666,10 +666,27 @@ def update_plot(self, context=None): self.data_view_widget.setXRange(0, time_range, padding=0.0) self.data_view_widget.setLimits(xMin=0, xMax=time_range) - if self.traces_button.isChecked(): + orig_chan_map = probe.chanMapBackup + max_channels = np.size(orig_chan_map) + + start_channel = self.primary_channel + active_channels = self.get_currently_displayed_channel_count() + if active_channels is None: + active_channels = self.get_total_channels() + end_channel = start_channel + active_channels + + # good channels after start_channel to display + to_display = np.arange(start_channel, end_channel, dtype=int) + to_display = to_display[to_display < max_channels] + + raw_traces = raw_data[start_time:end_time] + + if self.traces_mode_active(): self._update_traces(params=params, probe=probe, raw_data=raw_data, + raw_traces=raw_traces, + to_display=to_display, intermediate=intermediate, good_channels=good_channels, start_time=start_time, @@ -680,6 +697,8 @@ def update_plot(self, context=None): self._update_colormap(params=params, probe=probe, raw_data=raw_data, + raw_traces=raw_traces, + to_display=to_display, intermediate=intermediate, good_channels=good_channels, start_time=start_time, @@ -697,15 +716,24 @@ def update_plot(self, context=None): self.data_view_widget.autoRange() - def _update_traces(self, params, probe, raw_data, intermediate, good_channels, start_time, end_time): + def _update_traces( + self, + params, + probe, + raw_data, + raw_traces, + to_display, + intermediate, + good_channels, + start_time, + end_time + ): self.hide_inactive_traces() - raw_traces = raw_data[start_time:end_time] - if self.raw_button.isChecked(): self.add_traces_to_plot_items( - traces=raw_traces, - good_channels=good_channels, + traces=raw_traces[:, to_display], + good_channels=good_channels[to_display], view="raw", ) @@ -724,9 +752,11 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: whitened_traces = self.whitened_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = whitened_traces self.add_traces_to_plot_items( - traces=whitened_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="whitened", ) @@ -742,9 +772,11 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: prediction_traces = self.prediction_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = prediction_traces self.add_traces_to_plot_items( - traces=prediction_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="prediction", ) @@ -784,27 +816,35 @@ def _update_traces(self, params, probe, raw_data, intermediate, good_channels, s else: residual_traces = self.residual_traces + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) + processed_traces[:, good_channels] = residual_traces self.add_traces_to_plot_items( - traces=residual_traces, - good_channels=good_channels, + traces=processed_traces[:, to_display], + good_channels=good_channels[to_display], view="residual", ) - def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, start_time, end_time): + def _update_colormap( + self, + params, + probe, + raw_data, + raw_traces, + to_display, + intermediate, + good_channels, + start_time, + end_time + ): self.hide_traces() - raw_traces = raw_data[start_time:end_time] - - start_channel = self.primary_channel - displayed_channels = self.channels_displayed_colormap - if displayed_channels is None: - displayed_channels = self.get_total_channels() - end_channel = start_channel + displayed_channels + processed_traces = np.zeros_like(raw_traces, dtype=np.int16) if self.raw_button.isChecked(): colormap_min, colormap_max = -32.0, 32.0 + self.add_image_to_plot( - raw_traces[:, start_channel:end_channel], + raw_traces[:, to_display], colormap_min, colormap_max, ) @@ -824,8 +864,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: whitened_traces = self.whitened_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = whitened_traces self.add_image_to_plot( - whitened_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) @@ -842,8 +883,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: prediction_traces = self.prediction_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = prediction_traces self.add_image_to_plot( - prediction_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) @@ -884,8 +926,9 @@ def _update_colormap(self, params, probe, raw_data, intermediate, good_channels, else: residual_traces = self.residual_traces colormap_min, colormap_max = -4.0, 4.0 + processed_traces[:, good_channels] = residual_traces self.add_image_to_plot( - residual_traces[:, start_channel:end_channel], + processed_traces[:, to_display], colormap_min, colormap_max, ) From e6160465883b209d4636b0ecc5ad5f01aa1c42e1 Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 22:04:15 -0400 Subject: [PATCH 41/42] minor bug fix --- pykilosort/gui/data_view_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 34e0b55..57e6319 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -627,7 +627,7 @@ def get_whitened_traces( good_channels = intermediate.igood.ravel() \ if "igood" in intermediate \ - else np.ones_like(probe.chanMapBackup) + else np.ones_like(probe.chanMapBackup, dtype=bool) whitened_traces = filter_and_whiten( raw_traces=raw_traces, From 24975df54e22896b6e142fccf79a4f110bd878ff Mon Sep 17 00:00:00 2001 From: Shashwat S Date: Tue, 25 May 2021 22:12:29 -0400 Subject: [PATCH 42/42] synchronize displayed channels across probe view and data view --- pykilosort/gui/data_view_box.py | 10 +++++----- pykilosort/gui/probe_view_box.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pykilosort/gui/data_view_box.py b/pykilosort/gui/data_view_box.py index 57e6319..859a5bd 100644 --- a/pykilosort/gui/data_view_box.py +++ b/pykilosort/gui/data_view_box.py @@ -15,7 +15,7 @@ class DataViewBox(QtWidgets.QGroupBox): channelChanged = QtCore.pyqtSignal(int, int) - modeChanged = QtCore.pyqtSignal(str) + modeChanged = QtCore.pyqtSignal(str, int) def __init__(self, parent): QtWidgets.QGroupBox.__init__(self, parent=parent) @@ -250,25 +250,25 @@ def on_wheel_scroll_plus_alt(self, direction): @QtCore.pyqtSlot() def toggle_mode_from_click(self): if self.traces_mode_active(): - self.modeChanged.emit("traces") + self.modeChanged.emit("traces", self.get_currently_displayed_channel_count()) self.view_buttons_group.setExclusive(False) self.update_plot() if self.colormap_mode_active(): - self.modeChanged.emit("colormap") + self.modeChanged.emit("colormap", self.get_currently_displayed_channel_count()) self._traces_to_colormap_toggle() self.update_plot() def toggle_mode(self): if self.colormap_mode_active(): self.traces_button.toggle() - self.modeChanged.emit("traces") + self.modeChanged.emit("traces", self.get_currently_displayed_channel_count()) self.view_buttons_group.setExclusive(False) self.update_plot() elif self.traces_mode_active(): self.colormap_button.toggle() - self.modeChanged.emit("colormap") + self.modeChanged.emit("colormap", self.get_currently_displayed_channel_count()) self._traces_to_colormap_toggle() self.update_plot() diff --git a/pykilosort/gui/probe_view_box.py b/pykilosort/gui/probe_view_box.py index 582dede..960031e 100644 --- a/pykilosort/gui/probe_view_box.py +++ b/pykilosort/gui/probe_view_box.py @@ -118,11 +118,11 @@ def on_points_clicked(self, points): channel = self.channel_map[index] self.channelSelected.emit(channel) - @QtCore.pyqtSlot(str) - def synchronize_data_view_mode(self, mode: str): + @QtCore.pyqtSlot(str, int) + def synchronize_data_view_mode(self, mode: str, channels_displayed: int): if self.active_data_view_mode != mode: self.probe_view.clear() - self.update_probe_view() + self.update_probe_view(channels_displayed=channels_displayed) self.active_data_view_mode = mode def change_sorting_status(self, status_dict):