Skip to content

Commit

Permalink
WIP export to bdv h5
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorTatarnikov committed Feb 6, 2024
1 parent 0438d6e commit cdcdd8f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 25 deletions.
10 changes: 3 additions & 7 deletions mesospim_stitcher/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,11 @@ def write_bdv_xml(
input_tree = ET.parse(input_xml_path)
input_root = input_tree.getroot()

generated_by = input_root.find(".//generatedBy")
base_path = input_root.find(".//BasePath")

root = ET.Element("SpimData", version="0.2")
assert (
generated_by is not None
), "No generatedBy tag found in the input XML file"

assert base_path is not None, "No BasePath tag found in the input XML file"
root.append(generated_by)
root.append(base_path)

sequence_desc = ET.SubElement(root, "SequenceDescription")
Expand All @@ -123,14 +119,14 @@ def write_bdv_xml(
assert (
hdf5_path_node is not None
), "No hdf5 tag found in the input XML file"
hdf5_path_node.text = str(hdf5_path)
hdf5_path_node.text = str(hdf5_path.name)
sequence_desc.append(image_loader)

view_setup = input_root.find(".//ViewSetup")
assert (
view_setup is not None
), "No ViewSetup tag found in the input XML file"
view_setup[3].text = f"{image_size[2]} {image_size[1]} {image_size[0]}"
view_setup[2].text = f"{image_size[2]} {image_size[1]} {image_size[0]}"

view_setups = ET.SubElement(sequence_desc, "ViewSetups")
view_setups.append(view_setup)
Expand Down
115 changes: 98 additions & 17 deletions mesospim_stitcher/image_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
create_pyramid_bdv_h5,
get_slice_attributes,
parse_mesospim_metadata,
write_bdv_xml,
)
from mesospim_stitcher.fuse import get_big_stitcher_transforms
from mesospim_stitcher.tile import Overlap, Tile
Expand Down Expand Up @@ -452,25 +453,31 @@ def fuse(
) -> None:
output_path = self.directory / output_file_name

z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape
fused_image_shape: Tuple[int, ...] = (
max([tile.position[0] for tile in self.tiles]) + z_size,
max([tile.position[1] for tile in self.tiles]) + y_size,
max([tile.position[2] for tile in self.tiles]) + x_size,
)

if normalise_intensity:
self.normalise_intensity(0, 80)

if interpolate:
self.interpolate_overlaps(0)

if output_path.suffix == ".zarr":
self.fuse_to_zarr(output_path, normalise_intensity, interpolate)
self.fuse_to_zarr(output_path, fused_image_shape)
elif output_path.suffix == ".h5":
self.fuse_to_bdv_h5(output_path, fused_image_shape)

def fuse_to_zarr(
self,
output_path: Path,
normalise_intensity: bool = False,
interpolate: bool = False,
self, output_path: Path, fused_image_shape: Tuple[int, ...]
) -> None:
z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape

output_slice_axis = 1

fused_image_shape: Tuple[int, ...] = (
max([tile.position[0] for tile in self.tiles]) + z_size,
max([tile.position[1] for tile in self.tiles]) + y_size,
max([tile.position[2] for tile in self.tiles]) + x_size,
)

chunk_shape_list = list(fused_image_shape)
chunk_shape_list[output_slice_axis] = 1
chunk_shape = tuple(chunk_shape_list)
Expand All @@ -483,12 +490,6 @@ def fuse_to_zarr(
fused_image_shape = (self.num_channels, *fused_image_shape)
chunk_shape = (self.num_channels, *chunk_shape)

if normalise_intensity:
self.normalise_intensity(0, 80)

if interpolate:
self.interpolate_overlaps(0)

store = zarr.NestedDirectoryStore(str(output_path))
root = zarr.group(store=store)
compressor = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE)
Expand Down Expand Up @@ -588,6 +589,86 @@ def fuse_to_zarr(

root.attrs["omero"] = {"channels": channels}

def fuse_to_bdv_h5(
self, output_path: Path, fused_image_shape: Tuple[int, ...]
) -> None:
z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape
output_file = h5py.File(output_path, mode="w")

subdivisions = np.array(
[
[32, 32, 16],
[32, 32, 16],
[32, 32, 16],
[32, 32, 16],
[32, 32, 16],
],
dtype=np.int16,
)
resolutions = np.array(
[[1, 1, 1], [2, 2, 1], [4, 4, 1], [8, 8, 1], [16, 16, 1]],
dtype=np.int16,
)

ds_list = []
for i in range(self.num_channels):
output_file.require_dataset(
f"s{i:02}/resolutions",
data=resolutions,
dtype="i2",
shape=resolutions.shape,
)
output_file.require_dataset(
f"s{i:02}/subdivisions",
data=subdivisions,
dtype="i2",
shape=subdivisions.shape,
)
ds = output_file.require_dataset(
f"t00000/s{i:02}/0/cells",
shape=fused_image_shape,
dtype="i2",
)
ds_list.append(ds)

for tile in self.tiles[-1::-1]:
ds_list[tile.channel_id][
tile.position[0] : tile.position[0] + z_size,
tile.position[1] : tile.position[1] + y_size,
tile.position[2] : tile.position[2] + x_size,
] = tile.data_pyramid[0].compute()
print(f"Done tile {tile.id}")

for i in range(1, len(resolutions)):
for j in range(self.num_channels):
prev_resolution = da.from_array(
output_file[f"t00000/s{j:02}/{i - 1}/cells"]
)

downsampled_image = downscale_nearest(
prev_resolution, (1, 2, 2)
)

downsampled_shape = downsampled_image.shape
downsampled_dataset = output_file.require_dataset(
f"t00000/s{j:02}/{i}/cells",
shape=downsampled_shape,
dtype="i2",
)
downsampled_dataset[...] = downsampled_image.compute()

print(f"Done resolution {i}")

assert self.xml_path is not None

write_bdv_xml(
output_path.with_suffix(".xml"),
self.xml_path,
output_path,
fused_image_shape,
)
output_file.close()

def get_metadata_for_zarr(self, pyramid_depth: int = 5):
axes = [
{"name": "z", "type": "space", "unit": "micrometer"},
Expand Down
19 changes: 18 additions & 1 deletion mesospim_stitcher/stitching_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,17 @@ def __init__(self, napari_viewer: Viewer):
self.fuse_option_widget.setLayout(QFormLayout())
self.normalise_intensity_toggle = QCheckBox()
self.interpolate_toggle = QCheckBox()
self.output_file_name_field = QLineEdit()

self.fuse_option_widget.layout().addRow(
"Normalise intensity:", self.normalise_intensity_toggle
)
self.fuse_option_widget.layout().addRow(
"Interpolate overlaps:", self.interpolate_toggle
)
self.fuse_option_widget.layout().addRow(
"Output file name:", self.output_file_name_field
)

self.layout().addWidget(self.fuse_option_widget)

Expand Down Expand Up @@ -308,9 +312,22 @@ def _on_interpolation_button_clicked(self):
return

def _on_fuse_button_clicked(self):
if not self.output_file_name_field.text():
show_warning("Output file name not specified")
return

if not (
self.output_file_name_field.text().endswith(".zarr")
or self.output_file_name_field.text().endswith(".h5")
):
show_warning(
"Output file name should either end with .zarr or .h5"
)
return

fuse(
self.image_mosaic,
"fused.zarr",
self.output_file_name_field.text(),
self.normalise_intensity_toggle.isChecked(),
self.interpolate_toggle.isChecked(),
)
Expand Down

0 comments on commit cdcdd8f

Please sign in to comment.