Skip to content

Commit

Permalink
Merge pull request #68 from thewtex/cucim
Browse files Browse the repository at this point in the history
DOC: Document CUDA acceleration
  • Loading branch information
thewtex authored Mar 4, 2024
2 parents 63da7c5 + 6e66321 commit c99e033
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 76 deletions.
19 changes: 17 additions & 2 deletions docs/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ intensity images.

The default method.

To use an NVIDIA CUDA GPU-accelerated version, install the
`itkwasm-downsample-cucim` package:

[Install cuCIM](https://github.com/rapidsai/cucim?tab=readme-ov-file#install-cucim),
then:

```sh
pip install itkwasm-downsample-cucim
```

And GPU-accelerated filtering is applied by default after installation.

## `ITKWASM_BIN_SHRINK`

Uses the [local mean] for the output value. [WebAssembly] build.
Expand All @@ -34,6 +46,9 @@ Fast but generates more artifacts than gaussian-based methods.

Appropriate for intensity images.

An NVIDIA CUDA GPU-accelerated version can be installed similar to
`ITKWASM_GAUSSIAN` above.

## `ITKWASM_LABEL_IMAGE`

A sample is the mode of the linearly weighted [local labels] in the image.
Expand All @@ -56,7 +71,8 @@ To use a GPU-accelerated version, install the `itk-vkfft` package:
pip install itk-vkfft
```

And GPU-accelerated, FFT-based filtering is applied by default after installation.
And GPU-accelerated, FFT-based filtering is applied by default after
installation.

## `ITK_BIN_SHRINK`

Expand Down Expand Up @@ -105,7 +121,6 @@ Install required dependencies with:
pip install "ngff-zarr[dask-image]"
```


[aliasing artifacts]:
https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem
[dask-image]: https://image.dask.org/
Expand Down
2 changes: 1 addition & 1 deletion ngff_zarr/methods/_itkwasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _downsample_itkwasm_bin_shrink(
origin = [previous_image.translation[d] for d in spatial_dims]
block_input.origin = origin
block_output = downsample_bin_shrink(
block_input, shrink_factors, information_only=False
block_input, shrink_factors, information_only=True
)
scale = {_image_dims[i]: s for (i, s) in enumerate(block_output.spacing)}
translation = {_image_dims[i]: s for (i, s) in enumerate(block_output.origin)}
Expand Down
34 changes: 34 additions & 0 deletions ngff_zarr/to_ngff_zarr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import sys
from collections.abc import MutableMapping
from dataclasses import asdict
from pathlib import Path, PurePosixPath
from typing import Optional, Union

if sys.version_info < (3, 10):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata

import dask.array
import numpy as np
import zarr
from itkwasm import array_like_to_numpy_array
from zarr.storage import BaseStore

from .config import config
Expand All @@ -27,6 +34,31 @@ def _pop_metadata_optionals(metadata_dict):
return metadata_dict


def _prep_for_to_zarr(
store: Union[MutableMapping, str, Path, BaseStore], arr: dask.array.Array
) -> dask.array.Array:
try:
importlib_metadata.distribution("kvikio")
_KVIKIO_AVAILABLE = True
except importlib_metadata.PackageNotFoundError:
_KVIKIO_AVAILABLE = False

if _KVIKIO_AVAILABLE:
from kvikio.zarr import GDSStore

if not isinstance(store, GDSStore):
arr = dask.array.map_blocks(
array_like_to_numpy_array,
arr,
dtype=arr.dtype,
meta=np.empty(()),
)
return arr
return dask.array.map_blocks(
array_like_to_numpy_array, arr, dtype=arr.dtype, meta=np.empty(())
)


def to_ngff_zarr(
store: Union[MutableMapping, str, Path, BaseStore],
multiscales: Multiscales,
Expand Down Expand Up @@ -221,6 +253,7 @@ def to_ngff_zarr(
f"[green]Writing scale {index+1} of {nscales}, region {region_index+1} of {len(regions)}"
)
arr_region = arr[region]
arr_region = _prep_for_to_zarr(store, arr_region)
optimized = dask.array.Array(
dask.array.optimize(
arr_region.__dask_graph__(), arr_region.__dask_keys__()
Expand All @@ -245,6 +278,7 @@ def to_ngff_zarr(
progress.add_callback_task(
f"[green]Writing scale {index+1} of {nscales}"
)
arr = _prep_for_to_zarr(store, arr)
dask.array.to_zarr(
arr,
store,
Expand Down
121 changes: 57 additions & 64 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ readme = "README.md"
requires-python = ">=3.8"
license = "MIT"
keywords = []
authors = [
{ name = "Matt McCormick", email = "[email protected]" },
]
authors = [{ name = "Matt McCormick", email = "[email protected]" }]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
Expand All @@ -30,15 +28,15 @@ classifiers = [
"Topic :: Scientific/Engineering",
]
dependencies = [
"dask[array]",
"itkwasm >= 1.0b167",
"itkwasm-downsample >= 1.1.0",
"numpy",
"platformdirs",
"psutil; sys_platform != \"emscripten\"",
"rich",
"typing_extensions",
"zarr",
"dask[array]",
"itkwasm >= 1.0b168",
"itkwasm-downsample >= 1.1.0",
"numpy",
"platformdirs",
"psutil; sys_platform != \"emscripten\"",
"rich",
"typing_extensions",
"zarr",
]
dynamic = ["version"]

Expand All @@ -55,32 +53,27 @@ ngff-zarr = "ngff_zarr.cli:main"
path = "ngff_zarr/__about__.py"

[project.optional-dependencies]
dask-image = [
"dask-image",
]
itk = [
"itk-filtering>=5.3.0",
]
dask-image = ["dask-image"]
itk = ["itk-filtering>=5.3.0"]
cli = [
"dask-image",
"dask[distributed]",
"itk-filtering>=5.3.0",
"itk-io>=5.3.0",
"itkwasm-image-io",
"imageio",
"tifffile",
"imagecodecs",
"dask-image",
"dask[distributed]",
"itk-filtering>=5.3.0",
"itk-io>=5.3.0",
"itkwasm-image-io",
"imageio",
"tifffile",
"imagecodecs",
]
test = [
"pytest >=6",
"pre-commit",
"pooch",
"itkwasm",
"itk-io>=5.3.0",
"itkwasm-image-io",
"itk-filtering>=5.3.0",
"tifffile",
"jsonschema",
"pytest >=6",
"pre-commit",
"pooch",
"itk-io>=5.3.0",
"itkwasm-image-io",
"itk-filtering>=5.3.0",
"tifffile",
"jsonschema",
]


Expand All @@ -100,40 +93,40 @@ filterwarnings = [
"ignore:(ast.Str|Attribute s|ast.NameConstant|ast.Num) is deprecated:DeprecationWarning:_pytest",
]
log_cli_level = "INFO"
testpaths = [
"tests",
]
testpaths = ["tests"]


[tool.ruff]
select = [
"E", "F", "W", # flake8
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"G", # flake8-logging-format
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
"UP", # pyupgrade
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
"E",
"F",
"W", # flake8
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"G", # flake8-logging-format
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
"UP", # pyupgrade
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
]
extend-ignore = [
"PLR", # Design related pylint codes
"E501", # Line too long
"PLR", # Design related pylint codes
"E501", # Line too long
]
src = ["src"]
unfixable = [
Expand All @@ -150,7 +143,7 @@ line-length = 88

[tool.pylint]
py-version = "3.8"
ignore-paths= ["src/ngff_zarr/__about__.py"]
ignore-paths = ["src/ngff_zarr/__about__.py"]
reports.output-format = "colorized"
similarities.ignore-imports = "yes"
messages_control.disable = [
Expand Down
7 changes: 4 additions & 3 deletions test/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ngff_zarr import itk_image_to_ngff_image, to_ngff_zarr
from zarr.storage import DirectoryStore, MemoryStore

test_data_ipfs_cid = "bafybeiawyalfemcmlfbizetoqilpmbk6coowu7cqr7av6aff4dpjwlsk6m"
test_data_sha256 = "3f32e9e8fac84de3fbe63d0a6142b2eb65cadd8c9e1c3ba7f93080a6bc2150ef"
test_data_ipfs_cid = "bafybeiaskr5fxg6rbcwlxl6ibzqhubdleacenrpbnymc6oblwoi7ceqzta"
test_data_sha256 = "95e1f3864267dd9e0bd9ba7c99515d5952ca721b9dbbf282271e696fdab48f65"

test_dir = Path(__file__).resolve().parent
extract_dir = "data"
Expand All @@ -21,7 +21,8 @@ def input_images():
pooch.retrieve(
fname="data.tar.gz",
path=test_dir,
url=f"https://{test_data_ipfs_cid}.ipfs.w3s.link/ipfs/{test_data_ipfs_cid}/data.tar.gz",
url=f"https://itk.mypinata.cloud/ipfs/{test_data_ipfs_cid}/data.tar.gz",
# url=f"https://{test_data_ipfs_cid}.ipfs.w3s.link/ipfs/{test_data_ipfs_cid}/data.tar.gz",
known_hash=f"sha256:{test_data_sha256}",
processor=untar,
)
Expand Down
38 changes: 32 additions & 6 deletions test/test_to_ngff_zarr_itkwasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,65 @@

from ._data import verify_against_baseline

_HAVE_CUCIM = False
try:
import itkwasm_downsample_cucim # noqa: F401

_HAVE_CUCIM = True
except ImportError:
pass


def test_bin_shrink_isotropic_scale_factors(input_images):
dataset_name = "cthead1"
image = input_images[dataset_name]
baseline_name = "2_4/ITKWASM_BIN_SHRINK.zarr"
if _HAVE_CUCIM:
baseline_name = "2_4/ITKWASM_BIN_SHRINK_CUCIM.zarr"
else:
baseline_name = "2_4/ITKWASM_BIN_SHRINK.zarr"
multiscales = to_multiscales(image, [2, 4], method=Methods.ITKWASM_BIN_SHRINK)
verify_against_baseline(dataset_name, baseline_name, multiscales)

baseline_name = "auto/ITKWASM_BIN_SHRINK.zarr"
if _HAVE_CUCIM:
baseline_name = "auto/ITKWASM_BIN_SHRINK_CUCIM.zarr"
else:
baseline_name = "auto/ITKWASM_BIN_SHRINK.zarr"
multiscales = to_multiscales(image, method=Methods.ITKWASM_BIN_SHRINK)
verify_against_baseline(dataset_name, baseline_name, multiscales)


def test_gaussian_isotropic_scale_factors(input_images):
dataset_name = "cthead1"
image = input_images[dataset_name]
baseline_name = "2_4/ITKWASM_GAUSSIAN.zarr"
if _HAVE_CUCIM:
baseline_name = "2_4/ITKWASM_GAUSSIAN_CUCIM.zarr"
else:
baseline_name = "2_4/ITKWASM_GAUSSIAN.zarr"
multiscales = to_multiscales(image, [2, 4], method=Methods.ITKWASM_GAUSSIAN)
verify_against_baseline(dataset_name, baseline_name, multiscales)

baseline_name = "auto/ITKWASM_GAUSSIAN.zarr"
if _HAVE_CUCIM:
baseline_name = "auto/ITKWASM_GAUSSIAN_CUCIM.zarr"
else:
baseline_name = "auto/ITKWASM_GAUSSIAN.zarr"
multiscales = to_multiscales(image, method=Methods.ITKWASM_GAUSSIAN)
verify_against_baseline(dataset_name, baseline_name, multiscales)

dataset_name = "cthead1"
image = input_images[dataset_name]
baseline_name = "2_3/ITKWASM_GAUSSIAN.zarr"
if _HAVE_CUCIM:
baseline_name = "2_3/ITKWASM_GAUSSIAN_CUCIM.zarr"
else:
baseline_name = "2_3/ITKWASM_GAUSSIAN.zarr"
multiscales = to_multiscales(image, [2, 3], method=Methods.ITKWASM_GAUSSIAN)
verify_against_baseline(dataset_name, baseline_name, multiscales)

dataset_name = "MR-head"
image = input_images[dataset_name]
baseline_name = "2_3_4/ITKWASM_GAUSSIAN.zarr"
if _HAVE_CUCIM:
baseline_name = "2_3_4/ITKWASM_GAUSSIAN_CUCIM.zarr"
else:
baseline_name = "2_3_4/ITKWASM_GAUSSIAN.zarr"
multiscales = to_multiscales(image, [2, 3, 4], method=Methods.ITKWASM_GAUSSIAN)
verify_against_baseline(dataset_name, baseline_name, multiscales)

Expand Down
Loading

0 comments on commit c99e033

Please sign in to comment.