generated from snakemake-workflows/snakemake-workflow-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
242 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
from zarrnii import ZarrNii | ||
from scipy.fft import fftn, ifftn | ||
from scipy.ndimage import center_of_mass | ||
|
||
|
||
def phase_correlation(img1, img2): | ||
""" | ||
Compute the phase correlation between two 3D images to find the translation offset. | ||
Parameters: | ||
- img1 (np.ndarray): First image (3D array). | ||
- img2 (np.ndarray): Second image (3D array). | ||
Returns: | ||
- np.ndarray: Offset vector [z_offset, y_offset, x_offset]. | ||
""" | ||
# Compute the Fourier transforms | ||
fft1 = fftn(img1) | ||
fft2 = fftn(img2) | ||
|
||
# Compute the cross-power spectrum | ||
cross_power = fft1 * np.conj(fft2) | ||
cross_power /= np.abs(cross_power) # Normalize | ||
|
||
# Inverse Fourier transform to get correlation map | ||
correlation = np.abs(ifftn(cross_power)) | ||
|
||
# Find the peak in the correlation map | ||
peak = np.unravel_index(np.argmax(correlation), correlation.shape) | ||
|
||
# Convert peak index to an offset | ||
shifts = np.array(peak, dtype=float) | ||
for dim, size in enumerate(correlation.shape): | ||
if shifts[dim] > size // 2: | ||
shifts[dim] -= size | ||
|
||
return shifts | ||
|
||
|
||
def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs): | ||
""" | ||
Compute the optimal offset for each pair of overlapping tiles. | ||
Parameters: | ||
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. | ||
- overlapping_pairs (list of tuples): List of overlapping tile indices. | ||
Returns: | ||
- np.ndarray: Array of offsets for each pair (N, 3) where N is the number of pairs. | ||
""" | ||
offsets = [] | ||
|
||
for i, j in overlapping_pairs: | ||
# Load the two images | ||
znimg1 = ZarrNii.from_path(ome_zarr_paths[i]) | ||
znimg2 = ZarrNii.from_path(ome_zarr_paths[j]) | ||
|
||
img1 = znimg1.darr.squeeze().compute() | ||
img2 = znimg2.darr.squeeze().compute() | ||
|
||
# Compute phase correlation | ||
offset = phase_correlation(img1, img2) | ||
offsets.append(offset) | ||
|
||
return np.array(offsets) | ||
|
||
|
||
# Example usage | ||
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs | ||
ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths | ||
|
||
# Compute pairwise offsets | ||
offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs) | ||
|
||
# Save results | ||
np.savetxt(snakemake.output.offsets, offsets, fmt="%.6f") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
|
||
def find_overlapping_pairs(ome_zarr_paths): | ||
""" | ||
Identify overlapping tile pairs based on their physical offsets. | ||
Parameters: | ||
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. | ||
Returns: | ||
- List of tuples: Each tuple is a pair of overlapping tile indices ((i, j)). | ||
""" | ||
from zarrnii import ZarrNii | ||
|
||
# Read physical transformations and calculate bounding boxes | ||
bounding_boxes = [] | ||
for path in ome_zarr_paths: | ||
znimg = ZarrNii.from_path(path) | ||
affine = znimg.vox2ras.affine # 4x4 matrix | ||
tile_shape = znimg.darr.shape[1:] | ||
|
||
# Compute physical bounding box using affine | ||
corners = [ | ||
np.array([0, 0, 0, 1]), | ||
np.array([tile_shape[2], 0, 0, 1]), | ||
np.array([0, tile_shape[1], 0, 1]), | ||
np.array([tile_shape[2], tile_shape[1], 0, 1]), | ||
np.array([0, 0, tile_shape[0], 1]), | ||
np.array([tile_shape[2], 0, tile_shape[0], 1]), | ||
np.array([0, tile_shape[1], tile_shape[0], 1]), | ||
np.array([tile_shape[2], tile_shape[1], tile_shape[0], 1]), | ||
] | ||
corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] # Drop homogeneous coordinate | ||
bbox_min = corners_physical.min(axis=0) | ||
bbox_max = corners_physical.max(axis=0) | ||
|
||
bounding_boxes.append((bbox_min, bbox_max)) | ||
|
||
# Find overlapping pairs | ||
overlapping_pairs = [] | ||
for i, (bbox1_min, bbox1_max) in enumerate(bounding_boxes): | ||
for j, (bbox2_min, bbox2_max) in enumerate(bounding_boxes): | ||
if i >= j: | ||
continue # Avoid duplicate pairs and self-comparison | ||
|
||
# Check for overlap in all dimensions | ||
overlap = all( | ||
bbox1_min[d] < bbox2_max[d] and bbox1_max[d] > bbox2_min[d] | ||
for d in range(3) | ||
) | ||
if overlap: | ||
overlapping_pairs.append((i, j)) | ||
|
||
return overlapping_pairs | ||
|
||
|
||
overlapping_pairs = find_overlapping_pairs(snakemake.input) | ||
np.savetxt(snakemake.output.txt,np.array(overlapping_pairs),fmt='%d') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
from scipy.optimize import least_squares | ||
from zarrnii import ZarrNii | ||
|
||
|
||
def global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets): | ||
""" | ||
Perform global optimization to adjust translations for all tiles. | ||
Parameters: | ||
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. | ||
- overlapping_pairs (list of tuples): List of overlapping tile indices ((i, j)). | ||
- pairwise_offsets (np.ndarray): Array of pairwise offsets (N, 3), where N is the number of pairs. | ||
Returns: | ||
- np.ndarray: Optimized global translations of shape (T, 3). | ||
""" | ||
# Number of tiles is the number of OME-Zarr paths | ||
num_tiles = len(ome_zarr_paths) | ||
|
||
# Initial translations (start with identity translation: no offsets) | ||
initial_translations = np.zeros((num_tiles, 3)) | ||
|
||
# Flatten initial translations for optimization | ||
x0 = initial_translations.flatten() | ||
|
||
def objective(x): | ||
""" | ||
Compute the residuals for global optimization. | ||
Parameters: | ||
- x (np.ndarray): Flattened translations array (T * 3,). | ||
Returns: | ||
- np.ndarray: Residuals for least-squares optimization. | ||
""" | ||
translations = x.reshape((num_tiles, 3)) | ||
residuals = [] | ||
|
||
for (i, j), offset in zip(overlapping_pairs, pairwise_offsets): | ||
# Residual is the difference between the predicted and actual offset | ||
predicted_offset = translations[j] - translations[i] | ||
residuals.append(predicted_offset - offset) | ||
|
||
return np.concatenate(residuals) | ||
|
||
# Perform least-squares optimization | ||
result = least_squares(objective, x0) | ||
|
||
# Reshape result back to (T, 3) | ||
optimized_translations = result.x.reshape((num_tiles, 3)) | ||
|
||
return optimized_translations | ||
|
||
|
||
# Example usage | ||
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs | ||
pairwise_offsets = np.loadtxt(snakemake.input.offsets, dtype=float) # Pairwise offsets | ||
ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths | ||
|
||
# Perform global optimization | ||
optimized_translations = global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets) | ||
|
||
# Save results | ||
np.savetxt(snakemake.output.optimized_translations, optimized_translations, fmt="%.6f") | ||
|