Skip to content

Commit

Permalink
up to global optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 6, 2024
1 parent abe66df commit bf2fcdb
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 9 deletions.
25 changes: 24 additions & 1 deletion dask-stitch/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,31 @@ rule create_test_dataset_single_ome_zarr:
script: 'scripts/create_test_dataset_singletile.py'


rule find_overlapping_pairs:
input:
ome_zarr=get_tile_targets()
output:
txt='test_grid-{gridx}by{gridy}/overlapping_pairs.txt'
script: 'scripts/find_overlapping_pairs.py'

rule compute_pairwise_correlation:
input:
ome_zarr=get_tile_targets(),
pairs='test_grid-{gridx}by{gridy}/overlapping_pairs.txt'
output:
offsets='test_grid-{gridx}by{gridy}/pairwise_offsets.txt'
script: 'scripts/compute_pairwise_correlation.py'


rule global_optimization:
input:
ome_zarr=get_tile_targets(),
pairs='test_grid-{gridx}by{gridy}/overlapping_pairs.txt',
offsets='test_grid-{gridx}by{gridy}/pairwise_offsets.txt'
output:
optimized_translations='test_grid-{gridx}by{gridy}/optimized_translations.txt'
script:
'scripts/global_optimization.py'


#-- unused below:

Expand Down
78 changes: 78 additions & 0 deletions dask-stitch/scripts/compute_pairwise_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
from zarrnii import ZarrNii
from scipy.fft import fftn, ifftn
from scipy.ndimage import center_of_mass


def phase_correlation(img1, img2):
"""
Compute the phase correlation between two 3D images to find the translation offset.
Parameters:
- img1 (np.ndarray): First image (3D array).
- img2 (np.ndarray): Second image (3D array).
Returns:
- np.ndarray: Offset vector [z_offset, y_offset, x_offset].
"""
# Compute the Fourier transforms
fft1 = fftn(img1)
fft2 = fftn(img2)

# Compute the cross-power spectrum
cross_power = fft1 * np.conj(fft2)
cross_power /= np.abs(cross_power) # Normalize

# Inverse Fourier transform to get correlation map
correlation = np.abs(ifftn(cross_power))

# Find the peak in the correlation map
peak = np.unravel_index(np.argmax(correlation), correlation.shape)

# Convert peak index to an offset
shifts = np.array(peak, dtype=float)
for dim, size in enumerate(correlation.shape):
if shifts[dim] > size // 2:
shifts[dim] -= size

return shifts


def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs):
"""
Compute the optimal offset for each pair of overlapping tiles.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
- overlapping_pairs (list of tuples): List of overlapping tile indices.
Returns:
- np.ndarray: Array of offsets for each pair (N, 3) where N is the number of pairs.
"""
offsets = []

for i, j in overlapping_pairs:
# Load the two images
znimg1 = ZarrNii.from_path(ome_zarr_paths[i])
znimg2 = ZarrNii.from_path(ome_zarr_paths[j])

img1 = znimg1.darr.squeeze().compute()
img2 = znimg2.darr.squeeze().compute()

# Compute phase correlation
offset = phase_correlation(img1, img2)
offsets.append(offset)

return np.array(offsets)


# Example usage
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs
ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths

# Compute pairwise offsets
offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs)

# Save results
np.savetxt(snakemake.output.offsets, offsets, fmt="%.6f")

23 changes: 15 additions & 8 deletions dask-stitch/scripts/create_test_dataset_singletile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import nibabel as nib
import numpy as np
from scipy.ndimage import affine_transform
from templateflow import api as tflow
import nibabel as nib
import dask.array as da
Expand Down Expand Up @@ -64,21 +65,27 @@ def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2
x_start = x * (x_tile_size - overlap)
y_start = y * (y_tile_size - overlap)


# TODO: Simulate error by applying a transformation to the image before

