Skip to content

Commit

Permalink
Merge pull request #1444 from proektlab/register_ROIs-improvements
Browse files Browse the repository at this point in the history
Fix and add more options for non-optical flow register_ROIs; add feature to interpolate shifts based on patch locations in tile_and_correct
  • Loading branch information
pgunn authored Jan 7, 2025
2 parents 654bd0f + a479909 commit 857ae12
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 231 deletions.
61 changes: 45 additions & 16 deletions caiman/base/rois.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Optional
import zipfile

from caiman.motion_correction import tile_and_correct
from caiman.motion_correction import tile_and_correct, get_patch_centers, interpolate_shifts

try:
cv2.setNumThreads(0)
Expand Down Expand Up @@ -318,7 +318,8 @@ def register_ROIs(A1,
print_assignment=False,
plot_results=False,
Cn=None,
cmap='viridis'):
cmap='viridis',
align_options: Optional[dict] = None):
"""
Register ROIs across different sessions using an intersection over union
metric and the Hungarian algorithm for optimal matching
Expand Down Expand Up @@ -372,6 +373,9 @@ def register_ROIs(A1,
cmap: string
colormap for background image
align_options: Optional[dict]
mcorr options to override defaults when align_flag is True and use_opt_flow is False
Returns:
matched_ROIs1: list
Expand Down Expand Up @@ -414,25 +418,50 @@ def register_ROIs(A1,
if use_opt_flow:
template1_norm = np.uint8(template1 * (template1 > 0) * 255)
template2_norm = np.uint8(template2 * (template2 > 0) * 255)
flow = cv2.calcOpticalFlowFarneback(np.uint8(template1_norm * 255), np.uint8(template2_norm * 255), None,
flow = cv2.calcOpticalFlowFarneback(template1_norm, template2_norm, None,
0.5, 3, 128, 3, 7, 1.5, 0)
x_remap = (flow[:, :, 0] + x_grid).astype(np.float32)
y_remap = (flow[:, :, 1] + y_grid).astype(np.float32)

else:
template2, shifts, _, xy_grid = tile_and_correct(template2,
template1 - template1.min(),
[int(dims[0] / 4), int(dims[1] / 4)], [16, 16], [10, 10],
add_to_movie=template2.min(),
shifts_opencv=True)

dims_grid = tuple(np.max(np.stack(xy_grid, axis=0), axis=0) - np.min(np.stack(xy_grid, axis=0), axis=0) + 1)
_sh_ = np.stack(shifts, axis=0)
shifts_x = np.reshape(_sh_[:, 1], dims_grid, order='C').astype(np.float32)
shifts_y = np.reshape(_sh_[:, 0], dims_grid, order='C').astype(np.float32)

x_remap = (-np.resize(shifts_x, dims) + x_grid).astype(np.float32)
y_remap = (-np.resize(shifts_y, dims) + y_grid).astype(np.float32)
align_defaults = {
"strides": (int(dims[0] / 4), int(dims[1] / 4)),
"overlaps": (16, 16),
"max_shifts": (10, 10),
"shifts_opencv": True,
"upsample_factor_grid": 4,
"shifts_interpolate": True,
"max_deviation_rigid": 2
# any other argument to tile_and_correct can also be used in align_options
}

if align_options:
# override defaults with input options
align_defaults.update(align_options)
align_options = align_defaults

template2, shifts, _, _ = tile_and_correct(template2, template1 - template1.min(),
add_to_movie=template2.min(), **align_options)

if align_options["max_deviation_rigid"] == 0:
# repeat rigid shifts to size of the image
shifts_x_full = np.full(dims, -shifts[1])
shifts_y_full = np.full(dims, -shifts[0])
else:
# piecewise - interpolate from patches to get shifts per pixel
patch_centers = get_patch_centers(dims, overlaps=align_options["overlaps"], strides=align_options["strides"],
shifts_opencv=align_options["shifts_opencv"],
upsample_factor_grid=align_options["upsample_factor_grid"])
patch_grid = tuple(len(centers) for centers in patch_centers)
_sh_ = np.stack(shifts, axis=0)
shifts_x = np.reshape(_sh_[:, 1], patch_grid, order='C').astype(np.float32)
shifts_y = np.reshape(_sh_[:, 0], patch_grid, order='C').astype(np.float32)

shifts_x_full = interpolate_shifts(-shifts_x, patch_centers, tuple(range(d) for d in dims))
shifts_y_full = interpolate_shifts(-shifts_y, patch_centers, tuple(range(d) for d in dims))

x_remap = (shifts_x_full + x_grid).astype(np.float32)
y_remap = (shifts_y_full + y_grid).astype(np.float32)

A_2t = np.reshape(A2, dims + (-1,), order='F').transpose(2, 0, 1)
A2 = np.stack([cv2.remap(img.astype(np.float32), x_remap, y_remap, cv2.INTER_NEAREST) for img in A_2t], axis=0)
Expand Down
Loading

0 comments on commit 857ae12

Please sign in to comment.