Skip to content

Commit

Permalink
Added ability to fuse zarr starting at downscaled resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorTatarnikov committed Jul 25, 2024
1 parent f0b3bc5 commit 6172a42
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 28 deletions.
2 changes: 1 addition & 1 deletion brainglobe_stitch/big_stitcher_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def calculate_pairwise_links(

fiji_path = resolve_fiji_path(fiji_path)

if selected_channel == "All channels":
if selected_channel == "All Channels":
stitch_macro_path = (
macro_directory / "calculate_pairwise_all_channel.ijm"
)
Expand Down
88 changes: 65 additions & 23 deletions brainglobe_stitch/image_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,11 @@ def normalise_intensity(
# Adjust the intensity of each tile based on the scale factors
for tile in self.tiles:
if self.scale_factors[tile.id] != 1.0:
tile.data_pyramid[resolution_level] = np.multiply(
tile.data_pyramid[resolution_level] = da.multiply(
tile.data_pyramid[resolution_level],
self.scale_factors[tile.id],
dtype=np.float16,
).astype(np.uint16)
).astype(tile.data_pyramid[resolution_level].dtype)

self.intensity_adjusted[resolution_level] = True

Expand All @@ -516,6 +516,7 @@ def calculate_intensity_scale_factors(

for tile_i in self.tiles:
# Iterate through the neighbours of each tile
print(f"Calculating scale factors for tile {tile_i.id}")
for neighbour_id in tile_i.neighbours:
tile_j = self.tiles[neighbour_id]
overlap = self.overlaps[(tile_i.id, tile_j.id)]
Expand All @@ -526,18 +527,18 @@ def calculate_intensity_scale_factors(
)

# Calculate the percentile intensity of the overlapping data
median_i = np.percentile(i_overlap.ravel(), percentile)
median_j = np.percentile(j_overlap.ravel(), percentile)
median_i = da.percentile(i_overlap.ravel(), percentile)
median_j = da.percentile(j_overlap.ravel(), percentile)

curr_scale_factor = (median_i / median_j).compute()
scale_factors[tile_i.id][tile_j.id] = curr_scale_factor[0]

# Adjust the tile intensity based on the scale factor
tile_j.data_pyramid[resolution_level] = np.multiply(
tile_j.data_pyramid[resolution_level] = da.multiply(
tile_j.data_pyramid[resolution_level],
curr_scale_factor,
dtype=np.float16,
).astype(np.uint16)
).astype(tile_j.data_pyramid[resolution_level].dtype)

self.intensity_adjusted[resolution_level] = True
# Calculate the product of the scale factors for each tile's neighbours
Expand Down Expand Up @@ -580,6 +581,7 @@ def fuse(
output_file_name: str = "fused.zarr",
normalise_intensity: bool = False,
interpolate: bool = False,
resolution_level: int = 3,
) -> None:
"""
Fuse the tiles into a single image and save it to the output file.
Expand All @@ -592,25 +594,52 @@ def fuse(
Whether to normalise the intensity of the image before fusing.
interpolate: bool
Whether to interpolate the overlaps before fusing.
resolution_level: int
The resolution level to fuse the tiles at.
"""
output_path = self.directory / output_file_name
downsample_z, downsample_y, downsample_x = self.tiles[
0
].resolution_pyramid[resolution_level]

z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape
z_size, y_size, x_size = (
self.tiles[resolution_level].data_pyramid[0].shape
)
# Calculate the shape of the fused image
fused_image_shape: Tuple[int, ...] = (
max([tile.position[0] for tile in self.tiles]) + z_size,
max([tile.position[1] for tile in self.tiles]) + y_size,
max([tile.position[2] for tile in self.tiles]) + x_size,
(
max([tile.position[0] for tile in self.tiles])
+ z_size
+ int(resolution_level > 0)
)
// downsample_z,
(
max([tile.position[1] for tile in self.tiles])
+ y_size
+ int(resolution_level > 0)
)
// downsample_y,
(
max([tile.position[2] for tile in self.tiles])
+ x_size
+ int(resolution_level > 0)
)
// downsample_x,
)

if normalise_intensity:
self.normalise_intensity(0, 80)
self.normalise_intensity(resolution_level, 80)

if interpolate:
self.interpolate_overlaps(0)
self.interpolate_overlaps(resolution_level)

if output_path.suffix == ".zarr":
self._fuse_to_zarr(output_path, fused_image_shape)
self._fuse_to_zarr(
output_path,
fused_image_shape,
pyramid_depth=4,
resolution_level=resolution_level,
)
elif output_path.suffix == ".h5":
self._fuse_to_bdv_h5(output_path, fused_image_shape)
elif output_path.suffix in [".tif", ".tiff"]:
Expand All @@ -620,7 +649,8 @@ def _fuse_to_zarr(
self,
output_path: Path,
fused_image_shape: Tuple[int, ...],
pyramid_depth: int = 6,
pyramid_depth: int = 5,
resolution_level: int = 0,
) -> None:
"""
Fuse the tiles in the ImageMosaic into a single image and save it as a
Expand All @@ -632,8 +662,16 @@ def _fuse_to_zarr(
The path of the output file.
fused_image_shape: Tuple[int, ...]
The shape of the fused image.
pyramid_depth: int
The depth of the image pyramid.
Default is 6.
resolution_level: int
The resolution level to fuse the tiles at.
Default is 0.
"""
z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape
z_size, y_size, x_size = (
self.tiles[0].data_pyramid[resolution_level].shape
)

# Default chunk shape is (256, 256, 256) for the highest resolution
chunk_shape: Tuple[int, ...] = (256, 256, 256)
Expand Down Expand Up @@ -663,19 +701,23 @@ def _fuse_to_zarr(

# Place the tiles in reverse order of acquisition
for tile in self.tiles[-1::-1]:
scaled_translation = np.round(
np.array(tile.position)
/ tile.resolution_pyramid[resolution_level]
).astype(np.int32)
if self.num_channels > 1:
fused_image_store[
tile.channel_id,
tile.position[0] : tile.position[0] + z_size,
tile.position[1] : tile.position[1] + y_size,
tile.position[2] : tile.position[2] + x_size,
] = tile.data_pyramid[0].compute()
scaled_translation[0] : scaled_translation[0] + z_size,
scaled_translation[1] : scaled_translation[1] + y_size,
scaled_translation[2] : scaled_translation[2] + x_size,
] = tile.data_pyramid[resolution_level].compute()
else:
fused_image_store[
tile.position[0] : tile.position[0] + z_size,
tile.position[1] : tile.position[1] + y_size,
tile.position[2] : tile.position[2] + x_size,
] = tile.data_pyramid[0].compute()
scaled_translation[0] : scaled_translation[0] + z_size,
scaled_translation[1] : scaled_translation[1] + y_size,
scaled_translation[2] : scaled_translation[2] + x_size,
] = tile.data_pyramid[resolution_level].compute()

print(f"Done tile {tile.id}")

Expand Down
8 changes: 4 additions & 4 deletions brainglobe_stitch/stitching_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def _on_add_tiles_button_clicked(self):
worker.yielded.connect(self._set_tile_layers)
worker.start()

self.fuse_button.setEnabled(True)
self.adjust_intensity_button.setEnabled(True)
self.interpolate_button.setEnabled(True)

def _set_tile_layers(self, tile_layer: napari.layers.Image):
tile_layer = self._viewer.add_layer(tile_layer)
self.tile_layers.append(tile_layer)
Expand All @@ -285,10 +289,6 @@ def _on_stitch_button_clicked(self):

self.update_tiles_from_mosaic(napari_data)

self.fuse_button.setEnabled(True)
self.adjust_intensity_button.setEnabled(True)
self.interpolate_button.setEnabled(True)

def _on_adjust_intensity_button_clicked(self):
self.image_mosaic.normalise_intensity(
resolution_level=self.resolution_to_display,
Expand Down

0 comments on commit 6172a42

Please sign in to comment.