diff --git a/caiman/base/rois.py b/caiman/base/rois.py index fddd1471f..d1c11c374 100644 --- a/caiman/base/rois.py +++ b/caiman/base/rois.py @@ -20,7 +20,7 @@ from typing import Any, Optional import zipfile -from caiman.motion_correction import tile_and_correct +from caiman.motion_correction import tile_and_correct, get_patch_centers, interpolate_shifts try: cv2.setNumThreads(0) @@ -318,7 +318,8 @@ def register_ROIs(A1, print_assignment=False, plot_results=False, Cn=None, - cmap='viridis'): + cmap='viridis', + align_options: Optional[dict] = None): """ Register ROIs across different sessions using an intersection over union metric and the Hungarian algorithm for optimal matching @@ -372,6 +373,9 @@ def register_ROIs(A1, cmap: string colormap for background image + + align_options: Optional[dict] + mcorr options to override defaults when align_flag is True and use_opt_flow is False Returns: matched_ROIs1: list @@ -414,25 +418,50 @@ def register_ROIs(A1, if use_opt_flow: template1_norm = np.uint8(template1 * (template1 > 0) * 255) template2_norm = np.uint8(template2 * (template2 > 0) * 255) - flow = cv2.calcOpticalFlowFarneback(np.uint8(template1_norm * 255), np.uint8(template2_norm * 255), None, + flow = cv2.calcOpticalFlowFarneback(template1_norm, template2_norm, None, 0.5, 3, 128, 3, 7, 1.5, 0) x_remap = (flow[:, :, 0] + x_grid).astype(np.float32) y_remap = (flow[:, :, 1] + y_grid).astype(np.float32) else: - template2, shifts, _, xy_grid = tile_and_correct(template2, - template1 - template1.min(), - [int(dims[0] / 4), int(dims[1] / 4)], [16, 16], [10, 10], - add_to_movie=template2.min(), - shifts_opencv=True) - - dims_grid = tuple(np.max(np.stack(xy_grid, axis=0), axis=0) - np.min(np.stack(xy_grid, axis=0), axis=0) + 1) - _sh_ = np.stack(shifts, axis=0) - shifts_x = np.reshape(_sh_[:, 1], dims_grid, order='C').astype(np.float32) - shifts_y = np.reshape(_sh_[:, 0], dims_grid, order='C').astype(np.float32) - - x_remap = (-np.resize(shifts_x, dims) + x_grid).astype(np.float32) - y_remap = (-np.resize(shifts_y, dims) + y_grid).astype(np.float32) + align_defaults = { + "strides": (int(dims[0] / 4), int(dims[1] / 4)), + "overlaps": (16, 16), + "max_shifts": (10, 10), + "shifts_opencv": True, + "upsample_factor_grid": 4, + "shifts_interpolate": True, + "max_deviation_rigid": 2 + # any other argument to tile_and_correct can also be used in align_options + } + + if align_options: + # override defaults with input options + align_defaults.update(align_options) + align_options = align_defaults + + template2, shifts, _, _ = tile_and_correct(template2, template1 - template1.min(), + add_to_movie=template2.min(), **align_options) + + if align_options["max_deviation_rigid"] == 0: + # repeat rigid shifts to size of the image + shifts_x_full = np.full(dims, -shifts[1]) + shifts_y_full = np.full(dims, -shifts[0]) + else: + # piecewise - interpolate from patches to get shifts per pixel + patch_centers = get_patch_centers(dims, overlaps=align_options["overlaps"], strides=align_options["strides"], + shifts_opencv=align_options["shifts_opencv"], + upsample_factor_grid=align_options["upsample_factor_grid"]) + patch_grid = tuple(len(centers) for centers in patch_centers) + _sh_ = np.stack(shifts, axis=0) + shifts_x = np.reshape(_sh_[:, 1], patch_grid, order='C').astype(np.float32) + shifts_y = np.reshape(_sh_[:, 0], patch_grid, order='C').astype(np.float32) + + shifts_x_full = interpolate_shifts(-shifts_x, patch_centers, tuple(range(d) for d in dims)) + shifts_y_full = interpolate_shifts(-shifts_y, patch_centers, tuple(range(d) for d in dims)) + + x_remap = (shifts_x_full + x_grid).astype(np.float32) + y_remap = (shifts_y_full + y_grid).astype(np.float32) A_2t = np.reshape(A2, dims + (-1,), order='F').transpose(2, 0, 1) A2 = np.stack([cv2.remap(img.astype(np.float32), x_remap, y_remap, cv2.INTER_NEAREST) for img in A_2t], axis=0) diff --git a/caiman/motion_correction.py b/caiman/motion_correction.py index 0ef916b49..26ce39389 100644 --- a/caiman/motion_correction.py +++ b/caiman/motion_correction.py @@ -46,12 +46,12 @@ from numpy.fft import ifftshift import os import scipy -from skimage.transform import resize as resize_sk -from skimage.transform import warp as warp_sk +import scipy.interpolate +from skimage.transform import resize as resize_sk, rescale as rescale_sk, warp as warp_sk import sys import tifffile from tqdm import tqdm -from typing import Optional +from typing import Optional, Literal, Union import caiman import caiman.base.movies @@ -78,7 +78,7 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig strides=(96, 96), overlaps=(32, 32), splits_els=14, num_splits_to_process_els=None, upsample_factor_grid=4, max_deviation_rigid=3, shifts_opencv=True, nonneg_movie=True, gSig_filt=None, use_cuda=False, border_nan=True, pw_rigid=False, num_frames_split=80, var_name_hdf5='mov', is3D=False, - indices=(slice(None), slice(None))): + indices=(slice(None), slice(None)), shifts_interpolate=False): """ Constructor class for motion correction operations @@ -156,6 +156,9 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig indices: tuple(slice), default: (slice(None), slice(None)) Use that to apply motion correction only on a part of the FOV + + shifts_interpolate: bool, default: False + use patch locations to interpolate shifts rather than just upscaling to size of image (for pw_rigid only) Returns: self @@ -197,6 +200,7 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig self.var_name_hdf5 = var_name_hdf5 self.is3D = bool(is3D) self.indices = indices + self.shifts_interpolate = shifts_interpolate if self.use_cuda: logger.warn("cuda is no longer supported; this kwarg will be removed in a future version of caiman") @@ -253,7 +257,7 @@ def motion_correct(self, template=None, save_movie=False): self.mmap_file = self.fname_tot_els if self.pw_rigid else self.fname_tot_rig return self - def motion_correct_rigid(self, template=None, save_movie=False) -> None: + def motion_correct_rigid(self, template: Optional[np.ndarray] = None, save_movie=False) -> None: """ Perform rigid motion correction @@ -300,7 +304,8 @@ def motion_correct_rigid(self, template=None, save_movie=False) -> None: border_nan=self.border_nan, var_name_hdf5=self.var_name_hdf5, is3D=self.is3D, - indices=self.indices) + indices=self.indices, + shifts_interpolate=self.shifts_interpolate) if template is None: self.total_template_rig = _total_template_rig @@ -308,7 +313,7 @@ def motion_correct_rigid(self, template=None, save_movie=False) -> None: self.fname_tot_rig += [_fname_tot_rig] self.shifts_rig += _shifts_rig - def motion_correct_pwrigid(self, save_movie:bool=True, template:np.ndarray=None, show_template:bool=False) -> None: + def motion_correct_pwrigid(self, save_movie:bool=True, template: Optional[np.ndarray] = None, show_template:bool=False) -> None: """Perform pw-rigid motion correction Args: @@ -361,7 +366,7 @@ def motion_correct_pwrigid(self, save_movie:bool=True, template:np.ndarray=None, num_splits_to_process=None, num_iter=num_iter, template=self.total_template_els, shifts_opencv=self.shifts_opencv, save_movie=save_movie, nonneg_movie=self.nonneg_movie, gSig_filt=self.gSig_filt, use_cuda=self.use_cuda, border_nan=self.border_nan, var_name_hdf5=self.var_name_hdf5, is3D=self.is3D, - indices=self.indices) + indices=self.indices, shifts_interpolate=self.shifts_interpolate) if not self.is3D: if show_template: plt.imshow(new_template_els) @@ -381,8 +386,8 @@ def motion_correct_pwrigid(self, save_movie:bool=True, template:np.ndarray=None, self.z_shifts_els += _z_shifts_els self.coord_shifts_els += _coord_shifts_els - def apply_shifts_movie(self, fname, rigid_shifts:bool=None, save_memmap:bool=False, - save_base_name:str='MC', order:str='F', remove_min:bool=True): + def apply_shifts_movie(self, fname, rigid_shifts: Optional[bool] = None, save_memmap:bool=False, + save_base_name:str='MC', order: Literal['C', 'F'] = 'F', remove_min:bool=True): """ Applies shifts found by registering one file to a different file. Useful for cases when shifts computed from a structural channel are applied to a @@ -441,81 +446,27 @@ def apply_shifts_movie(self, fname, rigid_shifts:bool=None, save_memmap:bool=Fal sh[0], sh[1]), 0, is_freq=False, border_nan=self.border_nan) for img, sh in zip( Y, self.shifts_rig)] else: + # take potential upsampling into account when recreating patch grid + dims = Y.shape[1:] + patch_centers = get_patch_centers(dims, overlaps=self.overlaps, strides=self.strides, + shifts_opencv=self.shifts_opencv, upsample_factor_grid=self.upsample_factor_grid) if self.is3D: - xyz_grid = [(it[0], it[1], it[2]) for it in sliding_window_3d( - Y[0], self.overlaps, self.strides)] - dims_grid = tuple(np.add(xyz_grid[-1], 1)) - shifts_x = np.stack([np.reshape(_sh_, dims_grid, order='C').astype( - np.float32) for _sh_ in self.x_shifts_els], axis=0) - shifts_y = np.stack([np.reshape(_sh_, dims_grid, order='C').astype( - np.float32) for _sh_ in self.y_shifts_els], axis=0) - shifts_z = np.stack([np.reshape(_sh_, dims_grid, order='C').astype( - np.float32) for _sh_ in self.z_shifts_els], axis=0) - dims = Y.shape[1:] - x_grid, y_grid, z_grid = np.meshgrid(np.arange(0., dims[1]).astype( - np.float32), np.arange(0., dims[0]).astype(np.float32), - np.arange(0., dims[2]).astype(np.float32)) - if self.border_nan is not False: - if self.border_nan is True: - m_reg = [warp_sk(img, np.stack((resize_sk(shiftX.astype(np.float32), dims) + y_grid, - resize_sk(shiftY.astype(np.float32), dims) + x_grid, - resize_sk(shiftZ.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='constant', cval=np.nan) - for img, shiftX, shiftY, shiftZ in zip(Y, shifts_x, shifts_y, shifts_z)] - elif self.border_nan == 'min': - m_reg = [warp_sk(img, np.stack((resize_sk(shiftX.astype(np.float32), dims) + y_grid, - resize_sk(shiftY.astype(np.float32), dims) + x_grid, - resize_sk(shiftZ.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='constant', cval=np.min(img)) - for img, shiftX, shiftY, shiftZ in zip(Y, shifts_x, shifts_y, shifts_z)] - elif self.border_nan == 'copy': - m_reg = [warp_sk(img, np.stack((resize_sk(shiftX.astype(np.float32), dims) + y_grid, - resize_sk(shiftY.astype(np.float32), dims) + x_grid, - resize_sk(shiftZ.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='edge') - for img, shiftX, shiftY, shiftZ in zip(Y, shifts_x, shifts_y, shifts_z)] - else: - m_reg = [warp_sk(img, np.stack((resize_sk(shiftX.astype(np.float32), dims) + y_grid, - resize_sk(shiftY.astype(np.float32), dims) + x_grid, - resize_sk(shiftZ.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='constant') - for img, shiftX, shiftY, shiftZ in zip(Y, shifts_x, shifts_y, shifts_z)] + # x_shifts_els and y_shifts_els are switched intentionally + m_reg = [ + apply_pw_shifts_remap_3d(img, shifts_y=-x_shifts, shifts_x=-y_shifts, shifts_z=-z_shifts, + patch_centers=patch_centers, border_nan=self.border_nan, + shifts_interpolate=self.shifts_interpolate) + for img, x_shifts, y_shifts, z_shifts in zip(Y, self.x_shifts_els, self.y_shifts_els, self.z_shifts_els) + ] + else: - xy_grid = [(it[0], it[1]) for it in sliding_window(Y[0], self.overlaps, self.strides)] - dims_grid = tuple(np.max(np.stack(xy_grid, axis=1), axis=1) - np.min( - np.stack(xy_grid, axis=1), axis=1) + 1) - shifts_x = np.stack([np.reshape(_sh_, dims_grid, order='C').astype( - np.float32) for _sh_ in self.x_shifts_els], axis=0) - shifts_y = np.stack([np.reshape(_sh_, dims_grid, order='C').astype( - np.float32) for _sh_ in self.y_shifts_els], axis=0) - dims = Y.shape[1:] - x_grid, y_grid = np.meshgrid(np.arange(0., dims[1]).astype( - np.float32), np.arange(0., dims[0]).astype(np.float32)) - if self.border_nan is not False: - if self.border_nan is True: - m_reg = [cv2.remap(img, -cv2.resize(shiftY, dims[::-1]) + x_grid, - -cv2.resize(shiftX, dims[::-1]) + y_grid, - cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, - borderValue=np.nan) - for img, shiftX, shiftY in zip(Y, shifts_x, shifts_y)] - - elif self.border_nan == 'min': - m_reg = [cv2.remap(img, -cv2.resize(shiftY, dims[::-1]) + x_grid, - -cv2.resize(shiftX, dims[::-1]) + y_grid, - cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, - borderValue=np.min(img)) - for img, shiftX, shiftY in zip(Y, shifts_x, shifts_y)] - elif self.border_nan == 'copy': - m_reg = [cv2.remap(img, -cv2.resize(shiftY, dims[::-1]) + x_grid, - -cv2.resize(shiftX, dims[::-1]) + y_grid, - cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) - for img, shiftX, shiftY in zip(Y, shifts_x, shifts_y)] - else: - m_reg = [cv2.remap(img, -cv2.resize(shiftY, dims[::-1]) + x_grid, - -cv2.resize(shiftX, dims[::-1]) + y_grid, - cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, - borderValue=0.0) - for img, shiftX, shiftY in zip(Y, shifts_x, shifts_y)] + # x_shifts_els and y_shifts_els are switched intentionally + m_reg = [ + apply_pw_shifts_remap_2d(img, shifts_y=-x_shifts, shifts_x=-y_shifts, patch_centers=patch_centers, + border_nan=self.border_nan, shifts_interpolate=self.shifts_interpolate) + for img, x_shifts, y_shifts in zip(Y, self.x_shifts_els, self.y_shifts_els) + ] + del Y m_reg = np.stack(m_reg, axis=0) if save_memmap: @@ -530,12 +481,12 @@ def apply_shifts_movie(self, fname, rigid_shifts:bool=None, save_memmap:bool=Fal else: return caiman.movie(m_reg) -def apply_shift_iteration(img, shift, border_nan:bool=False, border_type=cv2.BORDER_REFLECT): +def apply_shift_iteration(img, shift, border_nan=False, border_type=cv2.BORDER_REFLECT): # todo todocument sh_x_n, sh_y_n = shift w_i, h_i = img.shape - M = np.float32([[1, 0, sh_y_n], [0, 1, sh_x_n]]) + M = np.array([[1, 0, sh_y_n], [0, 1, sh_x_n]], dtype=np.float32) min_, max_ = np.nanmin(img), np.nanmax(img) img = np.clip(cv2.warpAffine(img, M, (h_i, w_i), flags=cv2.INTER_CUBIC, borderMode=border_type), min_, max_) @@ -568,10 +519,12 @@ def apply_shift_iteration(img, shift, border_nan:bool=False, border_type=cv2.BOR img[:, :max_w] = img[:, max_w, np.newaxis] if min_w < 0: img[:, min_w:] = img[:, min_w-1, np.newaxis] + else: + logging.warning(f'Unknown value of border_nan ({border_nan}); treating as False') return img -def apply_shift_online(movie_iterable, xy_shifts, save_base_name=None, order='F'): +def apply_shift_online(movie_iterable, xy_shifts, save_base_name=None, order: Literal['C', 'F'] = 'F'): """ Applies rigid shifts to a loaded movie. Useful when processing a dataset with CaImAn online and you want to obtain the registered movie after @@ -777,7 +730,7 @@ def motion_correct_online_multifile(list_files, add_to_movie, order='C', **kwarg return all_names, all_shifts, all_xcorrs, all_templates -def motion_correct_online(movie_iterable, add_to_movie, max_shift_w=25, max_shift_h=25, save_base_name=None, order='C', +def motion_correct_online(movie_iterable, add_to_movie, max_shift_w=25, max_shift_h=25, save_base_name=None, order: Literal['F', 'C']='C', init_frames_template=100, show_movie=False, bilateral_blur=False, template=None, min_count=1000, border_to_0=0, n_iter=1, remove_blanks=False, show_template=False, return_mov=False, use_median_as_template=False): @@ -1410,7 +1363,7 @@ def register_translation_3d(src_image, target_image, upsample_factor = 1, del cross_correlation if (shifts_lb is not None) or (shifts_ub is not None): - + # TODO this will fail if only one is not none - should this be an "and"? if (shifts_lb[0] < 0) and (shifts_ub[0] >= 0): new_cross_corr[shifts_ub[0]:shifts_lb[0], :, :] = 0 else: @@ -1796,71 +1749,150 @@ def apply_shifts_dft(src_freq, shifts, diffphase, is_freq=True, border_nan=True) new_img[:, :, :max_d] = new_img[:, :, max_d, np.newaxis] if min_d < 0: new_img[:, :, min_d:] = new_img[:, :, min_d-1, np.newaxis] + else: + logging.warning(f'Unknown value of border_nan ({border_nan}); treating as False') return new_img -def sliding_window(image, overlaps, strides): +def get_patch_edges(dims: tuple[int, ...], overlaps: tuple[int, ...], strides: tuple[int, ...], + ) -> tuple[list[int], ...]: + """For each dimension, return a vector of pixels along that dimension where patches start""" + windowSizes = np.add(overlaps, strides) + return tuple( + list(range(0, dim - windowSize, stride)) + [dim - windowSize] + for dim, windowSize, stride in zip(dims, windowSizes, strides) + ) + + +def get_patch_centers(dims: tuple[int, ...], overlaps: tuple[int, ...], strides: tuple[int, ...], + shifts_opencv=False, upsample_factor_grid=1) -> tuple[list[float], ...]: + """ + For each dimension, return a vector of patch center locations for pw_rigid correction + shifts_opencv just overrides upsample_factor_grid (forces it to 1), this is an easy way to + get the correct values by just providing the motion correction parameters, but + by default no extra upsampling is done. + """ + if not shifts_opencv and upsample_factor_grid != 1: + # account for upsampling step + strides = tuple(np.round(np.divide(strides, upsample_factor_grid)).astype(int)) + + patch_edges = get_patch_edges(dims, overlaps, strides) + windowSizes = np.add(overlaps, strides) + return tuple([edge + (sz - 1) / 2 for edge in edges] for edges, sz in zip(patch_edges, windowSizes)) + + +def sliding_window_dims(dims: tuple[int, ...], overlaps: tuple[int, ...], strides: tuple[int, ...]): + """ computes dimensions for a sliding window with given image dims, overlaps, and strides + + Args: + dims: tuple + the dimensions of the image + + overlaps: tuple + overlap of patches in each dimension, except that the last patch will be all the way + at the bottom/right regardless of overlap + + strides: tuple + stride in each dimension, except that the last patch will be all the way + at the bottom/right regardless of stride + + Returns: + iterator containing 3 items: + inds: (x, y, ...) coordinates in the patch grid + corner: (x, y, ...) location of corner of the patch in the original matrix + size: (x, y, ...) size of patch in pixels + """ + windowSizes = np.add(overlaps, strides) + edge_ranges = get_patch_edges(dims, overlaps, strides) + + for patch in itertools.product(*[enumerate(r) for r in edge_ranges]): + inds, corners = zip(*patch) + yield (tuple(inds), tuple(corners), windowSizes) + + +def sliding_window(image: np.ndarray, overlaps: tuple[int, int], strides: tuple[int, int]): """ efficiently and lazily slides a window across the image Args: - img:ndarray 2D - image that needs to be slices + img: ndarray 2D + image that needs to be sliced overlaps: tuple - dimension of the patch + overlap of patches in each dimension, except that the last patch will be all the way + at the bottom/right regardless of overlap strides: tuple - stride in each dimension + stride in each dimension, except that the last patch will be all the way + at the bottom/right regardless of stride Returns: iterator containing five items dim_1, dim_2 coordinates in the patch grid x, y: bottom border of the patch in the original matrix - patch: the patch - """ - windowSize = np.add(overlaps, strides) - range_1 = list(range( - 0, image.shape[0] - windowSize[0], strides[0])) + [image.shape[0] - windowSize[0]] - range_2 = list(range( - 0, image.shape[1] - windowSize[1], strides[1])) + [image.shape[1] - windowSize[1]] - for dim_1, x in enumerate(range_1): - for dim_2, y in enumerate(range_2): - # yield the current window - yield (dim_1, dim_2, x, y, image[x:x + windowSize[0], y:y + windowSize[1]]) - -def sliding_window_3d(image, overlaps, strides): + """ + if image.ndim != 2: + raise ValueError('Input to sliding_window must be 2D') + + for inds, corner, size in sliding_window_dims(image.shape, overlaps, strides): + patch = image[corner[0]:corner[0] + size[0], corner[1]:corner[1] + size[1]] + yield inds + corner + (patch,) + + +def sliding_window_3d(image: np.ndarray, overlaps: tuple[int, int, int], strides: tuple[int, int, int]): """ efficiently and lazily slides a window across the image Args: - img:ndarray 3D - image that needs to be slices + img: ndarray 3D + image that needs to be sliced overlaps: tuple - dimension of the patch + overlap of patches in each dimension, except that the last patch will be all the way + at the bottom/right regardless of overlap strides: tuple - stride in each dimension + stride in each dimension, except that the last patch will be all the way + at the bottom/right regardless of stride Returns: iterator containing seven items dim_1, dim_2, dim_3 coordinates in the patch grid x, y, z: bottom border of the patch in the original matrix - patch: the patch """ - windowSize = np.add(overlaps, strides) - range_1 = list(range( - 0, image.shape[0] - windowSize[0], strides[0])) + [image.shape[0] - windowSize[0]] - range_2 = list(range( - 0, image.shape[1] - windowSize[1], strides[1])) + [image.shape[1] - windowSize[1]] - range_3 = list(range( - 0, image.shape[2] - windowSize[2], strides[2])) + [image.shape[2] - windowSize[2]] - for dim_1, x in enumerate(range_1): - for dim_2, y in enumerate(range_2): - for dim_3, z in enumerate(range_3): - # yield the current window - yield (dim_1, dim_2, dim_3, x, y, z, image[x:x + windowSize[0], y:y + windowSize[1], z:z + windowSize[2]]) + if image.ndim != 3: + raise ValueError('Input to sliding_window_3d must be 3D') + + for inds, corner, size in sliding_window_dims(image.shape, overlaps, strides): + patch = image[corner[0]:corner[0] + size[0], + corner[1]:corner[1] + size[1], + corner[2]:corner[2] + size[2]] + yield inds + corner + (patch,) + + +def interpolate_shifts(shifts, coords_orig: tuple, coords_new: tuple) -> np.ndarray: + """ + Interpolate piecewise shifts onto new coordinates. Pixels outside the original coordinates will be filled with edge values. + + Args: + shifts: ndarray or other array-like + shifts to interpolate; must have the same number of elements as the outer product of coords_orig + + coords_orig: tuple of float vectors + patch center coordinates along each dimension (e.g. outputs of get_patch_centers) + + coords_new: tuple of float vectors + coordinates along each dimension at which to output interpolated shifts + + Returns: + ndarray of interpolated shifts, of shape tuple(len(coords) for coords in coords_new) + """ + # clip new coordinates to avoid extrapolation + coords_new_clipped = [np.clip(coord, min(coord_orig), max(coord_orig)) for coord, coord_orig in zip(coords_new, coords_orig)] + coords_new_stacked = np.stack(np.meshgrid(*coords_new_clipped, indexing='ij'), axis=-1) + shifts_grid = np.reshape(shifts, tuple(len(coord) for coord in coords_orig)) + return scipy.interpolate.interpn(coords_orig, shifts_grid, coords_new_stacked, method="cubic") + def iqr(a): return np.percentile(a, 75) - np.percentile(a, 25) @@ -1951,7 +1983,7 @@ def high_pass_filter_space(img_orig, gSig_filt=None, freq=None, order=None): def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=None, newstrides=None, upsample_factor_grid=4, upsample_factor_fft=10, show_movie=False, max_deviation_rigid=2, add_to_movie=0, shifts_opencv=False, gSig_filt=None, - use_cuda=False, border_nan=True): + use_cuda=False, border_nan=True, shifts_interpolate=False): """ perform piecewise rigid motion correction iteration, by 1) dividing the FOV in patches 2) motion correcting each patch separately @@ -2006,6 +2038,9 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N border_nan : bool or string, optional specifies how to deal with borders. (True, False, 'copy', 'min') + + shifts_interpolate: bool + use patch locations to interpolate shifts rather than just upscaling to size of image. Default: False Returns: (new_img, total_shifts, start_step, xy_grid) @@ -2049,24 +2084,23 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N else: # extract patches logger.info("Extracting patches") - templates = [ - it[-1] for it in sliding_window(template, overlaps=overlaps, strides=strides)] - xy_grid = [(it[0], it[1]) for it in sliding_window( - template, overlaps=overlaps, strides=strides)] - num_tiles = np.prod(np.add(xy_grid[-1], 1)) - imgs = [it[-1] - for it in sliding_window(img, overlaps=overlaps, strides=strides)] + xy_grid: list[tuple[int, int]] = [] + templates: list[np.ndarray] = [] + + for (xind, yind, _, _, patch) in sliding_window(template, overlaps=overlaps, strides=strides): + xy_grid.append((xind, yind)) + templates.append(patch) + + imgs = [it[-1] for it in sliding_window(img, overlaps=overlaps, strides=strides)] dim_grid = tuple(np.add(xy_grid[-1], 1)) + num_tiles = len(xy_grid) if max_deviation_rigid is not None: - lb_shifts = np.ceil(np.subtract( rigid_shts, max_deviation_rigid)).astype(int) ub_shifts = np.floor( np.add(rigid_shts, max_deviation_rigid)).astype(int) - else: - lb_shifts = None ub_shifts = None @@ -2087,11 +2121,11 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N img = img_orig dims = img.shape - x_grid, y_grid = np.meshgrid(np.arange(0., dims[1]).astype( - np.float32), np.arange(0., dims[0]).astype(np.float32)) - m_reg = cv2.remap(img, cv2.resize(shift_img_y.astype(np.float32), dims[::-1]) + x_grid, - cv2.resize(shift_img_x.astype(np.float32), dims[::-1]) + y_grid, - cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) + patch_centers = get_patch_centers(dims, strides=strides, overlaps=overlaps) + # shift_img_x and shift_img_y are switched intentionally + m_reg = apply_pw_shifts_remap_2d(img, shifts_y=shift_img_x, shifts_x=shift_img_y, + patch_centers=patch_centers, border_nan=border_nan, + shifts_interpolate=shifts_interpolate) total_shifts = [ (-x, -y) for x, y in zip(shift_img_x.reshape(num_tiles), shift_img_y.reshape(num_tiles))] return m_reg - add_to_movie, total_shifts, None, None @@ -2105,25 +2139,32 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N newshapes = np.add(newstrides, newoverlaps) - imgs = [it[-1] - for it in sliding_window(img, overlaps=newoverlaps, strides=newstrides)] - - xy_grid = [(it[0], it[1]) for it in sliding_window( - img, overlaps=newoverlaps, strides=newstrides)] + xy_grid: list[tuple[int, int]] = [] + start_step: list[tuple[int, int]] = [] + imgs: list[np.ndarray] = [] - start_step = [(it[2], it[3]) for it in sliding_window( - img, overlaps=newoverlaps, strides=newstrides)] + for (xind, yind, xstart, ystart, patch) in sliding_window(img, overlaps=newoverlaps, strides=newstrides): + xy_grid.append((xind, yind)) + start_step.append((xstart, ystart)) + imgs.append(patch) dim_new_grid = tuple(np.add(xy_grid[-1], 1)) + num_tiles = len(xy_grid) + + if shifts_interpolate: + patch_centers_orig = get_patch_centers(img.shape, strides=strides, overlaps=overlaps) + patch_centers_new = get_patch_centers(img.shape, strides=newstrides, overlaps=newoverlaps) + shift_img_x = interpolate_shifts(shift_img_x, patch_centers_orig, patch_centers_new) + shift_img_y = interpolate_shifts(shift_img_y, patch_centers_orig, patch_centers_new) + diffs_phase_grid_us = interpolate_shifts(diffs_phase_grid, patch_centers_orig, patch_centers_new) + else: + shift_img_x = cv2.resize( + shift_img_x, dim_new_grid[::-1], interpolation=cv2.INTER_CUBIC) + shift_img_y = cv2.resize( + shift_img_y, dim_new_grid[::-1], interpolation=cv2.INTER_CUBIC) + diffs_phase_grid_us = cv2.resize( + diffs_phase_grid, dim_new_grid[::-1], interpolation=cv2.INTER_CUBIC) - shift_img_x = cv2.resize( - shift_img_x, dim_new_grid[::-1], interpolation=cv2.INTER_CUBIC) - shift_img_y = cv2.resize( - shift_img_y, dim_new_grid[::-1], interpolation=cv2.INTER_CUBIC) - diffs_phase_grid_us = cv2.resize( - diffs_phase_grid, dim_new_grid[::-1], interpolation=cv2.INTER_CUBIC) - - num_tiles = np.prod(dim_new_grid) max_shear = np.percentile( [np.max(np.abs(np.diff(ssshh, axis=xxsss))) for ssshh, xxsss in itertools.product( @@ -2200,7 +2241,7 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, overlaps:tuple, max_shifts:tuple, newoverlaps:Optional[tuple]=None, newstrides:Optional[tuple]=None, upsample_factor_grid:int=4, upsample_factor_fft:int=10, show_movie:bool=False, max_deviation_rigid:int=2, add_to_movie:int=0, shifts_opencv:bool=True, gSig_filt=None, - use_cuda:bool=False, border_nan:bool=True): + use_cuda:bool=False, border_nan:bool=True, shifts_interpolate:bool=False): """ perform piecewise rigid motion correction iteration, by 1) dividing the FOV in patches 2) motion correcting each patch separately @@ -2254,6 +2295,9 @@ def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, over border_nan : bool or string, optional specifies how to deal with borders. (True, False, 'copy', 'min') + + shifts_interpolate: bool + use patch locations to interpolate shifts rather than just upscaling to size of image. Default: False Returns: (new_img, total_shifts, start_step, xyz_grid) @@ -2288,14 +2332,16 @@ def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, over return new_img - add_to_movie, (-rigid_shts[0], -rigid_shts[1], -rigid_shts[2]), None, None else: # extract patches - templates = [ - it[-1] for it in sliding_window_3d(template, overlaps=overlaps, strides=strides)] - xyz_grid = [(it[0], it[1], it[2]) for it in sliding_window_3d( - template, overlaps=overlaps, strides=strides)] - num_tiles = np.prod(np.add(xyz_grid[-1], 1)) - imgs = [it[-1] - for it in sliding_window_3d(img, overlaps=overlaps, strides=strides)] + xyz_grid: list[tuple[int, int, int]] = [] + templates: list[np.ndarray] = [] + + for (xind, yind, zind, _, _, _, patch) in sliding_window_3d(template, overlaps=overlaps, strides=strides): + xyz_grid.append((xind, yind, zind)) + templates.append(patch) + + imgs = [it[-1] for it in sliding_window_3d(img, overlaps=overlaps, strides=strides)] dim_grid = tuple(np.add(xyz_grid[-1], 1)) + num_tiles = len(xyz_grid) if max_deviation_rigid is not None: lb_shifts = np.ceil(np.subtract( @@ -2322,32 +2368,14 @@ def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, over if shifts_opencv: if gSig_filt is not None: img = img_orig + + patch_centers = get_patch_centers(img.shape, strides=strides, overlaps=overlaps) + # shift_img_x and shift_img_y are switched intentionally + m_reg = apply_pw_shifts_remap_3d( + img, shifts_y=shift_img_x, shifts_x=shift_img_y, shifts_z=shift_img_z, + patch_centers=patch_centers, border_nan=border_nan, + shifts_interpolate=shifts_interpolate) - dims = img.shape - x_grid, y_grid, z_grid = np.meshgrid(np.arange(0., dims[1]).astype( - np.float32), np.arange(0., dims[0]).astype(np.float32), - np.arange(0., dims[2]).astype(np.float32)) - if border_nan is not False: - if border_nan is True: - m_reg = warp_sk(img, np.stack((resize_sk(shift_img_x.astype(np.float32), dims) + y_grid, - resize_sk(shift_img_y.astype(np.float32), dims) + x_grid, - resize_sk(shift_img_z.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='constant', cval=np.nan) - elif border_nan == 'min': - m_reg = warp_sk(img, np.stack((resize_sk(shift_img_x.astype(np.float32), dims) + y_grid, - resize_sk(shift_img_y.astype(np.float32), dims) + x_grid, - resize_sk(shift_img_z.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='constant', cval=np.min(img)) - elif border_nan == 'copy': - m_reg = warp_sk(img, np.stack((resize_sk(shift_img_x.astype(np.float32), dims) + y_grid, - resize_sk(shift_img_y.astype(np.float32), dims) + x_grid, - resize_sk(shift_img_z.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='edge') - else: - m_reg = warp_sk(img, np.stack((resize_sk(shift_img_x.astype(np.float32), dims) + y_grid, - resize_sk(shift_img_y.astype(np.float32), dims) + x_grid, - resize_sk(shift_img_z.astype(np.float32), dims) + z_grid), axis=0), - order=3, mode='constant') total_shifts = [ (-x, -y, -z) for x, y, z in zip(shift_img_x.reshape(num_tiles), shift_img_y.reshape(num_tiles), shift_img_z.reshape(num_tiles))] return m_reg - add_to_movie, total_shifts, None, None @@ -2361,27 +2389,35 @@ def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, over newshapes = np.add(newstrides, newoverlaps) - imgs = [it[-1] - for it in sliding_window_3d(img, overlaps=newoverlaps, strides=newstrides)] - - xyz_grid = [(it[0], it[1], it[2]) for it in sliding_window_3d( - img, overlaps=newoverlaps, strides=newstrides)] + xyz_grid: list[tuple[int, int, int]] = [] + start_step: list[tuple[int, int, int]] = [] + imgs: list[np.ndarray] = [] - start_step = [(it[3], it[4], it[5]) for it in sliding_window_3d( - img, overlaps=newoverlaps, strides=newstrides)] + for (xind, yind, zind, xstart, ystart, zstart, patch) in sliding_window_3d( + img, overlaps=newoverlaps, strides=newstrides): + xyz_grid.append((xind, yind, zind)) + start_step.append((xstart, ystart, zstart)) + imgs.append(patch) dim_new_grid = tuple(np.add(xyz_grid[-1], 1)) - - shift_img_x = resize_sk( - shift_img_x, dim_new_grid[::-1], order=3) - shift_img_y = resize_sk( - shift_img_y, dim_new_grid[::-1], order=3) - shift_img_z = resize_sk( - shift_img_z, dim_new_grid[::-1], order=3) - diffs_phase_grid_us = resize_sk( - diffs_phase_grid, dim_new_grid[::-1], order=3) - - num_tiles = np.prod(dim_new_grid) + num_tiles = len(xyz_grid) + + if shifts_interpolate: + patch_centers_orig = get_patch_centers(img.shape, strides=strides, overlaps=overlaps) + patch_centers_new = get_patch_centers(img.shape, strides=newstrides, overlaps=newoverlaps) + shift_img_x = interpolate_shifts(shift_img_x, patch_centers_orig, patch_centers_new) + shift_img_y = interpolate_shifts(shift_img_y, patch_centers_orig, patch_centers_new) + shift_img_z = interpolate_shifts(shift_img_z, patch_centers_orig, patch_centers_new) + diffs_phase_grid_us = interpolate_shifts(diffs_phase_grid, patch_centers_orig, patch_centers_new) + else: + shift_img_x = resize_sk( + shift_img_x, dim_new_grid[::-1], order=3) + shift_img_y = resize_sk( + shift_img_y, dim_new_grid[::-1], order=3) + shift_img_z = resize_sk( + shift_img_z, dim_new_grid[::-1], order=3) + diffs_phase_grid_us = resize_sk( + diffs_phase_grid, dim_new_grid[::-1], order=3) # what dimension shear should be looked at? shearing for 3d point scanning happens in y and z but not for plane-scanning max_shear = np.percentile( @@ -2453,7 +2489,7 @@ def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, over sfr_freq, (-rigid_shts[0], -rigid_shts[1], -rigid_shts[2]), diffphase, border_nan=border_nan) img_show = np.vstack([new_img, img]) - img_show = resize_sk(img_show, None, fx=1, fy=1, fz=1) + img_show = rescale_sk(img_show, 1) # TODO does this actually do anything?? cv2.imshow('frame', img_show / np.percentile(template, 99)) cv2.waitKey(int(1. / 500 * 1000)) @@ -2464,6 +2500,117 @@ def tile_and_correct_3d(img:np.ndarray, template:np.ndarray, strides:tuple, over except: pass return new_img - add_to_movie, total_shifts, start_step, xyz_grid + +def apply_pw_shifts_remap_2d(img: np.ndarray, shifts_y: np.ndarray, shifts_x: np.ndarray, + patch_centers: tuple[list[float], ...], border_nan: Union[bool, Literal['copy', 'min']], + shifts_interpolate=False) -> np.ndarray: + """ + Use OpenCV remap to apply 2D piecewise shifts + Inputs: + img: the 2D image to apply shifts to + shifts_y: array of y shifts for each patch (C order) (this is the actual Y i.e. the first dimension of the image) + shifts_x: array of x shifts for each patch (C order) + patch_centers: tuple of patch locations in each dimension. + border_nan: how to deal with borders when remapping + shifts_interpolate: if true, uses interpn to upsample shifts based on patch centers instead of resize + Outputs: + img_remapped: the remapped image + """ + # reshape shifts_y and shifts_x based on patch grid + patch_grid = tuple(len(centers) for centers in patch_centers) + shift_img_y = np.reshape(shifts_y, patch_grid) + shift_img_x = np.reshape(shifts_x, patch_grid) + + # get full image shape/coordinates + dims = img.shape + x_coords, y_coords = [np.arange(0., dims[dim]).astype(np.float32) for dim in (1, 0)] + x_grid, y_grid = np.meshgrid(x_coords, y_coords) + + # up-sample shifts + if shifts_interpolate: + shifts_y = interpolate_shifts(shift_img_y, patch_centers, (y_coords, x_coords)).astype(np.float32) + shifts_x = interpolate_shifts(shift_img_x, patch_centers, (y_coords, x_coords)).astype(np.float32) + else: + shifts_y = cv2.resize(shift_img_y.astype(np.float32), dims[::-1]) + shifts_x = cv2.resize(shift_img_x.astype(np.float32), dims[::-1]) + + # apply to image + if border_nan is False: + mode = cv2.BORDER_CONSTANT + value = 0.0 + elif border_nan is True: + mode = cv2.BORDER_CONSTANT + value = np.nan + elif border_nan == 'min': + mode = cv2.BORDER_CONSTANT + value = np.min(img) + elif border_nan == 'copy': + mode = cv2.BORDER_REPLICATE + value = 0.0 + else: + raise ValueError(f'Unknown value of border_nan ({border_nan})') + + return cv2.remap(img, shifts_x + x_grid, shifts_y + y_grid, cv2.INTER_CUBIC, + borderMode=mode, borderValue=value) + +def apply_pw_shifts_remap_3d(img: np.ndarray, shifts_y: np.ndarray, shifts_x: np.ndarray, shifts_z: np.ndarray, + patch_centers: tuple[list[float], ...], border_nan: Union[bool, Literal['copy', 'min']], + shifts_interpolate=False) -> np.ndarray: + """ + Use skimage warp to apply 3D piecewise shifts + Inputs: + img: the 3D image to apply shifts to + shifts_y: array of y shifts for each patch (C order) (this is the actual Y i.e. the first dimension of the image) + shifts_x: array of x shifts for each patch (C order) + shifts_z: array of z shifts for each patch (C order) + patch_centers: tuple of patch locations in each dimension. + border_nan: how to deal with borders when remapping + shifts_interpolate: if true, uses interpn to upsample shifts based on patch centers instead of resize + Outputs: + img_remapped: the remapped image + """ + # reshape shifts based on patch grid + patch_grid = tuple(len(centers) for centers in patch_centers) + shift_img_y = np.reshape(shifts_y, patch_grid) + shift_img_x = np.reshape(shifts_x, patch_grid) + shift_img_z = np.reshape(shifts_z, patch_grid) + + # get full image shape/coordinates + dims = img.shape + x_coords, y_coords, z_coords = [np.arange(0., dims[dim]).astype(np.float32) for dim in (1, 0, 2)] + x_grid, y_grid, z_grid = np.meshgrid(x_coords, y_coords, z_coords) + + # up-sample shifts + if shifts_interpolate: + coords_new = (y_coords, x_coords, z_coords) + shifts_y = interpolate_shifts(shift_img_y, patch_centers, coords_new).astype(np.float32) + shifts_x = interpolate_shifts(shift_img_x, patch_centers, coords_new).astype(np.float32) + shifts_z = interpolate_shifts(shift_img_z, patch_centers, coords_new).astype(np.float32) + else: + shifts_y = resize_sk(shift_img_y.astype(np.float32), dims) + shifts_x = resize_sk(shift_img_x.astype(np.float32), dims) + shifts_z = resize_sk(shift_img_z.astype(np.float32), dims) + + shift_map = np.stack((shifts_y + y_grid, shifts_x + x_grid, shifts_z + z_grid), axis=0) + + # apply to image + if border_nan is False: + mode = 'constant' + value = 0.0 + elif border_nan is True: + mode = 'constant' + value = np.nan + elif border_nan == 'min': + mode = 'constant' + value = np.min(img) + elif border_nan == 'copy': + mode = 'edge' + value = 0.0 + else: + raise ValueError(f'Unknown value of border_nan ({border_nan})') + + return warp_sk(img, shift_map, order=3, mode=mode, cval=value) + def compute_flow_single_frame(frame, templ, pyr_scale=.5, levels=3, winsize=100, iterations=15, poly_n=5, poly_sigma=1.2 / 5, flags=0): @@ -2606,7 +2753,8 @@ def compute_metrics_motion_correction(fname, final_size_x, final_size_y, swap_di def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_splits_to_process=None, num_iter=1, template=None, shifts_opencv=False, save_movie_rigid=False, add_to_movie=None, nonneg_movie=False, gSig_filt=None, subidx=slice(None, None, 1), use_cuda=False, - border_nan=True, var_name_hdf5='mov', is3D=False, indices=(slice(None), slice(None))): + border_nan=True, var_name_hdf5='mov', is3D=False, indices=(slice(None), slice(None)), + shifts_interpolate=False): """ Function that perform memory efficient hyper parallelized rigid motion corrections while also saving a memory mappable file @@ -2726,7 +2874,7 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl dview=dview, save_movie=save_movie, base_name=base_name, subidx = subidx, num_splits=num_splits_to_process, shifts_opencv=shifts_opencv, nonneg_movie=nonneg_movie, gSig_filt=gSig_filt, use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5, is3D=is3D, - indices=indices) + indices=indices, shifts_interpolate=shifts_interpolate) if is3D: new_templ = np.nanmedian(np.stack([r[-1] for r in res_rig]), 0) else: @@ -2749,7 +2897,7 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo splits=56, num_splits_to_process=None, num_iter=1, template=None, shifts_opencv=False, save_movie=False, nonneg_movie=False, gSig_filt=None, use_cuda=False, border_nan=True, var_name_hdf5='mov', is3D=False, - indices=(slice(None), slice(None))): + indices=(slice(None), slice(None)), shifts_interpolate=False): """ Function that perform memory efficient hyper parallelized rigid motion corrections while also saving a memory mappable file @@ -2854,7 +3002,7 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo base_name=base_name, num_splits=num_splits_to_process, shifts_opencv=shifts_opencv, nonneg_movie=nonneg_movie, gSig_filt=gSig_filt, use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5, is3D=is3D, - indices=indices) + indices=indices, shifts_interpolate=shifts_interpolate) if is3D: new_templ = np.nanmedian(np.stack([r[-1] for r in res_el]), 0) else: @@ -2909,7 +3057,7 @@ def tile_and_correct_wrapper(params): img_name, out_fname, idxs, shape_mov, template, strides, overlaps, max_shifts,\ add_to_movie, max_deviation_rigid, upsample_factor_grid, newoverlaps, newstrides, \ shifts_opencv, nonneg_movie, gSig_filt, is_fiji, use_cuda, border_nan, var_name_hdf5, \ - is3D, indices = params + is3D, indices, shifts_interpolate = params if isinstance(img_name, tuple): @@ -2935,7 +3083,8 @@ def tile_and_correct_wrapper(params): upsample_factor_fft=10, show_movie=False, max_deviation_rigid=max_deviation_rigid, shifts_opencv=shifts_opencv, gSig_filt=gSig_filt, - use_cuda=use_cuda, border_nan=border_nan) + use_cuda=use_cuda, border_nan=border_nan, + shifts_interpolate=shifts_interpolate) shift_info.append([total_shift, start_step, xyz_grid]) else: @@ -2946,7 +3095,8 @@ def tile_and_correct_wrapper(params): upsample_factor_fft=10, show_movie=False, max_deviation_rigid=max_deviation_rigid, shifts_opencv=shifts_opencv, gSig_filt=gSig_filt, - use_cuda=use_cuda, border_nan=border_nan) + use_cuda=use_cuda, border_nan=border_nan, + shifts_interpolate=shifts_interpolate) shift_info.append([total_shift, start_step, xy_grid]) if out_fname is not None: @@ -2967,7 +3117,7 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 upsample_factor_grid=4, order='F', dview=None, save_movie=True, base_name=None, subidx = None, num_splits=None, shifts_opencv=False, nonneg_movie=False, gSig_filt=None, use_cuda=False, border_nan=True, var_name_hdf5='mov', is3D=False, - indices=(slice(None), slice(None))): + indices=(slice(None), slice(None)), shifts_interpolate=False): """ """ @@ -3022,7 +3172,7 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 pars.append([fname, fname_tot, idx, shape_mov, template, strides, overlaps, max_shifts, np.array( add_to_movie, dtype=np.float32), max_deviation_rigid, upsample_factor_grid, newoverlaps, newstrides, shifts_opencv, nonneg_movie, gSig_filt, is_fiji, - use_cuda, border_nan, var_name_hdf5, is3D, indices]) + use_cuda, border_nan, var_name_hdf5, is3D, indices, shifts_interpolate]) if dview is not None: logger.info('** Starting parallel motion correction **') diff --git a/caiman/source_extraction/cnmf/params.py b/caiman/source_extraction/cnmf/params.py index 4ad83eaac..4bbfd83b8 100644 --- a/caiman/source_extraction/cnmf/params.py +++ b/caiman/source_extraction/cnmf/params.py @@ -588,6 +588,9 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1), pw_rigid: bool, default: False flag for performing pw-rigid motion correction. + shifts_interpolate: bool, default: False + use patch locations to interpolate shifts rather than just upscaling to size of image (for pw_rigid only) + shifts_opencv: bool, default: True flag for applying shifts using cubic interpolation (otherwise FFT) @@ -869,6 +872,7 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1), 'num_splits_to_process_rig': None, # DO NOT MODIFY 'overlaps': (32, 32), # overlap between patches in pw-rigid motion correction 'pw_rigid': False, # flag for performing pw-rigid motion correction + 'shifts_interpolate': False, # interpolate shifts based on patch locations instead of resizing 'shifts_opencv': True, # flag for applying shifts using cubic interpolation (otherwise FFT) 'splits_els': 14, # number of splits across time for pw-rigid registration 'splits_rig': 14, # number of splits across time for rigid registration