Skip to content

Commit

Permalink
Multicore code commented out
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorTatarnikov committed Jan 5, 2024
1 parent 8dc8e2e commit ebfdf1c
Showing 1 changed file with 200 additions and 63 deletions.
263 changes: 200 additions & 63 deletions mesospim_stitcher/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import h5py
import numpy as np
import zarr
from mpi4py import MPI
from numcodecs import Blosc
from ome_zarr.io import parse_url
from ome_zarr.writer import write_image
Expand Down Expand Up @@ -77,12 +76,13 @@ def fuse_image(

input_file = h5py.File(input_path, "r")
group = input_file["t00000"]
tiles = [da.from_array(group[f"{child}/0/cells"]) for child in group]

z_size = tiles[0].shape[0]
x_y_size = tiles[0].shape[1]
tile_names = list(group.keys())
max_delta = [max([abs(delta[i]) for delta in deltas]) for i in range(3)]

tile = da.from_array(group[f"{tile_names[0]}/0/cells"])
z_size = tile.shape[0]
x_y_size = tile.shape[1]

translations = []

for i in range(len(deltas)):
Expand All @@ -98,27 +98,78 @@ def fuse_image(

translations.append([x_start, x_end, y_start, y_end, z_start, z_end])

new_image = da.zeros(
(
max(translation[5] for translation in translations),
max(translation[3] for translation in translations),
max(translation[1] for translation in translations),
),
dtype="uint16",
# new_image = da.zeros(
# (
# max(translation[5] for translation in translations),
# max(translation[3] for translation in translations),
# max(translation[1] for translation in translations),
# ),
# dtype="int16",
# )

fused_image_shape = (
max(translation[5] for translation in translations),
max(translation[3] for translation in translations),
max(translation[1] for translation in translations),
)

for i in range(len(tiles) - 1, -1, -1):
curr_tile = tiles[i]
num_tiles = len(tile_names)

output_file = h5py.File(output_path, mode="w", compression="lzf")
ds = output_file.require_dataset(
"t00000/s00/0/cells", shape=fused_image_shape, dtype="i2"
)

square_root_cpu = 4

x_y_split_size = x_y_size // square_root_cpu

x_y_borders = [0]

for j in range(1, square_root_cpu):
x_y_borders.append(x_y_borders[j - 1] + x_y_split_size)

x_y_borders.append(x_y_size)

for i in range(num_tiles - 1, -1, -1):
# for rank in range(square_root_cpu ** 2):
# x_tile_s = x_y_borders[rank % square_root_cpu]
# x_tile_e = x_y_borders[rank % square_root_cpu + 1]
# y_tile_s = x_y_borders[rank // square_root_cpu]
# y_tile_e = x_y_borders[(rank // square_root_cpu + 1)]
# curr_tile = group[f"{tile_names[i]}/0/cells"]
# [:, y_tile_s:y_tile_e, x_tile_s:x_tile_e]
#
# x_s, x_e, y_s, y_e, z_s, z_e = translations[i]
# x_e = x_s + x_tile_e
# x_s = x_s + x_tile_s
# y_e = y_s + y_tile_e
# y_s = y_s + y_tile_s
#
# with ds.collective:
# ds[z_s:z_e, y_s:y_e, x_s:x_e] = curr_tile
#
# print(f"Done tile {tile_names[i]} part {rank}")

x_s, x_e, y_s, y_e, z_s, z_e = translations[i]
new_image[z_s:z_e, y_s:y_e, x_s:x_e] = curr_tile

try:
# write_ome_zarr(output_path, new_image, overwrite)
write_hdf5(output_path, new_image, overwrite)
except Exception as e:
raise e
finally:
input_file.close()
curr_tile = group[f"{tile_names[i]}/0/cells"]

# new_image[z_s:z_e, y_s:y_e, x_s:x_e] = curr_tile
ds[z_s:z_e, y_s:y_e, x_s:x_e] = curr_tile

output_file.close()
input_file.close()

write_bdv_xml(Path("testing.xml"), xml_path, output_path, ds.shape)

# try:
# write_ome_zarr(output_path, new_image, overwrite)
# # write_hdf5(output_path, new_image, overwrite)
# except Exception as e:
# raise e
# finally:
# input_file.close()


def write_ome_zarr(output_path: Path, image: da, overwrite: bool):
Expand Down Expand Up @@ -149,7 +200,7 @@ def write_ome_zarr(output_path: Path, image: da, overwrite: bool):
[{"type": "scale", "scale": [10.0, 41.6, 41.6]}],
],
storage_options=dict(
chunks=(2, image.shape[1], image.shape[2]),
chunks=(4, image.shape[1], image.shape[2]),
compressor=compressor,
),
)
Expand All @@ -173,7 +224,7 @@ def write_hdf5(output_path: Path, image: da, overwrite: bool):
[[1, 1, 1], [2, 2, 2], [4, 4, 4], [8, 8, 8], [16, 16, 16]]
)

rank = MPI.COMM_WORLD.rank
# rank = MPI.COMM_WORLD.rank
#
# print(f"Rank {rank} starting to write")
f = h5py.File(output_path, "w")
Expand All @@ -193,63 +244,149 @@ def write_hdf5(output_path: Path, image: da, overwrite: bool):

data_group: h5py.Group = f.create_group("t00000/s00")

