Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and add more options for non-optical flow register_ROIs; add feature to interpolate shifts based on patch locations in tile_and_correct #1444

Merged
merged 12 commits into from
Jan 7, 2025
Merged
Prev Previous commit
Next Next commit
Add interp_shifts_precisely option for tile_and_correct
ethanbb committed Jan 6, 2025
commit 776c7362f9a5273b14dbc73f649ceb50371ca149
68 changes: 45 additions & 23 deletions caiman/motion_correction.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@
from numpy.fft import ifftshift
import os
import scipy
import scipy.interpolate
from skimage.transform import resize as resize_sk
from skimage.transform import warp as warp_sk
import sys
@@ -1799,6 +1800,28 @@ def apply_shifts_dft(src_freq, shifts, diffphase, is_freq=True, border_nan=True)

return new_img

def get_patch_edges(dims: tuple[int, ...], overlaps: tuple[int, ...], strides: tuple[int, ...],
) -> tuple[list[int], ...]:
"""For each dimension, return a vector of pixels along that dimension where patches start"""
windowSizes = np.add(overlaps, strides)
return tuple(
list(range(0, dim - windowSize, stride)) + [dim - windowSize]
for dim, windowSize, stride in zip(dims, windowSizes, strides)
)


def get_patch_centers(dims: tuple[int, ...], overlaps: tuple[int, ...], strides: tuple[int, ...],
shifts_opencv: bool, upsample_factor_grid=1) -> tuple[list[int], ...]:
"""For each dimension, return a vector of patch center locations for pw_rigid correction"""
if not shifts_opencv:
# account for upsampling step
strides = tuple(np.round(np.divide(strides, upsample_factor_grid)).astype(int))

patch_edges = get_patch_edges(dims, overlaps, strides)
windowSizes = np.add(overlaps, strides)
return tuple([edge + (sz - 1) / 2 for edge in edges] for edges, sz in zip(patch_edges, windowSizes))


def sliding_window_dims(dims: tuple[int, ...], overlaps: tuple[int, ...], strides: tuple[int, ...]):
""" computes dimensions for a sliding window with given image dims, overlaps, and strides
@@ -1819,12 +1842,9 @@ def sliding_window_dims(dims: tuple[int, ...], overlaps: tuple[int, ...], stride
size: (x, y, ...) size of patch in pixels
"""
windowSizes = np.add(overlaps, strides)
ranges = [
list(range(0, dim - windowSize, stride)) + [dim - windowSize]
for dim, windowSize, stride in zip(dims, windowSizes, strides)
]
edge_ranges = get_patch_edges(dims, overlaps, strides)

for patch in itertools.product(*[enumerate(r) for r in ranges]):
for patch in itertools.product(*[enumerate(r) for r in edge_ranges]):
inds, corners = zip(*patch)
yield (tuple(inds), tuple(corners), windowSizes)

@@ -1884,21 +1904,6 @@ def sliding_window_3d(image: np.ndarray, overlaps: tuple[int, int, int], strides
corner[2]:corner[2] + size[2]]
yield inds + corner + (patch,)

def get_patch_centers(dims: tuple[int, ...], strides: tuple[int, ...], overlaps: tuple[int, ...],
upsample_factor_grid: int, shifts_opencv: bool) -> np.ndarray:
"""
Infer x/y[/z] centers of patches in pixel units for pw_rigid correction
Returns n_patches x n_dims ndarray.
"""
if not shifts_opencv:
# account for upsampling step
strides = tuple(np.round(np.divide(strides, upsample_factor_grid)).astype(int))

return np.stack([
[c + (sz-1) / 2 for c, sz in zip(corner, patch_size)] # from first pixel to center = (width-1)/2
for _, corner, patch_size in sliding_window_dims(dims, overlaps, strides)
])


def iqr(a):
return np.percentile(a, 75) - np.percentile(a, 25)
@@ -1989,7 +1994,7 @@ def high_pass_filter_space(img_orig, gSig_filt=None, freq=None, order=None):

def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=None, newstrides=None, upsample_factor_grid=4,
upsample_factor_fft=10, show_movie=False, max_deviation_rigid=2, add_to_movie=0, shifts_opencv=False, gSig_filt=None,
use_cuda=False, border_nan=True):
use_cuda=False, border_nan=True, interp_shifts_precisely=False):
""" perform piecewise rigid motion correction iteration, by
1) dividing the FOV in patches
2) motion correcting each patch separately
@@ -2044,6 +2049,10 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N
border_nan : bool or string, optional
specifies how to deal with borders. (True, False, 'copy', 'min')
interp_shifts_precisely: bool
use patch locations to interpolate shifts rather than just upscaling to size of image. Default: False
currently only implemented for shifts_opencv.
Returns:
(new_img, total_shifts, start_step, xy_grid)
@@ -2127,8 +2136,21 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N
dims = img.shape
x_grid, y_grid = np.meshgrid(np.arange(0., dims[1]).astype(
np.float32), np.arange(0., dims[0]).astype(np.float32))
m_reg = cv2.remap(img, cv2.resize(shift_img_y.astype(np.float32), dims[::-1]) + x_grid,
cv2.resize(shift_img_x.astype(np.float32), dims[::-1]) + y_grid,

if interp_shifts_precisely:
# get locations of patches
patch_centers = get_patch_centers(dims, strides=strides, overlaps=overlaps, shifts_opencv=True)
# clip destination pixels to avoid extrapolation
y_grid_clipped = np.clip(y_grid, min(patch_centers[0]), max(patch_centers[0]))
x_grid_clipped = np.clip(x_grid, min(patch_centers[1]), max(patch_centers[1]))
dest_grid = np.dstack((y_grid_clipped, x_grid_clipped))
shifts_x = scipy.interpolate.interpn(patch_centers, shift_img_y.astype(np.float32), dest_grid)
shifts_y = scipy.interpolate.interpn(patch_centers, shift_img_x.astype(np.float32), dest_grid)
else:
shifts_x = cv2.resize(shift_img_y.astype(np.float32), dims[::-1])
shifts_y = cv2.resize(shift_img_x.astype(np.float32), dims[::-1])

m_reg = cv2.remap(img, shifts_x + x_grid, shifts_y + y_grid,
cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
total_shifts = [
(-x, -y) for x, y in zip(shift_img_x.reshape(num_tiles), shift_img_y.reshape(num_tiles))]