Skip to content

Commit

Permalink
FIX(psoct): propagate no_pool to pyramid generator + FIX(generate_pyr…
Browse files Browse the repository at this point in the history
…amid): do not crash if last chunk in a row only has a single voxel
  • Loading branch information
balbasty committed Nov 22, 2024
1 parent 7ecfc08 commit 41cbd1f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
12 changes: 10 additions & 2 deletions linc_convert/modalities/psoct/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,25 @@ def generate_pyramid(
slice(i * max_load, min((i + 1) * max_load, n))
for i, n in zip(chunk_index, prev_shape)
]
fullshape = omz[str(level - 1)].shape
dat = omz[str(level - 1)][tuple(slicer)]

# Discard the last voxel along odd dimensions
crop = [0 if x == 1 else x % 2 for x in dat.shape[-ndim:]]
crop = [
0 if y == 1 else x % 2
for x, y in zip(dat.shape[-ndim:], fullshape)
]
# Don't crop the axis not down-sampling
# cannot do if not no_pyramid_axis since it could be 0
if no_pyramid_axis is not None:
crop[no_pyramid_axis] = 0
slcr = [slice(-1) if x else slice(None) for x in crop]
dat = dat[tuple([Ellipsis, *slcr])]

if any(n == 0 for n in dat.shape):
# last strip had a single voxel, nothing to do
continue

patch_shape = dat.shape[-ndim:]

# Reshape into patches of shape 2x2x2
Expand All @@ -234,7 +242,7 @@ def generate_pyramid(
# -> flatten patches
smaller_shape = [max(n // 2, 1) for n in patch_shape]
if no_pyramid_axis is not None:
smaller_shape[2 * no_pyramid_axis] = patch_shape[no_pyramid_axis]
smaller_shape[no_pyramid_axis] = patch_shape[no_pyramid_axis]

dat = dat.reshape(batch + smaller_shape + [-1])

Expand Down
2 changes: 1 addition & 1 deletion linc_convert/modalities/psoct/multi_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def convert(

inp[i] = None # no ref count -> delete array

generate_pyramid(omz, nblevels - 1, mode="mean")
generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool)

print("")

Expand Down
2 changes: 1 addition & 1 deletion linc_convert/modalities/psoct/single_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def convert(
i * inp_chunk[2]: i * inp_chunk[2] + loaded_chunk.shape[2],
] = loaded_chunk

generate_pyramid(omz, nblevels - 1, mode="mean")
generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool)

print("")

Expand Down

0 comments on commit 41cbd1f

Please sign in to comment.