orig_image = data_group.create_dataset(
data_group.create_dataset(
"0/cells",
data=image,
chunks=(16, 32, 32),
dtype="uint16",
shape=image.shape,
)

image_shape = image.shape
x_split_size = image_shape[2] // 4
y_split_size = image_shape[1] // 3

x_borders = [
0,
x_split_size,
x_split_size * 2,
x_split_size * 3,
image_shape[2],
]
y_borders = [0, y_split_size, y_split_size * 2, image_shape[1]]

x_start = x_borders[rank % 4]
x_end = x_borders[rank % 4 + 1]
y_start = y_borders[rank // 4]
y_end = y_borders[(rank // 4 + 1)]

orig_image[:, y_start:y_end, x_start:x_end] = image[
:, y_start:y_end, x_start:x_end
]

for i in range(1, resolutions.shape[0]):
prev_resolution = data_group[f"{i-1}/cells"][
:, y_start:y_end, x_start:x_end
]
data_group.require_dataset(
f"{i}/cells",
data=prev_resolution[::2, ::2, ::2],
chunks=(16, 32, 32),
compression="gzip",
dtype="uint16",
shape=prev_resolution[::2, ::2, ::2].shape,
)
# image_shape = image.shape
# x_split_size = image_shape[2] // 4
# y_split_size = image_shape[1] // 3
#
# x_borders = [
# 0,
# x_split_size,
# x_split_size * 2,
# x_split_size * 3,
# image_shape[2],
# ]
# y_borders = [0, y_split_size, y_split_size * 2, image_shape[1]]
#
# x_start = x_borders[rank % 4]
# x_end = x_borders[rank % 4 + 1]
# y_start = y_borders[rank // 4]
# y_end = y_borders[(rank // 4 + 1)]
#
# orig_image[:, y_start:y_end, x_start:x_end] = image[
# :, y_start:y_end, x_start:x_end
# ]
#
# for i in range(1, resolutions.shape[0]):
# prev_resolution = data_group[f"{i-1}/cells"][
# :, y_start:y_end, x_start:x_end
# ]
# data_group.require_dataset(
# f"{i}/cells",
# data=prev_resolution[::2, ::2, ::2],
# chunks=(16, 32, 32),
# compression="gzip",
# dtype="uint16",
# shape=prev_resolution[::2, ::2, ::2].shape,
# )

f.close()

# print(f"Rank {rank} finished writing")

def write_bdv_xml(
output_xml_path: Path,
input_xml_path: Path,
hdf5_path: Path,
image_size: tuple,
):
input_tree = ET.parse(input_xml_path)
input_root = input_tree.getroot()

generated_by = input_root.find(".//generatedBy")
base_path = input_root.find(".//BasePath")

root = ET.Element("SpimData", version="0.2")
assert (
generated_by is not None
), "No generatedBy tag found in the input XML file"
assert base_path is not None, "No BasePath tag found in the input XML file"
root.append(generated_by)
root.append(base_path)

sequence_desc = ET.SubElement(root, "SequenceDescription")
image_loader = input_root.find(".//ImageLoader")
assert (
image_loader is not None
), "No ImageLoader tag found in the input XML file"

hdf5_path_node = image_loader.find(".//hdf5")
assert (
hdf5_path_node is not None
), "No hdf5 tag found in the input XML file"
hdf5_path_node.text = str(hdf5_path)
sequence_desc.append(image_loader)

view_setup = input_root.find(".//ViewSetup")
assert (
view_setup is not None
), "No ViewSetup tag found in the input XML file"
view_setup[3].text = f"{image_size[2]} {image_size[1]} {image_size[0]}"

view_setups = ET.SubElement(sequence_desc, "ViewSetups")
view_setups.append(view_setup)

attributes_illumination = input_root.find(
".//Attributes[@name='illumination']"
)
assert (
attributes_illumination is not None
), "No illumination attributes found in the input XML file"
view_setups.append(attributes_illumination)

attributes_channel = input_root.find(".//Attributes[@name='channel']")
assert (
attributes_channel is not None
), "No channel attributes found in the input XML file"
view_setups.append(attributes_channel)

attributes_tiles = ET.SubElement(view_setups, "Attributes", name="tile")
tile = input_root.find(".//Tile/[id='0']")
assert tile is not None, "No Tile tag found in the input XML file"
attributes_tiles.append(tile)

attributes_angles = input_root.find(".//Attributes[@name='angle']")
assert (
attributes_angles is not None
), "No angle attributes found in the input XML file"
view_setups.append(attributes_angles)

timepoints = input_root.find(".//Timepoints")
assert (
timepoints is not None
), "No Timepoints tag found in the input XML file"
missing_views = input_root.find(".//MissingViews")
assert (
missing_views is not None
), "No MissingViews tag found in the input XML file"

sequence_desc.append(timepoints)
sequence_desc.append(missing_views)

tree = ET.ElementTree(root)
ET.indent(tree, space=" ")
tree.write(output_xml_path, encoding="utf-8", xml_declaration=True)

return


if __name__ == "__main__":
xml_path = Path(
"/mnt/Data/TiledDataset/2.5x_tile/"
"/home/igor/NIU-dev/stitching_dataset/"
"One_Channel/2.5x_tile_igor_rightonly_Mag2.5x_"
"ch488_ch561_ch647_bdv.xml"
)
input_path = Path("/mnt/Data/TiledDataset/2.5x_tile/One_Channel/test.h5")
input_path = Path(
"/home/igor/NIU-dev/stitching_dataset/One_Channel/test.h5"
)
output_path = Path(
"/mnt/Data/TiledDataset/2.5x_tile/One_Channel/test_out.h5"
"/home/igor/NIU-dev/stitching_dataset/One_Channel/test_out.zarr"
)

fuse_image(xml_path, input_path, output_path)

0 comments on commit ebfdf1c

Please sign in to comment.