diff --git a/ome_zarr/reader.py b/ome_zarr/reader.py index ef79845b..a2a04f77 100644 --- a/ome_zarr/reader.py +++ b/ome_zarr/reader.py @@ -7,7 +7,6 @@ import dask.array as da import numpy as np -from dask import delayed from .axes import Axes from .format import format_from_version @@ -420,39 +419,34 @@ def __init__(self, node: Node) -> None: self.img_metadata = image_node.metadata self.img_pyramid_shapes = [d.shape for d in image_node.data] - def get_field(tile_name: str, level: int) -> np.ndarray: + def get_field(row: int, col: int, level: int) -> da.core.Array: """tile_name is 'row,col'""" - row, col = (int(n) for n in tile_name.split(",")) field_index = (column_count * row) + col - image_path = image_paths[field_index] - path = f"{image_path}/{level}" - LOGGER.debug("LOADING tile... %s", path) + data = None try: - data = self.zarr.load(path) + # handle e.g. 2x2 grid with only 3 images/fields + if field_index < len(image_paths): + image_path = image_paths[field_index] + path = f"{image_path}/{level}" + data = self.zarr.load(path) except ValueError: LOGGER.error("Failed to load %s", path) - data = np.zeros(self.img_pyramid_shapes[level], dtype=self.numpy_type) + if data is None: + data = da.zeros(self.img_pyramid_shapes[level], dtype=self.numpy_type) return data - lazy_reader = delayed(get_field) - def get_lazy_well(level: int, tile_shape: tuple) -> da.Array: lazy_rows = [] for row in range(row_count): lazy_row: List[da.Array] = [] for col in range(column_count): - tile_name = f"{row},{col}" LOGGER.debug( "creating lazy_reader. row: %s col: %s level: %s", row, col, level, ) - lazy_tile = da.from_delayed( - lazy_reader(tile_name, level), - shape=tile_shape, - dtype=self.numpy_type, - ) + lazy_tile = get_field(row, col, level) lazy_row.append(lazy_tile) lazy_rows.append(da.concatenate(lazy_row, axis=x_index)) return da.concatenate(lazy_rows, axis=y_index) @@ -535,31 +529,25 @@ def get_tile_path(self, level: int, row: int, col: int) -> str: def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array: LOGGER.debug("get_stitched_grid() level: %s, tile_shape: %s", level, tile_shape) - def get_tile(tile_name: str) -> np.ndarray: + def get_tile(row: int, col: int) -> da.core.Array: """tile_name is 'level,z,c,t,row,col'""" - row, col = (int(n) for n in tile_name.split(",")) path = self.get_tile_path(level, row, col) - LOGGER.debug("LOADING tile... %s with shape: %s", path, tile_shape) + LOGGER.debug("creating tile... %s with shape: %s", path, tile_shape) try: + # this is a dask array - data not loaded from source yet data = self.zarr.load(path) except ValueError: LOGGER.exception("Failed to load %s", path) - data = np.zeros(tile_shape, dtype=self.numpy_type) + data = da.zeros(tile_shape, dtype=self.numpy_type) return data - lazy_reader = delayed(get_tile) - lazy_rows = [] # For level 0, return whole image for each tile for row in range(self.row_count): lazy_row: List[da.Array] = [] for col in range(self.column_count): - tile_name = f"{row},{col}" - lazy_tile = da.from_delayed( - lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type - ) - lazy_row.append(lazy_tile) + lazy_row.append(get_tile(row, col)) lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1)) return da.concatenate(lazy_rows, axis=len(self.axes) - 2) diff --git a/tests/test_reader.py b/tests/test_reader.py index 3520eabc..16d9a422 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -1,3 +1,4 @@ +import dask.array as da import numpy as np import pytest import zarr @@ -5,7 +6,7 @@ from ome_zarr.data import create_zarr from ome_zarr.io import parse_url -from ome_zarr.reader import Node, Plate, Reader +from ome_zarr.reader import Node, Plate, Reader, Well from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata @@ -95,16 +96,32 @@ def test_multiwells_plate(self, field_paths): nodes = list(reader()) # currently reading plate labels disabled. Only 1 node assert len(nodes) == 1 + plate_node = nodes[0] assert len(plate_node.specs) == 1 assert isinstance(plate_node.specs[0], Plate) + # data should be a Dask array + pyramid = plate_node.data + assert isinstance(pyramid[0], da.Array) + # if we compute(), expect to get numpy array + result = pyramid[0].compute() + assert isinstance(result, np.ndarray) + # Get the plate node's array. It should be fused from the first field of all # well arrays (which in this test are non-zero), with zero values for wells # that failed to load (not expected) or the surplus area not filled by a well. expected_num_pixels = ( len(well_paths) * len(field_paths[:1]) * np.prod((1, 1, 1, 256, 256)) ) - pyramid_0 = plate_node.data[0] + pyramid_0 = pyramid[0] assert np.asarray(pyramid_0).sum() == expected_num_pixels - # assert len(nodes[1].specs) == 1 + # assert isinstance(nodes[1].specs[0], PlateLabels) + + reader = Reader(parse_url(f"{self.path}/{well_paths[0]}")) + nodes = list(reader()) + assert isinstance(nodes[0].specs[0], Well) + pyramid = nodes[0].data + assert isinstance(pyramid[0], da.Array) + result = pyramid[0].compute() + assert isinstance(result, np.ndarray)