From eb88d93af57d8382cef60a064ab3cac72f440ed6 Mon Sep 17 00:00:00 2001 From: Kaidong Chai Date: Fri, 15 Nov 2024 15:46:18 -0500 Subject: [PATCH] feat: generate pyramid without pooling specific axis --- .../modalities/psoct/single_volume.py | 2 +- linc_convert/modalities/psoct/utils.py | 46 +++++++++++++------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/linc_convert/modalities/psoct/single_volume.py b/linc_convert/modalities/psoct/single_volume.py index ecf8915..a1e545f 100644 --- a/linc_convert/modalities/psoct/single_volume.py +++ b/linc_convert/modalities/psoct/single_volume.py @@ -170,7 +170,7 @@ def convert( j * inp_chunk[1]: j * inp_chunk[1] + loaded_chunk.shape[1], i * inp_chunk[2]: i * inp_chunk[2] + loaded_chunk.shape[2], ] = loaded_chunk - # TODO: no_pool is ignored for now, should add back + generate_pyramid(omz, nblevels - 1, mode="mean") print("") diff --git a/linc_convert/modalities/psoct/utils.py b/linc_convert/modalities/psoct/utils.py index c6c3fed..7a5daff 100644 --- a/linc_convert/modalities/psoct/utils.py +++ b/linc_convert/modalities/psoct/utils.py @@ -112,6 +112,7 @@ def generate_pyramid( ndim: int = 3, max_load: int = 512, mode: Literal["mean", "median"] = "median", + no_pyramid_axis: int|str|None = None ) -> list[list[int]]: """ Generate the levels of a pyramid in an existing Zarr. @@ -143,9 +144,16 @@ def generate_pyramid( Shapes of all levels, from finest to coarsest, including the existing top level. """ - + # Read properties from base level shape = list(omz["0"].shape) chunk_size = omz["0"].chunks + opt = { + "dimension_separator": omz["0"]._dimension_separator, + "order": omz["0"]._order, + "dtype": omz["0"]._dtype, + "fill_value": omz["0"]._fill_value, + "compressor": omz["0"]._compressor, + } # Select windowing function if mode == "median": @@ -158,19 +166,16 @@ def generate_pyramid( batch, shape = shape[:-ndim], shape[-ndim:] allshapes = [shape] - opt = { - "dimension_separator": omz["0"]._dimension_separator, - "order": omz["0"]._order, - "dtype": omz["0"]._dtype, - "fill_value": omz["0"]._fill_value, - "compressor": omz["0"]._compressor, - } - while True: level += 1 # Compute downsampled shape - prev_shape, shape = shape, [max(1, x // 2) for x in shape] + prev_shape, shape = shape, [] + for i, length in enumerate(prev_shape): + if i == no_pyramid_axis: + shape.append(length) + else: + shape.append(max(1, length // 2)) # Stop if seen enough levels or level shape smaller than chunk size if levels is None: @@ -198,16 +203,24 @@ def generate_pyramid( dat = omz[str(level - 1)][tuple(slicer)] # Discard the last voxel along odd dimensions - crop = [0 if x == 1 else x % 2 for x in dat.shape[-3:]] + crop = [0 if x == 1 else x % 2 for x in dat.shape[-ndim:]] + # Don't crop the axis not down-sampling + # cannot do if not no_pyramid_axis since it could be 0 + if no_pyramid_axis is not None: + crop[no_pyramid_axis] = 0 slcr = [slice(-1) if x else slice(None) for x in crop] dat = dat[tuple([Ellipsis, *slcr])] - patch_shape = dat.shape[-3:] + patch_shape = dat.shape[-ndim:] # Reshape into patches of shape 2x2x2 windowed_shape = [ x for n in patch_shape for x in (max(n // 2, 1), min(n, 2)) ] + if no_pyramid_axis is not None: + windowed_shape[2*no_pyramid_axis] = patch_shape[no_pyramid_axis] + windowed_shape[2 * no_pyramid_axis+1] = 1 + dat = dat.reshape(batch + windowed_shape) # -> last `ndim`` dimensions have shape 2x2x2 dat = dat.transpose( @@ -217,6 +230,9 @@ def generate_pyramid( ) # -> flatten patches smaller_shape = [max(n // 2, 1) for n in patch_shape] + if no_pyramid_axis is not None: + smaller_shape[2 * no_pyramid_axis] = patch_shape[no_pyramid_axis] + dat = dat.reshape(batch + smaller_shape + [-1]) # Compute the median/mean of each patch @@ -227,7 +243,9 @@ def generate_pyramid( # Write output slicer = [Ellipsis] + [ slice(i * max_load // 2, min((i + 1) * max_load // 2, n)) - for i, n in zip(chunk_index, shape) + if axis_index != no_pyramid_axis else + slice(i * max_load, min((i + 1) * max_load, n)) + for i, axis_index, n in zip(chunk_index, range(ndim), shape) ] omz[str(level)][tuple(slicer)] = dat @@ -357,7 +375,7 @@ def write_ome_metadata( } ] - shape = shape0 = shapes[0] + shape0 = shapes[0] for n in range(len(shapes)): shape = shapes[n] multiscales[0]["datasets"].append({})