Skip to content

Commit

Permalink
Merge pull request #299 from will-moore/plate_dask_compute
Browse files Browse the repository at this point in the history
Fix compute() on plate grid
  • Loading branch information
will-moore authored Sep 29, 2023
2 parents ea18dd4 + 5417764 commit 03de064
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
42 changes: 15 additions & 27 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import dask.array as da
import numpy as np
from dask import delayed

from .axes import Axes
from .format import format_from_version
Expand Down Expand Up @@ -420,39 +419,34 @@ def __init__(self, node: Node) -> None:
self.img_metadata = image_node.metadata
self.img_pyramid_shapes = [d.shape for d in image_node.data]

def get_field(tile_name: str, level: int) -> np.ndarray:
def get_field(row: int, col: int, level: int) -> da.core.Array:
"""tile_name is 'row,col'"""
row, col = (int(n) for n in tile_name.split(","))
field_index = (column_count * row) + col
image_path = image_paths[field_index]
path = f"{image_path}/{level}"
LOGGER.debug("LOADING tile... %s", path)
data = None
try:
data = self.zarr.load(path)
# handle e.g. 2x2 grid with only 3 images/fields
if field_index < len(image_paths):
image_path = image_paths[field_index]
path = f"{image_path}/{level}"
data = self.zarr.load(path)
except ValueError:
LOGGER.error("Failed to load %s", path)
data = np.zeros(self.img_pyramid_shapes[level], dtype=self.numpy_type)
if data is None:
data = da.zeros(self.img_pyramid_shapes[level], dtype=self.numpy_type)
return data

lazy_reader = delayed(get_field)

def get_lazy_well(level: int, tile_shape: tuple) -> da.Array:
lazy_rows = []
for row in range(row_count):
lazy_row: List[da.Array] = []
for col in range(column_count):
tile_name = f"{row},{col}"
LOGGER.debug(
"creating lazy_reader. row: %s col: %s level: %s",
row,
col,
level,
)
lazy_tile = da.from_delayed(
lazy_reader(tile_name, level),
shape=tile_shape,
dtype=self.numpy_type,
)
lazy_tile = get_field(row, col, level)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=x_index))
return da.concatenate(lazy_rows, axis=y_index)
Expand Down Expand Up @@ -535,31 +529,25 @@ def get_tile_path(self, level: int, row: int, col: int) -> str:
def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array:
LOGGER.debug("get_stitched_grid() level: %s, tile_shape: %s", level, tile_shape)

def get_tile(tile_name: str) -> np.ndarray:
def get_tile(row: int, col: int) -> da.core.Array:
"""tile_name is 'level,z,c,t,row,col'"""
row, col = (int(n) for n in tile_name.split(","))
path = self.get_tile_path(level, row, col)
LOGGER.debug("LOADING tile... %s with shape: %s", path, tile_shape)
LOGGER.debug("creating tile... %s with shape: %s", path, tile_shape)

try:
# this is a dask array - data not loaded from source yet
data = self.zarr.load(path)
except ValueError:
LOGGER.exception("Failed to load %s", path)
data = np.zeros(tile_shape, dtype=self.numpy_type)
data = da.zeros(tile_shape, dtype=self.numpy_type)
return data

lazy_reader = delayed(get_tile)

lazy_rows = []
# For level 0, return whole image for each tile
for row in range(self.row_count):
lazy_row: List[da.Array] = []
for col in range(self.column_count):
tile_name = f"{row},{col}"
lazy_tile = da.from_delayed(
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
lazy_row.append(lazy_tile)
lazy_row.append(get_tile(row, col))
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)

Expand Down
23 changes: 20 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dask.array as da
import numpy as np
import pytest
import zarr
from numpy import ones, zeros

from ome_zarr.data import create_zarr
from ome_zarr.io import parse_url
from ome_zarr.reader import Node, Plate, Reader
from ome_zarr.reader import Node, Plate, Reader, Well
from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata


Expand Down Expand Up @@ -95,16 +96,32 @@ def test_multiwells_plate(self, field_paths):
nodes = list(reader())
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1

plate_node = nodes[0]
assert len(plate_node.specs) == 1
assert isinstance(plate_node.specs[0], Plate)
# data should be a Dask array
pyramid = plate_node.data
assert isinstance(pyramid[0], da.Array)
# if we compute(), expect to get numpy array
result = pyramid[0].compute()
assert isinstance(result, np.ndarray)

# Get the plate node's array. It should be fused from the first field of all
# well arrays (which in this test are non-zero), with zero values for wells
# that failed to load (not expected) or the surplus area not filled by a well.
expected_num_pixels = (
len(well_paths) * len(field_paths[:1]) * np.prod((1, 1, 1, 256, 256))
)
pyramid_0 = plate_node.data[0]
pyramid_0 = pyramid[0]
assert np.asarray(pyramid_0).sum() == expected_num_pixels
# assert len(nodes[1].specs) == 1

# assert isinstance(nodes[1].specs[0], PlateLabels)

reader = Reader(parse_url(f"{self.path}/{well_paths[0]}"))
nodes = list(reader())
assert isinstance(nodes[0].specs[0], Well)
pyramid = nodes[0].data
assert isinstance(pyramid[0], da.Array)
result = pyramid[0].compute()
assert isinstance(result, np.ndarray)

0 comments on commit 03de064

Please sign in to comment.