Skip to content

Commit

Permalink
Merge pull request #5 from shashwatsridhar/fix/#43_speed_up_whitening
Browse files Browse the repository at this point in the history
Fix/MouseLand#43 speed up whitening
  • Loading branch information
shashwatsridhar authored May 26, 2021
2 parents 5eee5d5 + d16ad50 commit 12afe21
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 42 deletions.
126 changes: 84 additions & 42 deletions pykilosort/gui/data_view_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import typing as t
import pyqtgraph as pg
from cupy import asnumpy
from datetime import datetime
from pykilosort.gui.logger import setup_logger
from pykilosort.gui.minor_gui_elements import controls_popup_text
from pykilosort.gui.palettes import COLORMAP_COLORS
from pykilosort.gui.sorter import filter_and_whiten, get_predicted_traces
from pykilosort.preprocess import get_whitening_matrix
from pykilosort.preprocess import get_approx_whitening_matrix
from PyQt5 import QtCore, QtWidgets

logger = setup_logger(__name__)
Expand Down Expand Up @@ -111,6 +109,8 @@ def __init__(self, parent):
self.colormap_min, self.colormap_max
)

self.thread_pool = QtCore.QThreadPool()

self.setup()

def setup(self):
Expand Down Expand Up @@ -452,6 +452,7 @@ def change_sorting_status(self, status_dict):
self.sorting_status = status_dict
self.enable_view_buttons()

@QtCore.pyqtSlot()
def enable_view_buttons(self):
if self.colormap_mode_active():
if self.prediction_button.isChecked() or self.residual_button.isChecked():
Expand All @@ -462,6 +463,17 @@ def enable_view_buttons(self):
if self.residual_button.isChecked():
self.residual_button.click()

if self.whitening_matrix is not None:
self.whitened_button.setEnabled(True)
self.whitened_button.setStyleSheet(
"QPushButton {background-color: black; color: white;}"
)
else:
self.whitened_button.setDisabled(True)
self.whitened_button.setStyleSheet(
"QPushButton {background-color: black; color: gray;}"
)

