Skip to content

Commit

Permalink
feat: generate pyramid without pooling specific axis
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinchai committed Nov 15, 2024
1 parent 6d98272 commit eb88d93
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion linc_convert/modalities/psoct/single_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def convert(
j * inp_chunk[1]: j * inp_chunk[1] + loaded_chunk.shape[1],
i * inp_chunk[2]: i * inp_chunk[2] + loaded_chunk.shape[2],
] = loaded_chunk
# TODO: no_pool is ignored for now, should add back

generate_pyramid(omz, nblevels - 1, mode="mean")

print("")
Expand Down
46 changes: 32 additions & 14 deletions linc_convert/modalities/psoct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def generate_pyramid(
ndim: int = 3,
max_load: int = 512,
mode: Literal["mean", "median"] = "median",
no_pyramid_axis: int|str|None = None
) -> list[list[int]]:
"""
Generate the levels of a pyramid in an existing Zarr.
Expand Down Expand Up @@ -143,9 +144,16 @@ def generate_pyramid(
Shapes of all levels, from finest to coarsest, including the
existing top level.
"""

# Read properties from base level
shape = list(omz["0"].shape)
chunk_size = omz["0"].chunks
opt = {
"dimension_separator": omz["0"]._dimension_separator,
"order": omz["0"]._order,
"dtype": omz["0"]._dtype,
"fill_value": omz["0"]._fill_value,
"compressor": omz["0"]._compressor,
}

# Select windowing function
if mode == "median":
Expand All @@ -158,19 +166,16 @@ def generate_pyramid(
batch, shape = shape[:-ndim], shape[-ndim:]
allshapes = [shape]

opt = {
"dimension_separator": omz["0"]._dimension_separator,
"order": omz["0"]._order,
"dtype": omz["0"]._dtype,
"fill_value": omz["0"]._fill_value,
"compressor": omz["0"]._compressor,
}

while True:
level += 1

# Compute downsampled shape
prev_shape, shape = shape, [max(1, x // 2) for x in shape]
prev_shape, shape = shape, []
for i, length in enumerate(prev_shape):
if i == no_pyramid_axis:
shape.append(length)
else:
shape.append(max(1, length // 2))

# Stop if seen enough levels or level shape smaller than chunk size
if levels is None:
Expand Down Expand Up @@ -198,16 +203,24 @@ def generate_pyramid(
dat = omz[str(level - 1)][tuple(slicer)]

# Discard the last voxel along odd dimensions
crop = [0 if x == 1 else x % 2 for x in dat.shape[-3:]]
crop = [0 if x == 1 else x % 2 for x in dat.shape[-ndim:]]
# Don't crop the axis not down-sampling
# cannot do if not no_pyramid_axis since it could be 0
if no_pyramid_axis is not None:
crop[no_pyramid_axis] = 0
slcr = [slice(-1) if x else slice(None) for x in crop]
dat = dat[tuple([Ellipsis, *slcr])]

patch_shape = dat.shape[-3:]
patch_shape = dat.shape[-ndim:]

# Reshape into patches of shape 2x2x2
windowed_shape = [
x for n in patch_shape for x in (max(n // 2, 1), min(n, 2))
]
if no_pyramid_axis is not None:
windowed_shape[2*no_pyramid_axis] = patch_shape[no_pyramid_axis]
windowed_shape[2 * no_pyramid_axis+1] = 1

dat = dat.reshape(batch + windowed_shape)
# -> last `ndim`` dimensions have shape 2x2x2
dat = dat.transpose(
Expand All @@ -217,6 +230,9 @@ def generate_pyramid(
)
# -> flatten patches
smaller_shape = [max(n // 2, 1) for n in patch_shape]
if no_pyramid_axis is not None:
smaller_shape[2 * no_pyramid_axis] = patch_shape[no_pyramid_axis]

dat = dat.reshape(batch + smaller_shape + [-1])

# Compute the median/mean of each patch
Expand All @@ -227,7 +243,9 @@ def generate_pyramid(
# Write output
slicer = [Ellipsis] + [
slice(i * max_load // 2, min((i + 1) * max_load // 2, n))
for i, n in zip(chunk_index, shape)
if axis_index != no_pyramid_axis else
slice(i * max_load, min((i + 1) * max_load, n))
for i, axis_index, n in zip(chunk_index, range(ndim), shape)
]

omz[str(level)][tuple(slicer)] = dat
Expand Down Expand Up @@ -357,7 +375,7 @@ def write_ome_metadata(
}
]

shape = shape0 = shapes[0]
shape0 = shapes[0]
for n in range(len(shapes)):
shape = shapes[n]
multiscales[0]["datasets"].append({})
Expand Down

0 comments on commit eb88d93

Please sign in to comment.