From 5cad3f636847882e035d823848219c826a25eff2 Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Fri, 6 Dec 2024 10:06:07 -0500 Subject: [PATCH] WIP: prototype workflow for developing stitching alg --- dask-stitch/Snakefile | 16 ++++ dask-stitch/scripts/create_test_dataset.py | 85 ++++++++++++++++++++++ dask-stitch/scripts/get_tiles_as_nifti.py | 17 +++++ 3 files changed, 118 insertions(+) create mode 100644 dask-stitch/Snakefile create mode 100644 dask-stitch/scripts/create_test_dataset.py create mode 100644 dask-stitch/scripts/get_tiles_as_nifti.py diff --git a/dask-stitch/Snakefile b/dask-stitch/Snakefile new file mode 100644 index 0000000..d5ca63d --- /dev/null +++ b/dask-stitch/Snakefile @@ -0,0 +1,16 @@ + +rule create_test_dataset_ome_zarr: + output: + ome_zarr=directory('test_tiled.ome.zarr'), + translations_npy='test_translations.npy' + script: 'scripts/create_test_dataset.py' + +rule get_tiles_as_nifti: + input: + ome_zarr='test_tiled.ome.zarr', + output: + tiles_dir=directory('test_tiled_niftis') + script: + 'scripts/get_tiles_as_nifti.py' + + diff --git a/dask-stitch/scripts/create_test_dataset.py b/dask-stitch/scripts/create_test_dataset.py new file mode 100644 index 0000000..43aa2e7 --- /dev/null +++ b/dask-stitch/scripts/create_test_dataset.py @@ -0,0 +1,85 @@ +import numpy as np +from templateflow import api as tflow +import dask.array as da +from zarrnii import ZarrNii + + +out_ome_zarr=snakemake.output.ome_zarr +out_translations_npy=snakemake.output.translations_npy + +def create_test_dataset(template="MNI152NLin2009cAsym", res=2, tile_shape=(32,32, 32), final_chunks=(1,32,32,1),overlap=8, random_seed=42): + """ + Create a low-resolution test dataset for tile-based stitching. + + Parameters: + - template (str): TemplateFlow template name (default: MNI152NLin2009cAsym). + - res (int): Desired resolution in mm (default: 2mm). + - tile_shape (tuple): Shape of each tile in voxels (default: (64, 64, 64)). + - overlap (int): Overlap between tiles in voxels (default: 16). + - random_seed (int): Seed for reproducible random offsets. + + Returns: + - tiles (dask.array): TxZxYxX Dask array of tiles. + - translations (np.ndarray): Array of random offsets for each tile. + """ + # Seed the random number generator + np.random.seed(random_seed) + + # Download template and load as a Numpy array + template_path = tflow.get(template, resolution=res, desc=None,suffix="T1w") + znimg = ZarrNii.from_path(template_path) + print(znimg) + print(znimg.darr) + img_data = znimg.darr + + + # Determine number of tiles in each dimension + img_shape = img_data.shape + step = tuple(s - overlap for s in tile_shape) + + grid_shape = tuple( + max(((img_shape[dim] - tile_shape[dim]) // step[dim]) + 1, 1) + for dim in range(3) + ) + + print(f'img_shape {img_shape}, step: {step}, grid_shape: {grid_shape}') + # Create tiles + tiles = [] + translations = [] + for z in range(grid_shape[0]): + for y in range(grid_shape[1]): + for x in range(grid_shape[2]): + # Extract tile + z_start, y_start, x_start = z * step[0], y * step[1], x * step[2] + tile = img_data[z_start:z_start+tile_shape[0], + y_start:y_start+tile_shape[1], + x_start:x_start+tile_shape[2]] + + # Add to list + tiles.append(tile) + + # Add random offset + offset = np.random.uniform(-5, 5, size=3) # Random 3D offsets + translations.append((z_start, y_start, x_start) + offset) + + print(tiles) + # Convert to a Dask array + print(tile_shape) + tiles = da.concatenate([tile.rechunk(chunks=final_chunks) for tile in tiles]) + translations = np.array(translations) + + znimg.darr = tiles + + return znimg, translations + + + + +if __name__ == '__main__': + test_znimg, test_translations = create_test_dataset() + print(test_znimg) + + print(test_translations.shape) + test_znimg.to_ome_zarr(out_ome_zarr) + np.save(out_translations_npy,test_translations) + diff --git a/dask-stitch/scripts/get_tiles_as_nifti.py b/dask-stitch/scripts/get_tiles_as_nifti.py new file mode 100644 index 0000000..784bad5 --- /dev/null +++ b/dask-stitch/scripts/get_tiles_as_nifti.py @@ -0,0 +1,17 @@ +import nibabel as nib +from zarrnii import ZarrNii +from pathlib import Path + + +znimg_example_tile= ZarrNii.from_path(snakemake.input.ome_zarr) +print(znimg_full.darr.shape) + +out_dir = Path(snakemake.output.tiles_dir) +out_dir.mkdir(exist_ok=True, parents=True) + +for tile in range(znimg_full.darr.shape[0]): + print(f'reading tile {tile} and writing to nifti') + ZarrNii.from_path(snakemake.input.ome_zarr,channels=[tile]).to_nifti(out_dir / f'tile_{tile:02d}.nii') + + +