diff --git a/linc_convert/modalities/lsm/mosaic.py b/linc_convert/modalities/lsm/mosaic.py index f3e7f81..54743a1 100644 --- a/linc_convert/modalities/lsm/mosaic.py +++ b/linc_convert/modalities/lsm/mosaic.py @@ -11,7 +11,8 @@ import re from glob import glob from typing import Literal -from concurrent.futures import ThreadPoolExecutor, as_completed +from multiprocessing import Pool +from functools import partial # externals import cyclopts @@ -31,8 +32,8 @@ mosaic = cyclopts.App(name="mosaic", help_format="markdown") lsm.command(mosaic) -def write_plane(tswriter, subc, zstart, subz, ystart, yx_shape, dat): - """Write a single plane of data into the Zarr file.""" +def write_plane_multiprocess(tswriter, subc, zstart, subz, ystart, yx_shape, dat): + """Write a single plane of data into the Zarr file (multiprocessing).""" try: with ts.Transaction() as txn: tswriter.with_transaction(txn)[ @@ -45,6 +46,16 @@ def write_plane(tswriter, subc, zstart, subz, ystart, yx_shape, dat): print(f"Error writing plane: {e}") raise +def monitor_and_wait(futures): + """Monitor resource usage and wait for completion of a batch of futures.""" + print("\nProcessing batch...") + for future in as_completed(futures): + try: + future.result() # Raise any exceptions from the task + except Exception as e: + print(f"Error in parallel write: {e}") + futures.clear() + @mosaic.default def convert( inp: str, @@ -262,42 +273,34 @@ def convert( tswriter = ts.open(wconfig).result() - with ThreadPoolExecutor() as executor: - futures = [] - - for i, dirname in enumerate(all_chunks_info["dirname"]): - chunkz = all_chunks_info["z"][i] - 1 - chunky = all_chunks_info["y"][i] - 1 - planes = all_chunks_info["planes"][i] - - for j, fname in enumerate(planes["fname"]): - subz = planes["z"][j] - 1 - subc = planes["c"][j] - 1 - yx_shape = planes["yx_shape"][j] - - zstart = sum(shape[0][0] for shape in allshapes[:chunkz]) - ystart = sum( - shape[1] for subshapes in allshapes for shape in subshapes[:chunky] - ) - - print( - f"Queueing write plane ({subc:4d}, {zstart + subz:4d}, " - f"{ystart:4d}:{ystart + yx_shape[0]:4d})", - end="\r", - ) - - # Load data and submit the write task - dat = TiffFile(fname).asarray() - futures.append( - executor.submit(write_plane, tswriter, subc, zstart, subz, ystart, yx_shape, dat) - ) - - # Wait for all tasks to complete - for future in as_completed(futures): - try: - future.result() # Raise any exceptions from the task - except Exception as e: - print(f"Error in parallel write: {e}") + tasks = [] + for i, dirname in enumerate(all_chunks_info["dirname"]): + chunkz = all_chunks_info["z"][i] - 1 + chunky = all_chunks_info["y"][i] - 1 + planes = all_chunks_info["planes"][i] + + for j, fname in enumerate(planes["fname"]): + subz = planes["z"][j] - 1 + subc = planes["c"][j] - 1 + yx_shape = planes["yx_shape"][j] + + zstart = sum(shape[0][0] for shape in allshapes[:chunkz]) + ystart = sum( + shape[1] for subshapes in allshapes for shape in subshapes[:chunky] + ) + + print( + f"Queueing write plane ({subc:4d}, {zstart + subz:4d}, " + f"{ystart:4d}:{ystart + yx_shape[0]:4d})", + end="\r", + ) + + dat = TiffFile(fname).asarray() + tasks.append((subc, zstart, subz, ystart, yx_shape, dat)) + + write_func = partial(write_plane_multiprocess, tswriter) + with Pool(processes=8) as pool: + pool.starmap(write_func, tasks) print("")