From abe66df562a89a585fb0520203f6ffc8bc13f426 Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Fri, 6 Dec 2024 12:03:34 -0500 Subject: [PATCH] make multiple ome zarr --- dask-stitch/Snakefile | 44 ++++++-- dask-stitch/config.yml | 3 + dask-stitch/scripts/create_test_dataset.py | 85 --------------- .../scripts/create_test_dataset_combined.py | 94 ++++++++++++++++ .../scripts/create_test_dataset_singletile.py | 101 ++++++++++++++++++ dask-stitch/scripts/get_tiles_as_nifti.py | 5 +- 6 files changed, 237 insertions(+), 95 deletions(-) create mode 100644 dask-stitch/config.yml delete mode 100644 dask-stitch/scripts/create_test_dataset.py create mode 100644 dask-stitch/scripts/create_test_dataset_combined.py create mode 100644 dask-stitch/scripts/create_test_dataset_singletile.py diff --git a/dask-stitch/Snakefile b/dask-stitch/Snakefile index d5ca63d..0ee335a 100644 --- a/dask-stitch/Snakefile +++ b/dask-stitch/Snakefile @@ -1,15 +1,47 @@ +configfile: 'config.yml' -rule create_test_dataset_ome_zarr: +def get_tile_targets(): + gridx=config['gridx'] + gridy=config['gridy'] + tile_path=f'test_grid-{gridx}by{gridy}/tile-{{tile}}_SPIM.ome.zarr' + + return expand(tile_path, tile=range(gridx*gridy)) + +rule all: + input: + ome_zarr=get_tile_targets() + +rule create_test_dataset_single_ome_zarr: + params: + grid_shape=lambda wildcards: (int(wildcards.gridx),int(wildcards.gridy)), + tile_index=lambda wildcards: int(wildcards.tile) output: - ome_zarr=directory('test_tiled.ome.zarr'), - translations_npy='test_translations.npy' - script: 'scripts/create_test_dataset.py' + ome_zarr=directory('test_grid-{gridx}by{gridy}/tile-{tile}_SPIM.ome.zarr'), + nifti='test_grid-{gridx}by{gridy}/tile-{tile}_SPIM.nii', + script: 'scripts/create_test_dataset_singletile.py' + + + + + +#-- unused below: + +rule create_test_dataset_combined_ome_zarr: + params: + grid_shape=lambda wildcards: (int(wildcards.gridx),int(wildcards.gridy)) + output: + ome_zarr=directory('testcombined_grid-{gridx}by{gridy}_SPIM.ome.zarr'), + translations_npy='testcombined_grid-{gridx}by{gridy}_translations.npy' + script: 'scripts/create_test_dataset_combined.py' + rule get_tiles_as_nifti: input: - ome_zarr='test_tiled.ome.zarr', + ome_zarr='testcombined_grid-{gridx}by{gridy}_SPIM.ome.zarr' + params: + n_tiles = lambda wildcards: int(wildcards.gridx) * int(wildcards.gridy) output: - tiles_dir=directory('test_tiled_niftis') + tiles_dir=directory('testcombine_grid-{gridx}by{gridy}_SPIM.niftis') script: 'scripts/get_tiles_as_nifti.py' diff --git a/dask-stitch/config.yml b/dask-stitch/config.yml new file mode 100644 index 0000000..829936f --- /dev/null +++ b/dask-stitch/config.yml @@ -0,0 +1,3 @@ +gridx: 3 +gridy: 4 + diff --git a/dask-stitch/scripts/create_test_dataset.py b/dask-stitch/scripts/create_test_dataset.py deleted file mode 100644 index 43aa2e7..0000000 --- a/dask-stitch/scripts/create_test_dataset.py +++ /dev/null @@ -1,85 +0,0 @@ -import numpy as np -from templateflow import api as tflow -import dask.array as da -from zarrnii import ZarrNii - - -out_ome_zarr=snakemake.output.ome_zarr -out_translations_npy=snakemake.output.translations_npy - -def create_test_dataset(template="MNI152NLin2009cAsym", res=2, tile_shape=(32,32, 32), final_chunks=(1,32,32,1),overlap=8, random_seed=42): - """ - Create a low-resolution test dataset for tile-based stitching. - - Parameters: - - template (str): TemplateFlow template name (default: MNI152NLin2009cAsym). - - res (int): Desired resolution in mm (default: 2mm). - - tile_shape (tuple): Shape of each tile in voxels (default: (64, 64, 64)). - - overlap (int): Overlap between tiles in voxels (default: 16). - - random_seed (int): Seed for reproducible random offsets. - - Returns: - - tiles (dask.array): TxZxYxX Dask array of tiles. - - translations (np.ndarray): Array of random offsets for each tile. - """ - # Seed the random number generator - np.random.seed(random_seed) - - # Download template and load as a Numpy array - template_path = tflow.get(template, resolution=res, desc=None,suffix="T1w") - znimg = ZarrNii.from_path(template_path) - print(znimg) - print(znimg.darr) - img_data = znimg.darr - - - # Determine number of tiles in each dimension - img_shape = img_data.shape - step = tuple(s - overlap for s in tile_shape) - - grid_shape = tuple( - max(((img_shape[dim] - tile_shape[dim]) // step[dim]) + 1, 1) - for dim in range(3) - ) - - print(f'img_shape {img_shape}, step: {step}, grid_shape: {grid_shape}') - # Create tiles - tiles = [] - translations = [] - for z in range(grid_shape[0]): - for y in range(grid_shape[1]): - for x in range(grid_shape[2]): - # Extract tile - z_start, y_start, x_start = z * step[0], y * step[1], x * step[2] - tile = img_data[z_start:z_start+tile_shape[0], - y_start:y_start+tile_shape[1], - x_start:x_start+tile_shape[2]] - - # Add to list - tiles.append(tile) - - # Add random offset - offset = np.random.uniform(-5, 5, size=3) # Random 3D offsets - translations.append((z_start, y_start, x_start) + offset) - - print(tiles) - # Convert to a Dask array - print(tile_shape) - tiles = da.concatenate([tile.rechunk(chunks=final_chunks) for tile in tiles]) - translations = np.array(translations) - - znimg.darr = tiles - - return znimg, translations - - - - -if __name__ == '__main__': - test_znimg, test_translations = create_test_dataset() - print(test_znimg) - - print(test_translations.shape) - test_znimg.to_ome_zarr(out_ome_zarr) - np.save(out_translations_npy,test_translations) - diff --git a/dask-stitch/scripts/create_test_dataset_combined.py b/dask-stitch/scripts/create_test_dataset_combined.py new file mode 100644 index 0000000..c443f4b --- /dev/null +++ b/dask-stitch/scripts/create_test_dataset_combined.py @@ -0,0 +1,94 @@ +import numpy as np +from templateflow import api as tflow +import dask.array as da +import math +from zarrnii import ZarrNii + +def create_test_dataset(template="MNI152NLin2009cAsym", res=2, grid_shape=(3, 4), overlap=8, random_seed=42, final_chunks=(32, 32, 1)): + """ + Create a low-resolution test dataset for tile-based stitching. + + Parameters: + - template (str): TemplateFlow template name (default: MNI152NLin2009cAsym). + - res (int): Desired resolution in mm (default: 2mm). + - grid_shape (tuple): Shape of the tiling grid (e.g., (3, 4) for 3x4 grid in X-Y). + - overlap (int): Overlap between tiles in X-Y plane in voxels (default: 8). + - random_seed (int): Seed for reproducible random offsets. + - final_chunks (tuple): Desired chunks for final tiles. + + Returns: + - znimg (ZarrNii): ZarrNii object containing the tiles. + - translations (np.ndarray): Array of random offsets for each tile. + """ + import math + + # Seed the random number generator + np.random.seed(random_seed) + + # Download template and load as a ZarrNii object + template_path = tflow.get(template, resolution=res, desc=None, suffix="T1w") + znimg = ZarrNii.from_path(template_path) + img_data = znimg.darr.squeeze() + + # Original image shape + img_shape = np.array(img_data.shape) # (Z, Y, X) + + # Keep Z dimension intact, calculate X and Y tile sizes + z_dim, y_dim, x_dim = img_shape + x_tile_size = math.ceil((x_dim + overlap * (grid_shape[1] - 1)) / grid_shape[1]) + y_tile_size = math.ceil((y_dim + overlap * (grid_shape[0] - 1)) / grid_shape[0]) + tile_shape = (z_dim, y_tile_size, x_tile_size) + + # Calculate required padding to ensure X and Y dimensions are divisible by grid shape + padded_x = x_tile_size * grid_shape[1] - overlap * (grid_shape[1] - 1) + padded_y = y_tile_size * grid_shape[0] - overlap * (grid_shape[0] - 1) + + padding = ( + (0, 0), # No padding in Z + (0, int(max(padded_y - y_dim, 0))), + (0, int(max(padded_x - x_dim, 0))), + ) + + print('padding') + print(padding) + print(img_data.shape) + + # Pad image if needed + if any(p[1] > 0 for p in padding): + img_data = da.pad(img_data, padding, mode="constant", constant_values=0) + + # Create tiles + tiles = [] + translations = [] + for y in range(grid_shape[0]): + for x in range(grid_shape[1]): + # Calculate tile start indices + y_start = y * (y_tile_size - overlap) + x_start = x * (x_tile_size - overlap) + + # Extract tile + tile = img_data[:, y_start:y_start + y_tile_size, x_start:x_start + x_tile_size] + tiles.append(tile) + + # Add random offset -- NOT ACTUALLY BEING APPLIED TO SAMPLING HERE! + offset = np.random.uniform(-5, 5, size=3) # Random 3D offsets + translations.append((0, y_start, x_start) + offset) + + # Convert tiles to a Dask array + tiles = da.stack([tile.rechunk(chunks=final_chunks) for tile in tiles]) + + # Save back into ZarrNii object + znimg.darr = tiles + translations = np.array(translations) + + return znimg, translations + + + + +test_znimg, test_translations = create_test_dataset(grid_shape=snakemake.params.grid_shape) +print(test_znimg.darr.shape) +test_znimg.to_ome_zarr(snakemake.output.ome_zarr) +np.save(snakemake.output.translations_npy,test_translations) + + diff --git a/dask-stitch/scripts/create_test_dataset_singletile.py b/dask-stitch/scripts/create_test_dataset_singletile.py new file mode 100644 index 0000000..b8d1d8e --- /dev/null +++ b/dask-stitch/scripts/create_test_dataset_singletile.py @@ -0,0 +1,101 @@ +import nibabel as nib +import numpy as np +from templateflow import api as tflow +import nibabel as nib +import dask.array as da +import math +from zarrnii import ZarrNii + +def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2, grid_shape=(3, 4), overlap=8, random_seed=42, final_chunks=(1,32, 32, 1)): + """ + Create a low-resolution test dataset for tile-based stitching. + + Parameters: + - tile_index: the index of the tile to create + - template (str): TemplateFlow template name (default: MNI152NLin2009cAsym). + - res (int): Desired resolution in mm (default: 2mm). + - grid_shape (tuple): Shape of the tiling grid (e.g., (3, 4) for 3x4 grid in X-Y). + - overlap (int): Overlap between tiles in X-Y plane in voxels (default: 8). + - random_seed (int): Seed for reproducible random offsets. + - final_chunks (tuple): Desired chunks for final tiles. + + Returns: + - znimg (ZarrNii): ZarrNii object containing the tiles. + - translations (np.ndarray): Array of random offsets for each tile. + """ + import math + + # Seed the random number generator + np.random.seed(random_seed) + + # Download template and load as a ZarrNii object + template_path = tflow.get(template, resolution=res, desc=None, suffix="T1w") + img_data = nib.load(template_path).get_fdata() + + # Original image shape + img_shape = np.array(img_data.shape) # (Z, Y, X) + + # Keep Z dimension intact, calculate X and Y tile sizes + x_dim, y_dim, z_dim = img_shape + x_tile_size = math.ceil((x_dim + overlap * (grid_shape[0] - 1)) / grid_shape[0]) + y_tile_size = math.ceil((y_dim + overlap * (grid_shape[1] - 1)) / grid_shape[1]) + + # Calculate required padding to ensure X and Y dimensions are divisible by grid shape + padded_x = x_tile_size * grid_shape[0] - overlap * (grid_shape[0] - 1) + padded_y = y_tile_size * grid_shape[1] - overlap * (grid_shape[1] - 1) + + padding = ( + (0, int(max(padded_x - x_dim, 0))), + (0, int(max(padded_y - y_dim, 0))), + (0, 0), # No padding in Z + ) + + + # Pad image if needed + if any(p[1] > 0 for p in padding): + img_data = np.pad(img_data, padding, mode="constant", constant_values=0) + + # Create tiles + + x,y = np.unravel_index(tile_index,grid_shape) + + + # Calculate tile start indices + x_start = x * (x_tile_size - overlap) + y_start = y * (y_tile_size - overlap) + + # Extract tile + tile = img_data[x_start:x_start + x_tile_size, y_start:y_start + y_tile_size, :] + + # Add random offset - ensure that the random offset generated for the same tile is the same + # do this by gneerating + offset = np.random.uniform(-5, 5, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets + translation = ((x_start, y_start, 0) + offset[x,y,:]) + + print((x_start, y_start, 0)) + print(translation) + + tile_shape = (1,x_tile_size, y_tile_size, z_dim) + + + #save translation into vox2ras + vox2ras = np.eye(4) + vox2ras[:3,3] = np.array(translation) + + + # Save back into ZarrNii object + darr = da.from_array(tile.reshape(tile_shape),chunks=final_chunks) + + znimg = ZarrNii.from_darr(darr,vox2ras=vox2ras,axes_nifti=True) + + return znimg + + + + +test_znimg = create_test_dataset_single(tile_index=snakemake.params.tile_index, + grid_shape=snakemake.params.grid_shape) +test_znimg.to_ome_zarr(snakemake.output.ome_zarr) +test_znimg.to_nifti(snakemake.output.nifti) + + diff --git a/dask-stitch/scripts/get_tiles_as_nifti.py b/dask-stitch/scripts/get_tiles_as_nifti.py index 784bad5..d41b579 100644 --- a/dask-stitch/scripts/get_tiles_as_nifti.py +++ b/dask-stitch/scripts/get_tiles_as_nifti.py @@ -3,13 +3,10 @@ from pathlib import Path -znimg_example_tile= ZarrNii.from_path(snakemake.input.ome_zarr) -print(znimg_full.darr.shape) - out_dir = Path(snakemake.output.tiles_dir) out_dir.mkdir(exist_ok=True, parents=True) -for tile in range(znimg_full.darr.shape[0]): +for tile in range(snakemake.params.n_tiles): print(f'reading tile {tile} and writing to nifti') ZarrNii.from_path(snakemake.input.ome_zarr,channels=[tile]).to_nifti(out_dir / f'tile_{tile:02d}.nii')