diff --git a/mesospim_stitcher/big_stitcher_bridge.py b/mesospim_stitcher/big_stitcher_bridge.py index df801f4..3d66e28 100644 --- a/mesospim_stitcher/big_stitcher_bridge.py +++ b/mesospim_stitcher/big_stitcher_bridge.py @@ -13,7 +13,9 @@ def run_big_stitcher( downsample_y: int = 4, downsample_z: int = 4, ): - stitch_macro_path = Path(__file__).resolve().parent / "stitch_macro.ijm" + stitch_macro_path = ( + Path(__file__).resolve().parent / "bigstitcher_macro.ijm" + ) if platform.startswith("darwin"): imagej_path = imagej_path / "Contents/MacOS/ImageJ-macosx" diff --git a/mesospim_stitcher/stitch_macro.ijm b/mesospim_stitcher/bigstitcher_macro.ijm similarity index 100% rename from mesospim_stitcher/stitch_macro.ijm rename to mesospim_stitcher/bigstitcher_macro.ijm diff --git a/mesospim_stitcher/file_utils.py b/mesospim_stitcher/file_utils.py index 93079b6..b554e4f 100644 --- a/mesospim_stitcher/file_utils.py +++ b/mesospim_stitcher/file_utils.py @@ -4,7 +4,7 @@ import dask.array as da import h5py -import numpy.typing as npt +import numpy as np import zarr from tifffile import imwrite @@ -18,10 +18,16 @@ def create_pyramid_bdv_h5( input_file: Path, - resolutions_array: npt.NDArray, - subdivisions_array: npt.NDArray, yield_progress: bool = False, ): + resolutions_array = np.array( + [[1, 1, 1], [2, 2, 2], [4, 4, 4], [8, 8, 8], [16, 16, 16]] + ) + + subdivisions_array = np.array( + [[32, 32, 16], [32, 32, 16], [32, 32, 16], [32, 32, 16], [32, 32, 16]] + ) + with h5py.File(input_file, "r+") as f: data_group = f["t00000"] num_done = 0 diff --git a/mesospim_stitcher/image_mosaic.py b/mesospim_stitcher/image_mosaic.py index eb925cb..a4a6eff 100644 --- a/mesospim_stitcher/image_mosaic.py +++ b/mesospim_stitcher/image_mosaic.py @@ -23,14 +23,6 @@ ) from mesospim_stitcher.tile import Overlap, Tile -DOWNSAMPLE_ARRAY = np.array( - [[1, 1, 1], [2, 2, 2], [4, 4, 4], [8, 8, 8], [16, 16, 16]] -) - -SUBDIVISION_ARRAY = np.array( - [[32, 32, 16], [32, 32, 16], [32, 32, 16], [32, 32, 16], [32, 32, 16]] -) - class ImageMosaic: def __init__(self, directory: Path): @@ -86,8 +78,6 @@ def load_mesospim_directory(self) -> None: for update in create_pyramid_bdv_h5( self.h5_path, - DOWNSAMPLE_ARRAY, - SUBDIVISION_ARRAY, yield_progress=True, ): progress.update(task, advance=update) @@ -611,7 +601,7 @@ def fuse_to_bdv_h5( dtype=np.int16, ) - ds_list = [] + channel_ds_list = [] for i in range(self.num_channels): output_file.require_dataset( f"s{i:02}/resolutions", @@ -625,58 +615,60 @@ def fuse_to_bdv_h5( dtype="i2", shape=subdivisions.shape, ) + + ds_list = [] ds = output_file.require_dataset( f"t00000/s{i:02}/0/cells", shape=fused_image_shape, - chunks=(128, 128, 128), + chunks=(256, 256, 256), dtype="i2", ) ds_list.append(ds) - for tile in self.tiles[-1::-1]: - ds_list[tile.channel_id][ - tile.position[0] : tile.position[0] + z_size, - tile.position[1] : tile.position[1] + y_size, - tile.position[2] : tile.position[2] + x_size, - ] = tile.data_pyramid[0].compute() - - print(f"Done tile {tile.id}") - - for i in range(self.num_channels): - output_file.require_dataset( - f"s{i:02}/resolutions", - data=resolutions, - dtype="i2", - shape=resolutions.shape, - ) - print(f"s{i:02}/resolutions") - output_file.require_dataset( - f"s{i:02}/subdivisions", - data=subdivisions, - dtype="i2", - shape=subdivisions.shape, - ) - - for i in range(1, len(resolutions)): - for j in range(self.num_channels): - prev_resolution = da.from_array( - output_file[f"t00000/s{j:02}/{i - 1}/cells"] + for j in range(1, len(resolutions)): + new_shape = ( + fused_image_shape[0], + (fused_image_shape[1] + 1) // 2**j, + (fused_image_shape[2] + 1) // 2**j, ) - downsampled_image = downscale_nearest( - prev_resolution, (1, 2, 2) + down_ds = output_file.require_dataset( + f"t00000/s{i:02}/{j}/cells", + shape=new_shape, + chunks=(256, 256, 256), + dtype="i2", ) - downsampled_shape = downsampled_image.shape - output_file.require_dataset( - f"t00000/s{j:02}/{i}/cells", - data=downsampled_image.compute(), - shape=downsampled_shape, - chunks=downsampled_image.chunks, - dtype="i2", + ds_list.append(down_ds) + + channel_ds_list.append(ds_list) + + for tile in self.tiles[-1::-1]: + current_tile_data = tile.data_pyramid[0].compute() + channel_ds_list[tile.channel_id][0][ + tile.position[0] : tile.position[0] + z_size, + tile.position[1] : tile.position[1] + y_size, + tile.position[2] : tile.position[2] + x_size, + ] = current_tile_data + + for i in range(1, len(resolutions)): + scaled_position = tile.position // resolutions[i, -1::-1] + scaled_size = ( + z_size // resolutions[i][2], + (y_size + 1) // resolutions[i][1], + (x_size + 1) // resolutions[i][0], ) + channel_ds_list[tile.channel_id][i][ + scaled_position[0] : scaled_position[0] + scaled_size[0], + scaled_position[1] : scaled_position[1] + scaled_size[1], + scaled_position[2] : scaled_position[2] + scaled_size[2], + ] = current_tile_data[ + :: resolutions[i][2], + :: resolutions[i][1], + :: resolutions[i][0], + ] - print(f"Done resolution {i}") + print(f"Done tile {tile.id}") assert self.xml_path is not None @@ -686,6 +678,7 @@ def fuse_to_bdv_h5( output_path, fused_image_shape, ) + output_file.close() def get_metadata_for_zarr(self, pyramid_depth: int = 5): diff --git a/mesospim_stitcher/stitching_widget.py b/mesospim_stitcher/stitching_widget.py index 3000b2e..60314e4 100644 --- a/mesospim_stitcher/stitching_widget.py +++ b/mesospim_stitcher/stitching_widget.py @@ -4,7 +4,6 @@ import dask.array as da import h5py import napari.layers -import numpy as np from brainglobe_utils.qtpy.logo import header_widget from napari.qt.threading import create_worker from napari.utils.notifications import show_warning @@ -39,14 +38,6 @@ from mesospim_stitcher.image_mosaic import ImageMosaic from mesospim_stitcher.tile import Tile -DOWNSAMPLE_ARRAY = np.array( - [[1, 1, 1], [2, 2, 2], [4, 4, 4], [8, 8, 8], [16, 16, 16]] -) - -SUBDIVISION_ARRAY = np.array( - [[32, 32, 16], [32, 32, 16], [32, 32, 16], [32, 32, 16], [32, 32, 16]] -) - class StitchingWidget(QWidget): def __init__(self, napari_viewer: Viewer): @@ -244,8 +235,6 @@ def _on_create_pyramid_button_clicked(self): worker = create_worker( create_pyramid_bdv_h5, self.h5_path, - DOWNSAMPLE_ARRAY, - SUBDIVISION_ARRAY, yield_progress=True, ) worker.yielded.connect(self.progress_bar.setValue) @@ -375,8 +364,3 @@ def update_tiles_from_mosaic(self, napari_data): tile_data, tile_position = data tile_layer.data = tile_data tile_layer.translate = tile_position - - # def hideEvent(self, a0, QHideEvent=None): - # super().hideEvent(a0) - # if self.h5_file: - # self.h5_file.close()