if self.sorting_status["preprocess"] and self.sorting_status["spikesort"]:
self.prediction_button.setEnabled(True)
self.prediction_button.setStyleSheet(
Expand Down Expand Up @@ -614,30 +626,36 @@ def add_image_to_plot(self, raw_traces, level_min, level_max):
self.colormap_image = image_item
self.plot_item.addItem(image_item)

def get_whitened_traces(
self, raw_data, raw_traces, intermediate, params, probe, nSkipCov=None
):
@QtCore.pyqtSlot(object)
def set_whitening_matrix(self, array):
self.whitening_matrix = array

def calculate_approx_whitening_matrix(self, context):
raw_data = context.raw_data
params = context.params
probe = context.probe
intermediate = context.intermediate

@QtCore.pyqtSlot()
def _call_enable_buttons():
self.enable_view_buttons()

if "Wrot" in intermediate and self.whitening_matrix is None:
self.whitening_matrix = intermediate.Wrot

elif self.whitening_matrix is None:
self.whitening_matrix = get_whitening_matrix(
raw_data=raw_data, params=params, probe=probe, nSkipCov=nSkipCov
logger.info("Approx. whitening matrix loaded from existing context.")
_call_enable_buttons()

elif (self.whitening_matrix is None) and not (self.thread_pool.activeThreadCount() > 0):
whitening_worker = WhiteningMatrixCalculator(
raw_data=raw_data,
params=params,
probe=probe
)

good_channels = intermediate.igood.ravel() \
if "igood" in intermediate \
else np.ones_like(probe.chanMapBackup, dtype=bool)

whitened_traces = filter_and_whiten(
raw_traces=raw_traces,
params=params,
probe=probe,
whitening_matrix=self.whitening_matrix,
good_channels=good_channels,
)
whitening_worker.signals.result.connect(self.set_whitening_matrix)
whitening_worker.signals.finished.connect(_call_enable_buttons)

return whitened_traces
self.thread_pool.start(whitening_worker)

def update_plot(self, context=None):
if context is None:
Expand Down Expand Up @@ -684,7 +702,6 @@ def update_plot(self, context=None):
if self.traces_mode_active():
self._update_traces(params=params,
probe=probe,
raw_data=raw_data,
raw_traces=raw_traces,
to_display=to_display,
intermediate=intermediate,
Expand All @@ -696,7 +713,6 @@ def update_plot(self, context=None):
if self.colormap_mode_active():
self._update_colormap(params=params,
probe=probe,
raw_data=raw_data,
raw_traces=raw_traces,
to_display=to_display,
intermediate=intermediate,
Expand All @@ -720,7 +736,6 @@ def _update_traces(
self,
params,
probe,
raw_data,
raw_traces,
to_display,
intermediate,
Expand All @@ -739,13 +754,12 @@ def _update_traces(

if self.whitened_button.isChecked():
if self.whitened_traces is None:
whitened_traces = self.get_whitened_traces(
raw_data=raw_data,
whitened_traces = filter_and_whiten(
raw_traces=raw_traces,
intermediate=intermediate,
params=params,
probe=probe,
nSkipCov=100,
whitening_matrix=self.whitening_matrix,
good_channels=good_channels,
)

self.whitened_traces = whitened_traces
Expand Down Expand Up @@ -783,13 +797,12 @@ def _update_traces(
if self.residual_button.isChecked():
if self.residual_traces is None:
if self.whitened_traces is None:
whitened_traces = self.get_whitened_traces(
raw_data=raw_data,
whitened_traces = filter_and_whiten(
raw_traces=raw_traces,
intermediate=intermediate,
params=params,
probe=probe,
nSkipCov=100,
whitening_matrix=self.whitening_matrix,
good_channels=good_channels,
)

self.whitened_traces = whitened_traces
Expand Down Expand Up @@ -828,7 +841,6 @@ def _update_colormap(
self,
params,
probe,
raw_data,
raw_traces,
to_display,
intermediate,
Expand All @@ -851,13 +863,12 @@ def _update_colormap(

elif self.whitened_button.isChecked():
if self.whitened_traces is None:
whitened_traces = self.get_whitened_traces(
raw_data=raw_data,
whitened_traces = filter_and_whiten(
raw_traces=raw_traces,
intermediate=intermediate,
params=params,
probe=probe,
nSkipCov=100,
whitening_matrix=self.whitening_matrix,
good_channels=good_channels,
)

self.whitened_traces = whitened_traces
Expand Down Expand Up @@ -893,13 +904,12 @@ def _update_colormap(
elif self.residual_button.isChecked():
if self.residual_traces is None:
if self.whitened_traces is None:
whitened_traces = self.get_whitened_traces(
raw_data=raw_data,
whitened_traces = filter_and_whiten(
raw_traces=raw_traces,
intermediate=intermediate,
params=params,
probe=probe,
nSkipCov=100,
whitening_matrix=self.whitening_matrix,
good_channels=good_channels,
)

self.whitened_traces = whitened_traces
Expand Down Expand Up @@ -977,3 +987,35 @@ def wheelEvent(self, ev):

def mouseMoveEvent(self, ev):
pass


class WhiteningMatrixCalculator(QtCore.QRunnable):

def __init__(self, raw_data, probe, params):
super(WhiteningMatrixCalculator, self).__init__()
self.raw_data = raw_data
self.params = params
self.probe = probe

self.signals = CalculatorSignals()

def run(self):
try:
logger.info("Calculating approx. whitening matrix.")
whitening_matrix = get_approx_whitening_matrix(
raw_data=self.raw_data,
params=self.params,
probe=self.probe,
)
except Exception as e:
logger.error(e)
else:
logger.info("Approx. whitening matrix calculated.")
self.signals.result.emit(whitening_matrix)
finally:
self.signals.finished.emit()


class CalculatorSignals(QtCore.QObject):
finished = QtCore.pyqtSignal()
result = QtCore.pyqtSignal(object)
1 change: 1 addition & 0 deletions pykilosort/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def load_raw_data(self):
self.raw_data = raw_data

def setup_data_view(self):
self.data_view_box.calculate_approx_whitening_matrix(self.context)
self.data_view_box.setup_seek(self.context)
self.data_view_box.create_plot_items()
self.data_view_box.update_plot(self.context)
Expand Down
3 changes: 3 additions & 0 deletions pykilosort/gui/probe_view_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def synchronize_data_view_mode(self, mode: str, channels_displayed: int):
def change_sorting_status(self, status_dict):
self.sorting_status = status_dict

def change_sorting_status(self, status_dict):
self.sorting_status = status_dict

def generate_spots_list(self):
spots = []
size = 10
Expand Down
61 changes: 61 additions & 0 deletions pykilosort/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,67 @@ def get_whitening_matrix(raw_data=None, probe=None, params=None, nSkipCov=None):
return Wrot


def get_approx_whitening_matrix(raw_data, params, probe):
Nbatch = get_Nbatch(raw_data, params)
max_batches = 2
if Nbatch < max_batches:
batches = np.arange(Nbatch)
else:
batches = np.random.choice(Nbatch, size=max_batches)

ntbuff = params.ntbuff
NTbuff = params.NTbuff
whiteningRange = params.whiteningRange
scaleproc = params.scaleproc
NT = params.NT
fs = params.fs
fshigh = params.fshigh

xc = probe.xc
yc = probe.yc
chanMap = probe.chanMap
Nchan = probe.Nchan

# Nchan is obtained after the bad channels have been removed
CC = cp.zeros((Nchan, Nchan))

for ibatch in batches:
i = max(0, (NT - ntbuff) * ibatch - 2 * ntbuff)
# WARNING: we no longer use Fortran order, so raw_data is nsamples x NchanTOT
buff = raw_data[i:i + NT - ntbuff]
assert buff.shape[0] > buff.shape[1]
assert buff.flags.c_contiguous

nsampcurr = buff.shape[0]
if nsampcurr < NTbuff:
buff = np.concatenate(
(buff, np.tile(buff[nsampcurr - 1], (NTbuff, 1))), axis=0)

buff_g = cp.asarray(buff, dtype=np.float32)

# apply filters and median subtraction
datr = gpufilter(buff_g, fs=fs, fshigh=fshigh, chanMap=chanMap)
assert datr.flags.c_contiguous

CC = CC + cp.dot(datr.T, datr) / NT # sample covariance

CC = CC / batches.size

if whiteningRange < np.inf:
# if there are too many channels, a finite whiteningRange is more robust to noise
# in the estimation of the covariance
whiteningRange = min(whiteningRange, Nchan)
# this function performs the same matrix inversions as below, just on subsets of
# channels around each channel
Wrot = whiteningLocal(CC, yc, xc, whiteningRange)
else:
Wrot = whiteningFromCovariance(CC)

Wrot = Wrot * scaleproc

return Wrot


def get_good_channels(raw_data=None, probe=None, params=None):
"""
of the channels indicated by the user as good (chanMap)
Expand Down

0 comments on commit 12afe21

Please sign in to comment.