# initially lets just do a random jitter:
offset = np.random.uniform(-10, 10, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile

xfm_img_data = affine_transform(img_data,matrix=np.eye(3,3),offset=offset[x,y,:],order=1)

# Extract tile
tile = img_data[x_start:x_start + x_tile_size, y_start:y_start + y_tile_size, :]
tile = xfm_img_data[x_start:x_start + x_tile_size, y_start:y_start + y_tile_size, :]

# Add random offset - ensure that the random offset generated for the same tile is the same
# do this by gneerating
offset = np.random.uniform(-5, 5, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets
translation = ((x_start, y_start, 0) + offset[x,y,:])

print((x_start, y_start, 0))
print(translation)

translation = ((x_start, y_start, 0))



tile_shape = (1,x_tile_size, y_tile_size, z_dim)


#save translation into vox2ras
#save tiling coordinate translation into vox2ras (not the random jitter)
vox2ras = np.eye(4)
vox2ras[:3,3] = np.array(translation)

Expand Down
59 changes: 59 additions & 0 deletions dask-stitch/scripts/find_overlapping_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np

def find_overlapping_pairs(ome_zarr_paths):
"""
Identify overlapping tile pairs based on their physical offsets.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
Returns:
- List of tuples: Each tuple is a pair of overlapping tile indices ((i, j)).
"""
from zarrnii import ZarrNii

# Read physical transformations and calculate bounding boxes
bounding_boxes = []
for path in ome_zarr_paths:
znimg = ZarrNii.from_path(path)
affine = znimg.vox2ras.affine # 4x4 matrix
tile_shape = znimg.darr.shape[1:]

# Compute physical bounding box using affine
corners = [
np.array([0, 0, 0, 1]),
np.array([tile_shape[2], 0, 0, 1]),
np.array([0, tile_shape[1], 0, 1]),
np.array([tile_shape[2], tile_shape[1], 0, 1]),
np.array([0, 0, tile_shape[0], 1]),
np.array([tile_shape[2], 0, tile_shape[0], 1]),
np.array([0, tile_shape[1], tile_shape[0], 1]),
np.array([tile_shape[2], tile_shape[1], tile_shape[0], 1]),
]
corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] # Drop homogeneous coordinate
bbox_min = corners_physical.min(axis=0)
bbox_max = corners_physical.max(axis=0)

bounding_boxes.append((bbox_min, bbox_max))

# Find overlapping pairs
overlapping_pairs = []
for i, (bbox1_min, bbox1_max) in enumerate(bounding_boxes):
for j, (bbox2_min, bbox2_max) in enumerate(bounding_boxes):
if i >= j:
continue # Avoid duplicate pairs and self-comparison

# Check for overlap in all dimensions
overlap = all(
bbox1_min[d] < bbox2_max[d] and bbox1_max[d] > bbox2_min[d]
for d in range(3)
)
if overlap:
overlapping_pairs.append((i, j))

return overlapping_pairs


overlapping_pairs = find_overlapping_pairs(snakemake.input)
np.savetxt(snakemake.output.txt,np.array(overlapping_pairs),fmt='%d')

66 changes: 66 additions & 0 deletions dask-stitch/scripts/global_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
from scipy.optimize import least_squares
from zarrnii import ZarrNii


def global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets):
"""
Perform global optimization to adjust translations for all tiles.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
- overlapping_pairs (list of tuples): List of overlapping tile indices ((i, j)).
- pairwise_offsets (np.ndarray): Array of pairwise offsets (N, 3), where N is the number of pairs.
Returns:
- np.ndarray: Optimized global translations of shape (T, 3).
"""
# Number of tiles is the number of OME-Zarr paths
num_tiles = len(ome_zarr_paths)

# Initial translations (start with identity translation: no offsets)
initial_translations = np.zeros((num_tiles, 3))

# Flatten initial translations for optimization
x0 = initial_translations.flatten()

def objective(x):
"""
Compute the residuals for global optimization.
Parameters:
- x (np.ndarray): Flattened translations array (T * 3,).
Returns:
- np.ndarray: Residuals for least-squares optimization.
"""
translations = x.reshape((num_tiles, 3))
residuals = []

for (i, j), offset in zip(overlapping_pairs, pairwise_offsets):
# Residual is the difference between the predicted and actual offset
predicted_offset = translations[j] - translations[i]
residuals.append(predicted_offset - offset)

return np.concatenate(residuals)

# Perform least-squares optimization
result = least_squares(objective, x0)

# Reshape result back to (T, 3)
optimized_translations = result.x.reshape((num_tiles, 3))

return optimized_translations


# Example usage
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs
pairwise_offsets = np.loadtxt(snakemake.input.offsets, dtype=float) # Pairwise offsets
ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths

# Perform global optimization
optimized_translations = global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets)

# Save results
np.savetxt(snakemake.output.optimized_translations, optimized_translations, fmt="%.6f")

0 comments on commit bf2fcdb

Please sign in to comment.