diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b6a5ecb0..12a96fcf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,8 +2,7 @@ name: CI on: push: - branches: - - main + branches: [main] tags: - "v*" pull_request: @@ -24,48 +23,23 @@ jobs: - run: pipx run check-manifest test: - name: ${{ matrix.platform }} (${{ matrix.python-version }}) - runs-on: ${{ matrix.platform }} + uses: pyapp-kit/workflows/.github/workflows/test-pyrepo.yml@v2 + with: + os: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + coverage-upload: artifact strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - platform: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12"] - steps: - - uses: actions/checkout@v4 - - - name: ๐Ÿ Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache-dependency-path: "pyproject.toml" - cache: "pip" - - - name: Install Dependencies - run: | - python -m pip install -U pip - # if running a cron job, we add the --pre flag to test against pre-releases - python -m pip install .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }} - - - name: ๐Ÿงช Run Tests - run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing - - - name: ๐Ÿ“ Report --pre Failures - if: failure() && github.event_name == 'schedule' - uses: JasonEtco/create-an-issue@v2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PLATFORM: ${{ matrix.platform }} - PYTHON: ${{ matrix.python-version }} - RUN_ID: ${{ github.run_id }} - TITLE: "[test-bot] pip install --pre is failing" - with: - filename: .github/TEST_FAIL_TEMPLATE.md - update_existing: true - - - name: Coverage - uses: codecov/codecov-action@v3 + upload_coverage: + if: always() + needs: [test] + uses: pyapp-kit/workflows/.github/workflows/upload-coverage.yml@v2 + secrets: + codecov_token: ${{ secrets.CODECOV_TOKEN }} deploy: name: Deploy @@ -83,7 +57,7 @@ jobs: fetch-depth: 0 - name: ๐Ÿ Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" @@ -98,4 +72,4 @@ jobs: - uses: softprops/action-gh-release@v1 with: generate_release_notes: true - files: './dist/*' + files: "./dist/*" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0b547e0..f8b79eef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,3 @@ -# enable pre-commit.ci at https://pre-commit.ci/ -# it adds: -# 1. auto fixing pull requests -# 2. auto updating the pre-commit configuration ci: autoupdate_schedule: monthly autofix_commit_msg: "style(pre-commit.ci): auto fixes [...]" @@ -17,13 +13,13 @@ repos: rev: typos-dict-v0.11.20 hooks: - id: typos - args: [--force-exclude] # omitting --write-changes + args: [--force-exclude] # omitting --write-changes - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.4.8 hooks: - id: ruff - args: [--fix] # may also add '--unsafe-fixes' + args: [--fix, --unsafe-fixes] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy @@ -31,6 +27,5 @@ repos: hooks: - id: mypy files: "^src/" - # # you have to add the things you want to type check against here - # additional_dependencies: - # - numpy + additional_dependencies: + - numpy diff --git a/examples/dask_arr.py b/examples/dask_arr.py new file mode 100644 index 00000000..a9c514e1 --- /dev/null +++ b/examples/dask_arr.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import numpy as np + +try: + from dask.array.core import map_blocks +except ImportError: + raise ImportError("Please `pip install dask[array]` to run this example.") + +frame_size = (1024, 1024) + + +def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: + if isinstance(block_id, np.ndarray): + return None + data = np.random.randint(0, 255, size=frame_size, dtype=np.uint8) + return data[(None,) * 3] + + +chunks = [(1,) * x for x in (1000, 64, 3)] +chunks += [(x,) for x in frame_size] +dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) + +if __name__ == "__main__": + from qtpy import QtWidgets + + from ndv import NDViewer + + qapp = QtWidgets.QApplication([]) + v = NDViewer(dask_arr) + v.show() + qapp.exec() diff --git a/examples/jax_arr.py b/examples/jax_arr.py new file mode 100644 index 00000000..ed0e3208 --- /dev/null +++ b/examples/jax_arr.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +try: + import jax.numpy as jnp +except ImportError: + raise ImportError("Please install jax to run this example") +from numpy_arr import generate_5d_sine_wave +from qtpy import QtWidgets + +from ndv import NDViewer + +# Example usage +array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions +sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape)) + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = NDViewer(sine_wave_5d, channel_axis=1) + v.show() + qapp.exec() diff --git a/examples/numpy_arr.py b/examples/numpy_arr.py new file mode 100644 index 00000000..d9b0fb86 --- /dev/null +++ b/examples/numpy_arr.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import numpy as np + + +def generate_5d_sine_wave( + shape: tuple[int, int, int, int, int], + amplitude: float = 240, + base_frequency: float = 5, +) -> np.ndarray: + """5D dataset.""" + # Unpack the dimensions + angle_dim, freq_dim, phase_dim, ny, nx = shape + + # Create an empty array to hold the data + output = np.zeros(shape) + + # Define spatial coordinates for the last two dimensions + half_per = base_frequency * np.pi + x = np.linspace(-half_per, half_per, nx) + y = np.linspace(-half_per, half_per, ny) + y, x = np.meshgrid(y, x) + + # Iterate through each parameter in the higher dimensions + for phase_idx in range(phase_dim): + for freq_idx in range(freq_dim): + for angle_idx in range(angle_dim): + # Calculate phase and frequency + phase = np.pi / phase_dim * phase_idx + frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step + + # Calculate angle + angle = np.pi / angle_dim * angle_idx + # Rotate x and y coordinates + xr = np.cos(angle) * x - np.sin(angle) * y + np.sin(angle) * x + np.cos(angle) * y + + # Compute the sine wave + sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase) + sine_wave += amplitude * 0.5 + + # Assign to the output array + output[angle_idx, freq_idx, phase_idx] = sine_wave + + return output + + +try: + from skimage import data + + img = data.cells3d() +except Exception: + img = generate_5d_sine_wave((10, 3, 8, 512, 512)) + + +if __name__ == "__main__": + from qtpy import QtWidgets + + from ndv import NDViewer + + qapp = QtWidgets.QApplication([]) + v = NDViewer(img) + v.show() + qapp.exec() diff --git a/examples/tensorstore_arr.py b/examples/tensorstore_arr.py new file mode 100644 index 00000000..9ac30a90 --- /dev/null +++ b/examples/tensorstore_arr.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import numpy as np +import tensorstore as ts +from qtpy import QtWidgets + +from ndv import NDViewer + +shape = (10, 4, 3, 512, 512) +ts_array = ts.open( + {"driver": "zarr", "kvstore": {"driver": "memory"}}, + create=True, + shape=shape, + dtype=ts.uint8, +).result() +ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8) +ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]] + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = NDViewer(ts_array) + v.show() + qapp.exec() diff --git a/examples/xarray_arr.py b/examples/xarray_arr.py new file mode 100644 index 00000000..05eaac09 --- /dev/null +++ b/examples/xarray_arr.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import xarray as xr +from qtpy import QtWidgets + +from ndv import NDViewer + +da = xr.tutorial.open_dataset("air_temperature").air + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = NDViewer(da, colormaps=["thermal"], channel_mode="composite") + v.show() + qapp.exec() diff --git a/examples/zarr_arr.py b/examples/zarr_arr.py new file mode 100644 index 00000000..ab31d9c8 --- /dev/null +++ b/examples/zarr_arr.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import zarr +import zarr.storage +from qtpy import QtWidgets + +from ndv import NDViewer + +URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr" +zarr_arr = zarr.open(URL, mode="r") + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = NDViewer(zarr_arr["s0"]) + v.show() + qapp.exec() diff --git a/pyproject.toml b/pyproject.toml index 4c76cc5b..db9e30c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,33 +17,35 @@ sources = ["src"] [project] name = "ndv" dynamic = ["version"] -description = "simple ndviewer" +description = "simple nd image viewer" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { text = "BSD-3-Clause" } -authors = [{ name = "Talley Lambert", email = "talley.lambert@example.com" }] +authors = [{ name = "Talley Lambert", email = "talley.lambert@gmail.com" }] classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Typing :: Typed", ] -dependencies = [] +dependencies = ["qtpy", "numpy", "superqt[cmap,iconify]"] # https://peps.python.org/pep-0621/#dependencies-optional-dependencies [project.optional-dependencies] -test = ["pytest", "pytest-cov"] +pyqt = ["pyqt6"] +vispy = ["vispy", "pyopengl"] +pyside = ["pyside6"] +test = ["pytest", "pytest-cov", "pytest-qt", "dask"] dev = [ "ipython", "mypy", - "pdbpp", # https://github.com/pdbpp/pdbpp + "pdbpp", # https://github.com/pdbpp/pdbpp "pre-commit", - "rich", # https://github.com/Textualize/rich + "rich", # https://github.com/Textualize/rich "ruff", ] @@ -81,11 +83,13 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D", "S"] +"examples/*.py" = ["D", "B9"] + # https://docs.astral.sh/ruff/formatter/ [tool.ruff.format] docstring-code-format = true -skip-magic-trailing-comma = false # default is false +skip-magic-trailing-comma = false # default is false # https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] @@ -119,8 +123,7 @@ exclude_lines = [ source = ["ndv"] [tool.check-manifest] -ignore = [ - ".pre-commit-config.yaml", - ".ruff_cache/**/*", - "tests/**/*", -] +ignore = [".pre-commit-config.yaml", ".ruff_cache/**/*", "tests/**/*"] + +[tool.typos.default] +extend-ignore-identifiers-re = ["(?i)nd2?.*", "(?i)ome", ".*ser_schema"] \ No newline at end of file diff --git a/src/ndv/__init__.py b/src/ndv/__init__.py index 76e7bac7..305e64e7 100644 --- a/src/ndv/__init__.py +++ b/src/ndv/__init__.py @@ -8,3 +8,8 @@ __version__ = "uninstalled" __author__ = "Talley Lambert" __email__ = "talley.lambert@example.com" + +from .viewer._indexing import DataWrapper +from .viewer._stack_viewer import NDViewer + +__all__ = ["NDViewer", "DataWrapper"] diff --git a/src/ndv/viewer/__init__.py b/src/ndv/viewer/__init__.py new file mode 100644 index 00000000..09c94709 --- /dev/null +++ b/src/ndv/viewer/__init__.py @@ -0,0 +1 @@ +"""viewer source.""" diff --git a/src/ndv/viewer/_backends/__init__.py b/src/ndv/viewer/_backends/__init__.py new file mode 100644 index 00000000..5aa7ce80 --- /dev/null +++ b/src/ndv/viewer/_backends/__init__.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import importlib +import importlib.util +import os +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ndv.viewer._protocols import PCanvas + + +def get_canvas(backend: str | None = None) -> type[PCanvas]: + backend = backend or os.getenv("CANVAS_BACKEND", None) + if backend == "vispy" or (backend is None and "vispy" in sys.modules): + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if backend == "pygfx" or (backend is None and "pygfx" in sys.modules): + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + if backend is None: + if importlib.util.find_spec("vispy") is not None: + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if importlib.util.find_spec("pygfx") is not None: + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + raise RuntimeError("No canvas backend found") diff --git a/src/ndv/viewer/_backends/_pygfx.py b/src/ndv/viewer/_backends/_pygfx.py new file mode 100644 index 00000000..813ff8c9 --- /dev/null +++ b/src/ndv/viewer/_backends/_pygfx.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, cast + +import numpy as np +import pygfx +from qtpy.QtCore import QSize +from wgpu.gui.qt import QWgpuCanvas + +if TYPE_CHECKING: + import cmap + from pygfx.materials import ImageBasicMaterial + from pygfx.resources import Texture + from qtpy.QtWidgets import QWidget + + +class PyGFXImageHandle: + def __init__(self, image: pygfx.Image, render: Callable) -> None: + self._image = image + self._render = render + self._grid = cast("Texture", image.geometry.grid) + self._material = cast("ImageBasicMaterial", image.material) + + @property + def data(self) -> np.ndarray: + return self._grid.data # type: ignore [no-any-return] + + @data.setter + def data(self, data: np.ndarray) -> None: + self._grid.data[:] = data + self._grid.update_range((0, 0, 0), self._grid.size) + + @property + def visible(self) -> bool: + return bool(self._image.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._image.visible = visible + self._render() + + @property + def clim(self) -> Any: + return self._material.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + self._material.clim = clims + self._render() + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._material.map = cmap.to_pygfx() + self._render() + + def remove(self) -> None: + if (par := self._image.parent) is not None: + par.remove(self._image) + + +class _QWgpuCanvas(QWgpuCanvas): + def sizeHint(self) -> QSize: + return QSize(512, 512) + + +class PyGFXViewerCanvas: + """pygfx-based canvas wrapper.""" + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + + self._canvas = _QWgpuCanvas(size=(512, 512)) + self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) + # requires https://github.com/pygfx/pygfx/pull/752 + self._renderer.blend_mode = "additive" + self._scene = pygfx.Scene() + self._camera = cam = pygfx.OrthographicCamera(512, 512) + cam.local.scale_y = -1 + + cam.local.position = (256, 256, 0) + self._controller = pygfx.PanZoomController(cam, register_events=self._renderer) + # increase zoom wheel gain + self._controller.controls.update({"wheel": ("zoom_to_point", "push", -0.005)}) + + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas) + + def refresh(self) -> None: + self._canvas.update() + self._canvas.request_draw(self._animate) + + def _animate(self) -> None: + self._renderer.render(self._scene, self._camera) + + def set_ndim(self, ndim: int) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim != 2: + raise NotImplementedError("Volume rendering is not supported by pygfx yet.") + + def add_volume( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + raise NotImplementedError("Volume rendering is not supported by pygfx yet.") + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + """Add a new Image node to the scene.""" + image = pygfx.Image( + pygfx.Geometry(grid=pygfx.Texture(data, dim=2)), + # depth_test=False for additive-like blending + pygfx.ImageBasicMaterial(depth_test=False), + ) + self._scene.add(image) + # FIXME: I suspect there are more performant ways to refresh the canvas + # look into it. + handle = PyGFXImageHandle(image, self.refresh) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0.05, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + if not self._scene.children: + return + + cam = self._camera + cam.show_object(self._scene) + + width, height, depth = np.ptp(self._scene.get_world_bounding_box(), axis=0) + if width < 0.01: + width = 1 + if height < 0.01: + height = 1 + cam.width = width + cam.height = height + cam.zoom = 1 - margin + self.refresh() + + # def _on_mouse_move(self, event: SceneMouseEvent) -> None: + # """Mouse moved on the canvas, display the pixel value and position.""" + # images = [] + # # Get the images the mouse is over + # seen = set() + # while visual := self._canvas.visual_at(event.pos): + # if isinstance(visual, scene.visuals.Image): + # images.append(visual) + # visual.interactive = False + # seen.add(visual) + # for visual in seen: + # visual.interactive = True + # if not images: + # return + + # tform = images[0].get_transform("canvas", "visual") + # px, py, *_ = (int(x) for x in tform.map(event.pos)) + # text = f"[{py}, {px}]" + # for c, img in enumerate(images): + # with suppress(IndexError): + # text += f" c{c}: {img._data[py, px]}" + # self._set_info(text) diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py new file mode 100644 index 00000000..c7404785 --- /dev/null +++ b/src/ndv/viewer/_backends/_vispy.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import numpy as np +import vispy +import vispy.scene +import vispy.visuals +from superqt.utils import qthrottled +from vispy import scene +from vispy.util.quaternion import Quaternion + +if TYPE_CHECKING: + import cmap + from qtpy.QtWidgets import QWidget + from vispy.scene.events import SceneMouseEvent + +turn = np.sin(np.pi / 4) +DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) + + +class VispyImageHandle: + def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: + self._visual = visual + + @property + def data(self) -> np.ndarray: + try: + return self._visual._data # type: ignore [no-any-return] + except AttributeError: + return self._visual._last_data # type: ignore [no-any-return] + + @data.setter + def data(self, data: np.ndarray) -> None: + self._visual.set_data(data) + + @property + def visible(self) -> bool: + return bool(self._visual.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._visual.visible = visible + + @property + def clim(self) -> Any: + return self._visual.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + with suppress(ZeroDivisionError): + self._visual.clim = clims + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._visual.cmap = cmap.to_vispy() + + @property + def transform(self) -> np.ndarray: + raise NotImplementedError + + @transform.setter + def transform(self, transform: np.ndarray) -> None: + raise NotImplementedError + + def remove(self) -> None: + self._visual.parent = None + + +class VispyViewerCanvas: + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + self._canvas = scene.SceneCanvas() + self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) + self._current_shape: tuple[int, ...] = () + self._last_state: dict[Literal[2, 3], Any] = {} + + central_wdg: scene.Widget = self._canvas.central_widget + self._view: scene.ViewBox = central_wdg.add_view() + self._ndim: Literal[2, 3] | None = None + + @property + def _camera(self) -> vispy.scene.cameras.BaseCamera: + return self._view.camera + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim == self._ndim: + return + elif self._ndim is not None: + # remember the current state before switching to the new camera + self._last_state[self._ndim] = self._camera.get_state() + + self._ndim = ndim + if ndim == 3: + cam = scene.ArcballCamera(fov=0) + # this sets the initial view similar to what the panzoom view would have. + cam._quaternion = DEFAULT_QUATERNION + else: + cam = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + + # restore the previous state if it exists + if state := self._last_state.get(ndim): + cam.set_state(state) + self._view.camera = cam + + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas.native) + + def refresh(self) -> None: + self._canvas.update() + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> VispyImageHandle: + """Add a new Image node to the scene.""" + img = scene.visuals.Image(data, parent=self._view.scene) + img.set_gl_state("additive", depth_test=False) + img.interactive = True + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() + handle = VispyImageHandle(img) + if cmap is not None: + handle.cmap = cmap + return handle + + def add_volume( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> VispyImageHandle: + vol = scene.visuals.Volume( + data, parent=self._view.scene, interpolation="nearest" + ) + vol.set_gl_state("additive", depth_test=False) + vol.interactive = True + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if len(prev_shape) != 3: + self.set_range() + handle = VispyImageHandle(vol) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0.01, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + if len(self._current_shape) >= 2: + if x is None: + x = (0, self._current_shape[-1]) + if y is None: + y = (0, self._current_shape[-2]) + if z is None and len(self._current_shape) == 3: + z = (0, self._current_shape[-3]) + is_3d = isinstance(self._camera, scene.ArcballCamera) + if is_3d: + self._camera._quaternion = DEFAULT_QUATERNION + self._view.camera.set_range(x=x, y=y, z=z, margin=margin) + if is_3d: + max_size = max(self._current_shape) + self._camera.scale_factor = max_size + 6 + + def _on_mouse_move(self, event: SceneMouseEvent) -> None: + """Mouse moved on the canvas, display the pixel value and position.""" + images = [] + # Get the images the mouse is over + # FIXME: this is narsty ... there must be a better way to do this + seen = set() + try: + while visual := self._canvas.visual_at(event.pos): + if isinstance(visual, scene.visuals.Image): + images.append(visual) + visual.interactive = False + seen.add(visual) + except Exception: + return + for visual in seen: + visual.interactive = True + if not images: + return + + tform = images[0].get_transform("canvas", "visual") + px, py, *_ = (int(x) for x in tform.map(event.pos)) + text = f"[{py}, {px}]" + for c, img in enumerate(reversed(images)): + with suppress(IndexError): + value = img._data[py, px] + if isinstance(value, (np.floating, float)): + value = f"{value:.2f}" + text += f" {c}: {value}" + self._set_info(text) diff --git a/src/ndv/viewer/_dims_slider.py b/src/ndv/viewer/_dims_slider.py new file mode 100644 index 00000000..21967943 --- /dev/null +++ b/src/ndv/viewer/_dims_slider.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast +from warnings import warn + +from qtpy.QtCore import QPoint, QPointF, QSize, Qt, Signal +from qtpy.QtGui import QCursor, QResizeEvent +from qtpy.QtWidgets import ( + QDialog, + QDoubleSpinBox, + QFormLayout, + QFrame, + QHBoxLayout, + QLabel, + QPushButton, + QSizePolicy, + QSlider, + QSpinBox, + QVBoxLayout, + QWidget, +) +from superqt import QElidingLabel, QLabeledRangeSlider +from superqt.iconify import QIconifyIcon +from superqt.utils import signals_blocked + +if TYPE_CHECKING: + from typing import Hashable, Mapping, TypeAlias + + from PyQt6.QtGui import QResizeEvent + + # any hashable represent a single dimension in a AND array + DimKey: TypeAlias = Hashable + # any object that can be used to index a single dimension in an AND array + Index: TypeAlias = int | slice + # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) + # this object is used frequently to query or set the currently displayed slice + Indices: TypeAlias = Mapping[DimKey, Index] + # mapping of dimension keys to the maximum value for that dimension + Sizes: TypeAlias = Mapping[DimKey, int] + + +SS = """ +QSlider::groove:horizontal { + height: 15px; + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(128, 128, 128, 0.25), + stop:1 rgba(128, 128, 128, 0.1) + ); + border-radius: 3px; +} + +QSlider::handle:horizontal { + width: 38px; + background: #999999; + border-radius: 3px; +} + +QLabel { font-size: 12px; } + +QRangeSlider { qproperty-barColor: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(100, 80, 120, 0.2), + stop:1 rgba(100, 80, 120, 0.4) + )} + +SliderLabel { + font-size: 12px; + color: white; +} +""" + + +class QtPopup(QDialog): + """A generic popup window.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setModal(False) # if False, then clicking anywhere else closes it + self.setWindowFlags(Qt.WindowType.Popup | Qt.WindowType.FramelessWindowHint) + + self.frame = QFrame(self) + layout = QVBoxLayout(self) + layout.addWidget(self.frame) + layout.setContentsMargins(0, 0, 0, 0) + + def show_above_mouse(self, *args: Any) -> None: + """Show popup dialog above the mouse cursor position.""" + pos = QCursor().pos() # mouse position + szhint = self.sizeHint() + pos -= QPoint(szhint.width() // 2, szhint.height() + 14) + self.move(pos) + self.resize(self.sizeHint()) + self.show() + + +class PlayButton(QPushButton): + """Just a styled QPushButton that toggles between play and pause icons.""" + + fpsChanged = Signal(float) + + PLAY_ICON = "bi:play-fill" + PAUSE_ICON = "bi:pause-fill" + + def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.PLAY_ICON, color="#888888") + icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") + super().__init__(icn, "", parent) + self.spin = QDoubleSpinBox(self) + self.spin.setRange(0.5, 100) + self.spin.setValue(fps) + self.spin.valueChanged.connect(self.fpsChanged) + self.setCheckable(True) + self.setFixedSize(14, 18) + self.setIconSize(QSize(16, 16)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + self._popup = QtPopup(self) + form = QFormLayout(self._popup.frame) + form.setContentsMargins(6, 6, 6, 6) + form.addRow("FPS", self.spin) + + def mousePressEvent(self, e: Any) -> None: + if e and e.button() == Qt.MouseButton.RightButton: + self._show_fps_dialog(e.globalPosition()) + else: + super().mousePressEvent(e) + + def _show_fps_dialog(self, pos: QPointF) -> None: + self._popup.show_above_mouse() + + +class LockButton(QPushButton): + LOCK_ICON = "uis:unlock" + UNLOCK_ICON = "uis:lock" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.LOCK_ICON, color="#888888") + icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On, color="red") + super().__init__(icn, text, parent) + self.setCheckable(True) + self.setFixedSize(20, 20) + self.setIconSize(QSize(14, 14)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + +class DimsSlider(QWidget): + """A single slider in the DimsSliders widget. + + Provides a play/pause button that toggles animation of the slider value. + Has a QLabeledSlider for the actual value. + Adds a label for the maximum value (e.g. "3 / 10") + """ + + valueChanged = Signal(object, object) # where object is int | slice + + def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setStyleSheet(SS) + self._slice_mode = False + self._dim_key = dimension_key + + self._timer_id: int | None = None # timer for play button + self._play_btn = PlayButton(parent=self) + self._play_btn.fpsChanged.connect(self.set_fps) + self._play_btn.toggled.connect(self._toggle_animation) + + self._dim_key = dimension_key + self._dim_label = QElidingLabel(str(dimension_key).upper()) + self._dim_label.setToolTip("Double-click to toggle slice mode") + + # note, this lock button only prevents the slider from updating programmatically + # using self.setValue, it doesn't prevent the user from changing the value. + self._lock_btn = LockButton(parent=self) + + self._pos_label = QSpinBox(self) + self._pos_label.valueChanged.connect(self._on_pos_label_edited) + self._pos_label.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) + self._pos_label.setAlignment(Qt.AlignmentFlag.AlignRight) + self._pos_label.setStyleSheet( + "border: none; padding: 0; margin: 0; background: transparent" + ) + self._out_of_label = QLabel(self) + + self._int_slider = QSlider(Qt.Orientation.Horizontal) + self._int_slider.rangeChanged.connect(self._on_range_changed) + self._int_slider.valueChanged.connect(self._on_int_value_changed) + + self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) + slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) + slc.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + slc.setVisible(False) + slc.rangeChanged.connect(self._on_range_changed) + slc.valueChanged.connect(self._on_slice_value_changed) + + self.installEventFilter(self) + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(2) + layout.addWidget(self._play_btn) + layout.addWidget(self._dim_label) + layout.addWidget(self._int_slider) + layout.addWidget(self._slice_slider) + layout.addWidget(self._pos_label) + layout.addWidget(self._out_of_label) + layout.addWidget(self._lock_btn) + self.setMinimumHeight(22) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + if isinstance(par := self.parent(), DimsSliders): + par.resizeEvent(None) + + def mouseDoubleClickEvent(self, a0: Any) -> None: + self._set_slice_mode(not self._slice_mode) + super().mouseDoubleClickEvent(a0) + + def containMaximum(self, max_val: int) -> None: + if max_val > self._int_slider.maximum(): + self._int_slider.setMaximum(max_val) + if max_val > self._slice_slider.maximum(): + self._slice_slider.setMaximum(max_val) + + def setMaximum(self, max_val: int) -> None: + self._int_slider.setMaximum(max_val) + self._slice_slider.setMaximum(max_val) + + def setMinimum(self, min_val: int) -> None: + self._int_slider.setMinimum(min_val) + self._slice_slider.setMinimum(min_val) + + def containMinimum(self, min_val: int) -> None: + if min_val < self._int_slider.minimum(): + self._int_slider.setMinimum(min_val) + if min_val < self._slice_slider.minimum(): + self._slice_slider.setMinimum(min_val) + + def setRange(self, min_val: int, max_val: int) -> None: + self._int_slider.setRange(min_val, max_val) + self._slice_slider.setRange(min_val, max_val) + + def value(self) -> Index: + if not self._slice_mode: + return self._int_slider.value() # type: ignore + start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) + if start == stop: + return start + return slice(start, stop) + + def setValue(self, val: Index) -> None: + # variant of setValue that always updates the maximum + self._set_slice_mode(isinstance(val, slice)) + if self._lock_btn.isChecked(): + return + if isinstance(val, slice): + start = int(val.start) if val.start is not None else 0 + stop = ( + int(val.stop) if val.stop is not None else self._slice_slider.maximum() + ) + self._slice_slider.setValue((start, stop)) + else: + self._int_slider.setValue(val) + # self._slice_slider.setValue((val, val + 1)) + + def forceValue(self, val: Index) -> None: + """Set value and increase range if necessary.""" + if isinstance(val, slice): + if isinstance(val.start, int): + self.containMinimum(val.start) + if isinstance(val.stop, int): + self.containMaximum(val.stop) + else: + self.containMinimum(val) + self.containMaximum(val) + self.setValue(val) + + def _set_slice_mode(self, mode: bool = True) -> None: + if mode == self._slice_mode: + return + self._slice_mode = bool(mode) + self._slice_slider.setVisible(self._slice_mode) + self._int_slider.setVisible(not self._slice_mode) + # self._pos_label.setVisible(not self._slice_mode) + self.valueChanged.emit(self._dim_key, self.value()) + + def set_fps(self, fps: float) -> None: + self._play_btn.spin.setValue(fps) + self._toggle_animation(self._play_btn.isChecked()) + + def _toggle_animation(self, checked: bool) -> None: + if checked: + if self._timer_id is not None: + self.killTimer(self._timer_id) + interval = int(1000 / self._play_btn.spin.value()) + self._timer_id = self.startTimer(interval) + elif self._timer_id is not None: + self.killTimer(self._timer_id) + self._timer_id = None + + def timerEvent(self, event: Any) -> None: + """Handle timer event for play button, move to the next frame.""" + # TODO + # for now just increment the value by 1, but we should be able to + # take FPS into account better and skip additional frames if the timerEvent + # is delayed for some reason. + inc = 1 + if self._slice_mode: + val = cast(tuple[int, int], self._slice_slider.value()) + next_val = [v + inc for v in val] + if next_val[1] > self._slice_slider.maximum(): + # wrap around, without going below the min handle + next_val = [v - val[0] for v in val] + self._slice_slider.setValue(next_val) + else: + ival = self._int_slider.value() + ival = (ival + inc) % (self._int_slider.maximum() + 1) + self._int_slider.setValue(ival) + + def _on_pos_label_edited(self) -> None: + if self._slice_mode: + self._slice_slider.setValue( + (self._slice_slider.value()[0], self._pos_label.value()) + ) + else: + self._int_slider.setValue(self._pos_label.value()) + + def _on_range_changed(self, min: int, max: int) -> None: + self._out_of_label.setText(f"| {max}") + self._pos_label.setRange(min, max) + self.resizeEvent(None) + self.setVisible(min != max) + + def setVisible(self, visible: bool) -> None: + if self._has_no_range(): + visible = False + super().setVisible(visible) + + def _has_no_range(self) -> bool: + if self._slice_mode: + return bool(self._slice_slider.minimum() == self._slice_slider.maximum()) + return bool(self._int_slider.minimum() == self._int_slider.maximum()) + + def _on_int_value_changed(self, value: int) -> None: + self._pos_label.setValue(value) + if not self._slice_mode: + self.valueChanged.emit(self._dim_key, value) + + def _on_slice_value_changed(self, value: tuple[int, int]) -> None: + self._pos_label.setValue(int(value[1])) + with signals_blocked(self._int_slider): + self._int_slider.setValue(int(value[0])) + if self._slice_mode: + self.valueChanged.emit(self._dim_key, slice(*value)) + + +class DimsSliders(QWidget): + """A Collection of DimsSlider widgets for each dimension in the data. + + Maintains the global current index and emits a signal when it changes. + """ + + valueChanged = Signal(dict) # dict is of type Indices + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._locks_visible: bool | Mapping[DimKey, bool] = False + self._sliders: dict[DimKey, DimsSlider] = {} + self._current_index: dict[DimKey, Index] = {} + self._invisible_dims: set[DimKey] = set() + + self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + def __contains__(self, key: DimKey) -> bool: + """Return True if the dimension key is present in the DimsSliders.""" + return key in self._sliders + + def slider(self, key: DimKey) -> DimsSlider: + """Return the DimsSlider widget for the given dimension key.""" + return self._sliders[key] + + def value(self) -> Indices: + """Return mapping of {dim_key -> current index} for each dimension.""" + return self._current_index.copy() + + def setValue(self, values: Indices) -> None: + """Set the current index for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int | slice] + Mapping of {dim_key -> index} for each dimension. If value is a slice, + the slider will be in slice mode. If the dimension is not present in the + DimsSliders, it will be added. + """ + if self._current_index == values: + return + with signals_blocked(self): + for dim, index in values.items(): + self.add_or_update_dimension(dim, index) + # FIXME: i don't know why this this is ever empty ... only happens on pyside6 + if val := self.value(): + self.valueChanged.emit(val) + + def minima(self) -> Sizes: + """Return mapping of {dim_key -> minimum value} for each dimension.""" + return {k: v._int_slider.minimum() for k, v in self._sliders.items()} + + def setMinima(self, values: Sizes) -> None: + """Set the minimum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> minimum value} for each dimension. + """ + for name, min_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMinimum(min_val) + + def maxima(self) -> Sizes: + """Return mapping of {dim_key -> maximum value} for each dimension.""" + return {k: v._int_slider.maximum() for k, v in self._sliders.items()} + + def setMaxima(self, values: Sizes) -> None: + """Set the maximum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> maximum value} for each dimension. + """ + for name, max_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMaximum(max_val) + + def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: + """Set the visibility of the lock buttons for all dimensions.""" + self._locks_visible = visible + for dim, slider in self._sliders.items(): + viz = visible if isinstance(visible, bool) else visible.get(dim, False) + slider._lock_btn.setVisible(viz) + + def add_dimension(self, key: DimKey, val: Index | None = None) -> None: + """Add a new dimension to the DimsSliders widget. + + Parameters + ---------- + key : Hashable + The name of the dimension. + val : int | slice, optional + The initial value for the dimension. If a slice, the slider will be in + slice mode. + """ + self._sliders[key] = slider = DimsSlider(dimension_key=key, parent=self) + if isinstance(self._locks_visible, dict) and key in self._locks_visible: + slider._lock_btn.setVisible(self._locks_visible[key]) + else: + slider._lock_btn.setVisible(bool(self._locks_visible)) + + val_int = val.start if isinstance(val, slice) else val + slider.setVisible(key not in self._invisible_dims) + if isinstance(val_int, int): + slider.setRange(val_int, val_int) + elif isinstance(val_int, slice): + slider.setRange(val_int.start or 0, val_int.stop or 1) + + val = val if val is not None else 0 + self._current_index[key] = val + slider.forceValue(val) + slider.valueChanged.connect(self._on_dim_slider_value_changed) + cast("QVBoxLayout", self.layout()).addWidget(slider) + + def set_dimension_visible(self, key: DimKey, visible: bool) -> None: + """Set the visibility of a dimension in the DimsSliders widget. + + Once a dimension is hidden, it will not be shown again until it is explicitly + made visible again with this method. + """ + if visible: + self._invisible_dims.discard(key) + if key in self._sliders: + self._current_index[key] = self._sliders[key].value() + else: + self.add_dimension(key) + else: + self._invisible_dims.add(key) + self._current_index.pop(key, None) + if key in self._sliders: + self._sliders[key].setVisible(visible) + + def remove_dimension(self, key: DimKey) -> None: + """Remove a dimension from the DimsSliders widget.""" + try: + slider = self._sliders.pop(key) + except KeyError: + warn(f"Dimension {key} not found in DimsSliders", stacklevel=2) + return + cast("QVBoxLayout", self.layout()).removeWidget(slider) + slider.deleteLater() + + def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: + self._current_index[key] = value + self.valueChanged.emit(self.value()) + + def add_or_update_dimension(self, key: DimKey, value: Index) -> None: + """Add a dimension if it doesn't exist, otherwise update the value.""" + if key in self._sliders: + self._sliders[key].forceValue(value) + else: + self.add_dimension(key, value) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + # align all labels + if sliders := list(self._sliders.values()): + for lbl in ("_dim_label", "_pos_label", "_out_of_label"): + lbl_width = max(getattr(s, lbl).sizeHint().width() for s in sliders) + for s in sliders: + getattr(s, lbl).setFixedWidth(lbl_width) + + super().resizeEvent(a0) + + def sizeHint(self) -> QSize: + return super().sizeHint().boundedTo(QSize(9999, 0)) diff --git a/src/ndv/viewer/_indexing.py b/src/ndv/viewer/_indexing.py new file mode 100644 index 00000000..354f6be2 --- /dev/null +++ b/src/ndv/viewer/_indexing.py @@ -0,0 +1,305 @@ +"""In this module, we provide built-in support for many array types.""" + +from __future__ import annotations + +import sys +import warnings +from abc import abstractmethod +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import suppress +from typing import ( + TYPE_CHECKING, + Generic, + Hashable, + Iterable, + Mapping, + Sequence, + TypeVar, + cast, +) + +import numpy as np + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any, Protocol, TypeGuard + + import dask.array as da + import numpy.typing as npt + import tensorstore as ts + import xarray as xr + from pymmcore_plus.mda.handlers import TensorStoreHandler + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase + + from ._dims_slider import Index, Indices + + class SupportsIndexing(Protocol): + def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... + @property + def shape(self) -> tuple[int, ...]: ... + + +ArrayT = TypeVar("ArrayT") +MAX_CHANNELS = 16 +# Create a global executor +_EXECUTOR = ThreadPoolExecutor(max_workers=1) + + +class DataWrapper(Generic[ArrayT]): + """Interface for wrapping different array-like data types. + + If DataWrapper.create(your_obj) raises an exception, you can implement a new + DataWrapper subclass to handle your data type. + + It can be passed to NDViewer. + """ + + def __init__(self, data: ArrayT) -> None: + self._data = data + + @classmethod + def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if isinstance(data, DataWrapper): + return data + if MMTensorStoreWrapper.supports(data): + return MMTensorStoreWrapper(data) + if MM5DWriter.supports(data): + return MM5DWriter(data) + if XarrayWrapper.supports(data): + return XarrayWrapper(data) + if DaskWrapper.supports(data): + return DaskWrapper(data) + if TensorstoreWrapper.supports(data): + return TensorstoreWrapper(data) + if ArrayLikeWrapper.supports(data): + return ArrayLikeWrapper(data) + raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + + @abstractmethod + def isel(self, indexers: Indices) -> np.ndarray: + """Select a slice from a data store using (possibly) named indices. + + For xarray.DataArray, use the built-in isel method. + For any other duck-typed array, use numpy-style indexing, where indexers + is a mapping of axis to slice objects or indices. + """ + raise NotImplementedError + + def isel_async( + self, indexers: list[Indices] + ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + """Asynchronous version of isel.""" + return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) + + @classmethod + @abstractmethod + def supports(cls, obj: Any) -> bool: + """Return True if this wrapper can handle the given object.""" + raise NotImplementedError + + def guess_channel_axis(self) -> Hashable | None: + """Return the (best guess) axis name for the channel dimension.""" + if isinstance(shp := getattr(self._data, "shape", None), Sequence): + # for numpy arrays, use the smallest dimension as the channel axis + if min(shp) <= MAX_CHANNELS: + return shp.index(min(shp)) + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + raise NotImplementedError("save_as_zarr not implemented for this data type.") + + def sizes(self) -> Mapping[Hashable, int]: + if (shape := getattr(self._data, "shape", None)) and isinstance(shape, tuple): + _sizes: dict[Hashable, int] = {} + for i, val in enumerate(shape): + if isinstance(val, int): + _sizes[i] = val + elif isinstance(val, Sequence) and len(val) == 2: + _sizes[val[0]] = int(val[1]) + else: + raise ValueError( + f"Invalid size: {val}. Must be an int or a 2-tuple." + ) + return _sizes + raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") + + def summary_info(self) -> str: + """Return info label with information about the data.""" + package = getattr(self._data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" + + if sizes := self.sizes(): + # if all of the dimension keys are just integers, omit them from size_str + if all(isinstance(x, int) for x in sizes): + size_str = repr(tuple(sizes.values())) + # otherwise, include the keys in the size_str + else: + size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(self._data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(self._data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + return info + + +class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): + def sizes(self) -> Mapping[Hashable, int]: + with suppress(Exception): + return self._data.current_sequence.sizes # type: ignore [no-any-return] + return {} + + def guess_channel_axis(self) -> Hashable | None: + return "c" + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: + with suppress(ImportError): + from pymmcore_plus.mda.handlers import TensorStoreHandler + + return isinstance(obj, TensorStoreHandler) + return False + + def isel(self, indexers: Indices) -> np.ndarray: + return self._data.isel(indexers) # type: ignore [no-any-return] + + def save_as_zarr(self, save_loc: str | Path) -> None: + if (store := self._data.store) is None: + return + import tensorstore as ts + + new_spec = store.spec().to_json() + new_spec["kvstore"] = {"driver": "file", "path": str(save_loc)} + new_ts = ts.open(new_spec, create=True).result() + new_ts[:] = store.read().result() + + +class MM5DWriter(DataWrapper["_5DWriterBase"]): + def guess_channel_axis(self) -> Hashable | None: + return "c" + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: + with suppress(ImportError): + try: + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase + except ImportError: + from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter + + _5DWriterBase = (OMETiffWriter, OMEZarrWriter) + if isinstance(obj, _5DWriterBase): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + import zarr + from pymmcore_plus.mda.handlers import OMEZarrWriter + + if isinstance(self._data, OMEZarrWriter): + zarr.copy_store(self._data.group.store, zarr.DirectoryStore(save_loc)) + raise NotImplementedError(f"Cannot save {type(self._data)} data to Zarr.") + + def isel(self, indexers: Indices) -> np.ndarray: + p_index = indexers.get("p", 0) + if isinstance(p_index, slice): + warnings.warn("Cannot slice over position index", stacklevel=2) # TODO + p_index = p_index.start + p_index = cast(int, p_index) + + try: + sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] + except IndexError as e: + raise IndexError( + f"Position index {p_index} out of range for " + f"{len(self._data.position_sizes)}" + ) from e + + data = self._data.position_arrays[self._data.get_position_key(p_index)] + full = slice(None, None) + index = tuple(indexers.get(k, full) for k in sizes) + return data[index] # type: ignore [no-any-return] + + +class XarrayWrapper(DataWrapper["xr.DataArray"]): + def isel(self, indexers: Indices) -> np.ndarray: + return np.asarray(self._data.isel(indexers)) + + def sizes(self) -> Mapping[Hashable, int]: + return {k: int(v) for k, v in self._data.sizes.items()} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: + if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): + return True + return False + + def guess_channel_axis(self) -> Hashable | None: + for d in self._data.dims: + if str(d).lower() in ("channel", "ch", "c"): + return cast("Hashable", d) + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(save_loc) + + +class DaskWrapper(DataWrapper["da.Array"]): + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx].compute()) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(url=str(save_loc)) + + +class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): + def __init__(self, data: Any) -> None: + super().__init__(data) + import tensorstore as ts + + self._ts = ts + + def sizes(self) -> Mapping[Hashable, int]: + return {dim.label: dim.size for dim in self._data.domain} + + def isel(self, indexers: Indices) -> np.ndarray: + result = self._data[self._ts.d[*indexers][*indexers.values()]].read().result() + return np.asarray(result) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + +class ArrayLikeWrapper(DataWrapper): + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx]) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + or (hasattr(obj, "__getitem__") and hasattr(obj, "__array__")) + ): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + import zarr + + if isinstance(self._data, zarr.Array): + self._data.store = zarr.DirectoryStore(save_loc) + else: + zarr.save(str(save_loc), self._data) diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py new file mode 100644 index 00000000..65ba6977 --- /dev/null +++ b/src/ndv/viewer/_lut_control.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable, cast + +import numpy as np +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QPushButton, QWidget +from superqt import QLabeledRangeSlider +from superqt.cmap import QColormapComboBox +from superqt.utils import signals_blocked + +from ._dims_slider import SS + +if TYPE_CHECKING: + import cmap + + from ._protocols import PImageHandle + + +class CmapCombo(QColormapComboBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") + self.setMinimumSize(120, 21) + # self.setStyleSheet("background-color: transparent;") + + def showPopup(self) -> None: + super().showPopup() + popup = self.findChild(QFrame) + popup.setMinimumWidth(self.width() + 100) + popup.move(popup.x(), popup.y() - self.height() - popup.height()) + + +class LutControl(QWidget): + def __init__( + self, + name: str = "", + handles: Iterable[PImageHandle] = (), + parent: QWidget | None = None, + cmaplist: Iterable[Any] = (), + ) -> None: + super().__init__(parent) + self._handles = handles + self._name = name + + self._visible = QCheckBox(name) + self._visible.setChecked(True) + self._visible.toggled.connect(self._on_visible_changed) + + self._cmap = CmapCombo() + self._cmap.currentColormapChanged.connect(self._on_cmap_changed) + for handle in handles: + self._cmap.addColormap(handle.cmap) + for color in cmaplist: + self._cmap.addColormap(color) + + self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + self._clims.setStyleSheet(SS) + self._clims.setHandleLabelPosition( + QLabeledRangeSlider.LabelPosition.LabelsOnHandle + ) + self._clims.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + self._clims.setRange(0, 2**8) + self._clims.valueChanged.connect(self._on_clims_changed) + + self._auto_clim = QPushButton("Auto") + self._auto_clim.setMaximumWidth(42) + self._auto_clim.setCheckable(True) + self._auto_clim.setChecked(True) + self._auto_clim.toggled.connect(self.update_autoscale) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._visible) + layout.addWidget(self._cmap) + layout.addWidget(self._clims) + layout.addWidget(self._auto_clim) + + self.update_autoscale() + + def autoscaleChecked(self) -> bool: + return cast("bool", self._auto_clim.isChecked()) + + def _on_clims_changed(self, clims: tuple[float, float]) -> None: + self._auto_clim.setChecked(False) + for handle in self._handles: + handle.clim = clims + + def _on_visible_changed(self, visible: bool) -> None: + for handle in self._handles: + handle.visible = visible + if visible: + self.update_autoscale() + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + for handle in self._handles: + handle.cmap = cmap + + def update_autoscale(self) -> None: + if ( + not self._auto_clim.isChecked() + or not self._visible.isChecked() + or not self._handles + ): + return + + # find the min and max values for the current channel + clims = [np.inf, -np.inf] + for handle in self._handles: + clims[0] = min(clims[0], np.nanmin(handle.data)) + clims[1] = max(clims[1], np.nanmax(handle.data)) + + mi, ma = tuple(int(x) for x in clims) + if mi != ma: + for handle in self._handles: + handle.clim = (mi, ma) + + # set the slider values to the new clims + with signals_blocked(self._clims): + self._clims.setMinimum(min(mi, self._clims.minimum())) + self._clims.setMaximum(max(ma, self._clims.maximum())) + self._clims.setValue((mi, ma)) diff --git a/src/ndv/viewer/_protocols.py b/src/ndv/viewer/_protocols.py new file mode 100644 index 00000000..413038de --- /dev/null +++ b/src/ndv/viewer/_protocols.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol + +if TYPE_CHECKING: + import cmap + import numpy as np + from qtpy.QtWidgets import QWidget + + +class PImageHandle(Protocol): + @property + def data(self) -> np.ndarray: ... + @data.setter + def data(self, data: np.ndarray) -> None: ... + @property + def visible(self) -> bool: ... + @visible.setter + def visible(self, visible: bool) -> None: ... + @property + def clim(self) -> Any: ... + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: ... + @property + def cmap(self) -> Any: ... + @cmap.setter + def cmap(self, cmap: Any) -> None: ... + def remove(self) -> None: ... + + +class PCanvas(Protocol): + def __init__(self, set_info: Callable[[str], None]) -> None: ... + def set_ndim(self, ndim: Literal[2, 3]) -> None: ... + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = ..., + ) -> None: ... + def refresh(self) -> None: ... + def qwidget(self) -> QWidget: ... + def add_image( + self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + ) -> PImageHandle: ... + def add_volume( + self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + ) -> PImageHandle: ... diff --git a/src/ndv/viewer/_save_button.py b/src/ndv/viewer/_save_button.py new file mode 100644 index 00000000..85520641 --- /dev/null +++ b/src/ndv/viewer/_save_button.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget +from superqt.iconify import QIconifyIcon + +if TYPE_CHECKING: + from ._indexing import DataWrapper + + +class SaveButton(QPushButton): + def __init__( + self, + data_wrapper: DataWrapper, + parent: QWidget | None = None, + ): + super().__init__(parent=parent) + self.setIcon(QIconifyIcon("mdi:content-save")) + self.clicked.connect(self._on_click) + + self._data_wrapper = data_wrapper + self._last_loc = str(Path.home()) + + def _on_click(self) -> None: + self._last_loc, _ = QFileDialog.getSaveFileName( + self, "Choose destination", str(self._last_loc), "" + ) + suffix = Path(self._last_loc).suffix + if suffix in (".zarr", ".ome.zarr", ""): + self._data_wrapper.save_as_zarr(self._last_loc) + else: + raise ValueError(f"Unsupported file format: {self._last_loc}") diff --git a/src/ndv/viewer/_stack_viewer.py b/src/ndv/viewer/_stack_viewer.py new file mode 100644 index 00000000..294e9025 --- /dev/null +++ b/src/ndv/viewer/_stack_viewer.py @@ -0,0 +1,591 @@ +from __future__ import annotations + +from collections import defaultdict +from enum import Enum +from itertools import cycle +from typing import TYPE_CHECKING, Iterable, Literal, Mapping, Sequence, cast + +import cmap +import numpy as np +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread +from superqt.utils import qthrottled, signals_blocked + +from ._backends import get_canvas +from ._dims_slider import DimsSliders +from ._indexing import DataWrapper +from ._lut_control import LutControl + +if TYPE_CHECKING: + from concurrent.futures import Future + from typing import Any, Callable, Hashable, TypeAlias + + from qtpy.QtGui import QCloseEvent + + from ._dims_slider import DimKey, Indices, Sizes + from ._protocols import PCanvas, PImageHandle + + ImgKey: TypeAlias = Hashable + # any mapping of dimensions to sizes + SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] + +MID_GRAY = "#888888" +GRAYS = cmap.Colormap("gray") +DEFAULT_COLORMAPS = [ + cmap.Colormap("green"), + cmap.Colormap("magenta"), + cmap.Colormap("cyan"), + cmap.Colormap("yellow"), + cmap.Colormap("red"), + cmap.Colormap("blue"), + cmap.Colormap("cubehelix"), + cmap.Colormap("gray"), +] +ALL_CHANNELS = slice(None) + + +class ChannelMode(str, Enum): + COMPOSITE = "composite" + MONO = "mono" + + def __str__(self) -> str: + return self.value + + +class ChannelModeButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self.setCheckable(True) + self.toggled.connect(self.next_mode) + + # set minimum width to the width of the larger string 'composite' + self.setMinimumWidth(92) # FIXME: magic number + + def next_mode(self) -> None: + if self.isChecked(): + self.setMode(ChannelMode.MONO) + else: + self.setMode(ChannelMode.COMPOSITE) + + def mode(self) -> ChannelMode: + return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE + + def setMode(self, mode: ChannelMode) -> None: + # we show the name of the next mode, not the current one + other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO + self.setText(str(other)) + self.setChecked(mode == ChannelMode.MONO) + + +class DimToggleButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + icn = QIconifyIcon("f7:view-2d", color="#333333") + icn.addKey("f7:view-3d", state=QIconifyIcon.State.On, color="white") + super().__init__(icn, "", parent) + self.setCheckable(True) + self.setChecked(True) + + +# @dataclass +# class LutModel: +# name: str = "" +# autoscale: bool = True +# min: float = 0.0 +# max: float = 1.0 +# colormap: cmap.Colormap = GRAYS +# visible: bool = True + + +# @dataclass +# class ViewerModel: +# data: Any = None +# # dimensions of the data that will *not* be sliced. +# visualized_dims: Container[DimKey] = (-2, -1) +# # the axis that represents the channels in the data +# channel_axis: DimKey | None = None +# # the mode for displaying the channels +# # if MONO, only the current selection of channel_axis is displayed +# # if COMPOSITE, the full channel_axis is sliced, and luts determine display +# channel_mode: ChannelMode = ChannelMode.MONO +# # map of index in the channel_axis to LutModel +# luts: Mapping[int, LutModel] = {} + + +class NDViewer(QWidget): + """A viewer for ND arrays. + + This widget displays a single slice from an ND array (or a composite of slices in + different colormaps). The widget provides sliders to select the slice to display, + and buttons to control the display mode of the channels. + + An important concept in this widget is the "index". The index is a mapping of + dimensions to integers or slices that define the slice of the data to display. For + example, a numpy slice of `[0, 1, 5:10]` would be represented as + `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be named, e.g. + `{'t': 0, 'c': 1, 'z': slice(5, 10)}`. The index is used to select the data from + the datastore, and to determine the position of the sliders. + + The flow of data is as follows: + + - The user sets the data using the `set_data` method. This will set the number + and range of the sliders to the shape of the data, and display the first slice. + - The user can then use the sliders to select the slice to display. The current + slice is defined as a `Mapping` of `{dim -> int|slice}` and can be retrieved + with the `_dims_sliders.value()` method. To programmatically set the current + position, use the `setIndex` method. This will set the values of the sliders, + which in turn will trigger the display of the new slice via the + `_update_data_for_index` method. + - `_update_data_for_index` is an asynchronous method that retrieves the data for + the given index from the datastore (using `_isel`) and queues the + `_on_data_slice_ready` method to be called when the data is ready. The logic + for extracting data from the datastore is defined in `_indexing.py`, which handles + idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). + - `_on_data_slice_ready` is called when the data is ready, and updates the image. + Note that if the slice is multidimensional, the data will be reduced to 2D using + max intensity projection (and double-clicking on any given dimension slider will + turn it into a range slider allowing a projection to be made over that dimension). + - The image is displayed on the canvas, which is an object that implements the + `PCanvas` protocol (mostly, it has an `add_image` method that returns a handle + to the added image that can be used to update the data and display). This + small abstraction allows for various backends to be used (e.g. vispy, pygfx, etc). + + Parameters + ---------- + data : Any + The data to display. This can be an ND array, an xarray DataArray, or any + object that supports numpy-style indexing. + parent : QWidget, optional + The parent widget of this widget. + channel_axis : Hashable, optional + The axis that represents the channels in the data. If not provided, this will + be guessed from the data. + channel_mode : ChannelMode, optional + The initial mode for displaying the channels. If not provided, this will be + set to ChannelMode.MONO. + """ + + def __init__( + self, + data: Any, + *, + colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, + parent: QWidget | None = None, + channel_axis: DimKey | None = None, + channel_mode: ChannelMode | str = ChannelMode.MONO, + ): + super().__init__(parent=parent) + + # ATTRIBUTES ---------------------------------------------------- + + # dimensions of the data in the datastore + self._sizes: Sizes = {} + # mapping of key to a list of objects that control image nodes in the canvas + self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) + # mapping of same keys to the LutControl objects control image display props + self._lut_ctrls: dict[ImgKey, LutControl] = {} + # the set of dimensions we are currently visualizing (e.g. XY) + # this is used to control which dimensions have sliders and the behavior + # of isel when selecting data from the datastore + self._visualized_dims: set[DimKey] = set() + # the axis that represents the channels in the data + self._channel_axis = channel_axis + self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode + # colormaps that will be cycled through when displaying composite images + # TODO: allow user to set this + if colormaps is not None: + self._cmaps = [cmap.Colormap(c) for c in colormaps] + else: + self._cmaps = DEFAULT_COLORMAPS + self._cmap_cycle = cycle(self._cmaps) + # the last future that was created by _update_data_for_index + self._last_future: Future | None = None + + # number of dimensions to display + self._ndims: Literal[2, 3] = 2 + + # WIDGETS ---------------------------------------------------- + + # the button that controls the display mode of the channels + self._channel_mode_btn = ChannelModeButton(self) + self._channel_mode_btn.clicked.connect(self.set_channel_mode) + # button to reset the zoom of the canvas + self._set_range_btn = QPushButton( + QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self + ) + self._set_range_btn.clicked.connect(self._on_set_range_clicked) + + # button to change number of displayed dimensions + self._ndims_btn = DimToggleButton(self) + self._ndims_btn.clicked.connect(self.toggle_3d) + + # place to display dataset summary + self._data_info_label = QElidingLabel("", parent=self) + # place to display arbitrary text + self._hover_info_label = QLabel("", self) + # the canvas that displays the images + self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText) + self._canvas.set_ndim(self._ndims) + + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders(self) + self._dims_sliders.valueChanged.connect( + qthrottled(self._update_data_for_index, 20, leading=True) + ) + + self._lut_drop = QCollapsible("LUTs", self) + self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) + self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY)) + lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) + lut_layout.setContentsMargins(0, 1, 0, 1) + lut_layout.setSpacing(0) + if ( + hasattr(self._lut_drop, "_content") + and (layout := self._lut_drop._content.layout()) is not None + ): + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # LAYOUT ----------------------------------------------------- + + self._btns = btns = QHBoxLayout() + btns.setContentsMargins(0, 0, 0, 0) + btns.setSpacing(0) + btns.addStretch() + btns.addWidget(self._channel_mode_btn) + btns.addWidget(self._ndims_btn) + btns.addWidget(self._set_range_btn) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addWidget(self._data_info_label) + layout.addWidget(self._canvas.qwidget(), 1) + layout.addWidget(self._hover_info_label) + layout.addWidget(self._dims_sliders) + layout.addWidget(self._lut_drop) + layout.addLayout(btns) + + # SETUP ------------------------------------------------------ + + self.set_channel_mode(channel_mode) + if data is not None: + self.set_data(data) + + # ------------------- PUBLIC API ---------------------------- + @property + def data(self) -> Any: + """Return the data backing the view.""" + return self._data_wrapper._data + + @data.setter + def data(self, data: Any) -> None: + """Set the data backing the view.""" + raise AttributeError("Cannot set data directly. Use `set_data` method.") + + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + + @property + def sizes(self) -> Sizes: + """Return sizes {dimkey: int} of the dimensions in the datastore.""" + return self._sizes + + def set_data( + self, + data: Any, + sizes: SizesLike | None = None, + channel_axis: int | None = None, + visualized_dims: Iterable[DimKey] | None = None, + ) -> None: + """Set the datastore, and, optionally, the sizes of the data.""" + # store the data + self._data_wrapper = DataWrapper.create(data) + + # determine sizes of the data + self._sizes = self._data_wrapper.sizes() if sizes is None else _to_sizes(sizes) + + # set channel axis + if channel_axis is not None: + self._channel_axis = channel_axis + elif self._channel_axis is None: + self._channel_axis = self._data_wrapper.guess_channel_axis() + + # update the dimensions we are visualizing + if visualized_dims is None: + visualized_dims = list(self._sizes)[-self._ndims :] + self.set_visualized_dims(visualized_dims) + + # update the range of all the sliders to match the sizes we set above + with signals_blocked(self._dims_sliders): + self.update_slider_ranges() + # redraw + self.setIndex({}) + # update the data info label + self._data_info_label.setText(self._data_wrapper.summary_info()) + + def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: + """Set the dimensions that will be visualized. + + This dims will NOT have sliders associated with them. + """ + self._visualized_dims = set(dims) + for d in self._dims_sliders._sliders: + self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) + for d in self._visualized_dims: + self._dims_sliders.set_dimension_visible(d, False) + + def update_slider_ranges( + self, mins: SizesLike | None = None, maxes: SizesLike | None = None + ) -> None: + """Set the maximum values of the sliders. + + If `sizes` is not provided, sizes will be inferred from the datastore. + This is mostly here as a public way to reset the + """ + if maxes is None: + maxes = self._sizes + maxes = _to_sizes(maxes) + self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) + if mins is not None: + self._dims_sliders.setMinima(_to_sizes(mins)) + + # FIXME: this needs to be moved and made user-controlled + for dim in list(maxes.keys())[-self._ndims :]: + self._dims_sliders.set_dimension_visible(dim, False) + + def toggle_3d(self) -> None: + self.set_ndim(3 if self._ndims == 2 else 2) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions to display.""" + self._ndims = ndim + self._canvas.set_ndim(ndim) + + # set the visibility of the last non-channel dimension + sizes = list(self._sizes) + if self._channel_axis is not None: + sizes = [x for x in sizes if x != self._channel_axis] + if len(sizes) >= 3: + dim3 = sizes[-3] + self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) + + # clear image handles and redraw + if self._img_handles: + self._clear_images() + self._update_data_for_index(self._dims_sliders.value()) + + def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: + """Set the mode for displaying the channels. + + In "composite" mode, the channels are displayed as a composite image, using + self._channel_axis as the channel axis. In "grayscale" mode, each channel is + displayed separately. (If mode is None, the current value of the + channel_mode_picker button is used) + """ + if mode is None or isinstance(mode, bool): + mode = self._channel_mode_btn.mode() + else: + mode = ChannelMode(mode) + self._channel_mode_btn.setMode(mode) + if mode == getattr(self, "_channel_mode", None): + return + + self._channel_mode = mode + self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle + if self._channel_axis is not None: + # set the visibility of the channel slider + self._dims_sliders.set_dimension_visible( + self._channel_axis, mode != ChannelMode.COMPOSITE + ) + + if self._img_handles: + self._clear_images() + self._update_data_for_index(self._dims_sliders.value()) + + def setIndex(self, index: Indices) -> None: + """Set the index of the displayed image.""" + self._dims_sliders.setValue(index) + + # ------------------- PRIVATE METHODS ---------------------------- + + def _on_set_range_clicked(self) -> None: + # using method to swallow the parameter passed by _set_range_btn.clicked + self._canvas.set_range() + + def _image_key(self, index: Indices) -> ImgKey: + """Return the key for image handle(s) corresponding to `index`.""" + if self._channel_mode == ChannelMode.COMPOSITE: + val = index.get(self._channel_axis, 0) + if isinstance(val, slice): + return (val.start, val.stop) + return val + return 0 + + def _update_data_for_index(self, index: Indices) -> None: + """Retrieve data for `index` from datastore and update canvas image(s). + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. This method is *asynchronous*. It + makes a request for the new data slice and queues _on_data_future_done to be + called when the data is ready. + """ + if ( + self._channel_axis is not None + and self._channel_mode == ChannelMode.COMPOSITE + ): + indices: list[Indices] = [ + {**index, self._channel_axis: i} + for i in range(self._sizes[self._channel_axis]) + ] + else: + indices = [index] + + if self._last_future: + self._last_future.cancel() + + # don't request any dimensions that are not visualized + indices = [ + {k: v for k, v in idx.items() if k not in self._visualized_dims} + for idx in indices + ] + self._last_future = f = self._isel(indices) + f.add_done_callback(self._on_data_slice_ready) + + def closeEvent(self, a0: QCloseEvent | None) -> None: + if self._last_future is not None: + self._last_future.cancel() + self._last_future = None + super().closeEvent(a0) + + def _isel( + self, indices: list[Indices] + ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + """Select data from the datastore using the given index.""" + try: + return self._data_wrapper.isel_async(indices) + except Exception as e: + raise type(e)(f"Failed to index data with {indices}: {e}") from e + + @ensure_main_thread # type: ignore + def _on_data_slice_ready( + self, future: Future[Iterable[tuple[Indices, np.ndarray]]] + ) -> None: + """Update the displayed image for the given index. + + Connected to the future returned by _isel. + """ + # NOTE: removing the reference to the last future here is important + # because the future has a reference to this widget in its _done_callbacks + # which will prevent the widget from being garbage collected if the future + self._last_future = None + if future.cancelled(): + return + + data = future.result() + # FIXME: + # `self._channel_axis: i` is a bug; we assume channel indices start at 0 + # but the actual values used for indices are up to the user. + for idx, datum in data: + self._update_canvas_data(datum, idx) + self._canvas.refresh() + + def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: + """Actually update the image handle(s) with the (sliced) data. + + By this point, data should be sliced from the underlying datastore. Any + dimensions remaining that are more than the number of visualized dimensions + (currently just 2D) will be reduced using max intensity projection (currently). + """ + imkey = self._image_key(index) + datum = self._reduce_data_for_display(data) + if handles := self._img_handles[imkey]: + for handle in handles: + handle.data = datum + if ctrl := self._lut_ctrls.get(imkey, None): + ctrl.update_autoscale() + else: + cm = ( + next(self._cmap_cycle) + if self._channel_mode == ChannelMode.COMPOSITE + else GRAYS + ) + if datum.ndim == 2: + handles.append(self._canvas.add_image(datum, cmap=cm)) + elif datum.ndim == 3: + handles.append(self._canvas.add_volume(datum, cmap=cm)) + if imkey not in self._lut_ctrls: + channel_name = self._get_channel_name(index) + self._lut_ctrls[imkey] = c = LutControl( + channel_name, + handles, + self, + cmaplist=self._cmaps + DEFAULT_COLORMAPS, + ) + self._lut_drop.addWidget(c) + + def _get_channel_name(self, index: Indices) -> str: + c = index.get(self._channel_axis, 0) + return f"Ch {c}" # TODO: get name from user + + def _reduce_data_for_display( + self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max + ) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + + This also coerces 64-bit data to 32-bit data. + """ + # TODO + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the smallest dims) + data = data.squeeze() + visualized_dims = self._ndims + if extra_dims := data.ndim - visualized_dims: + shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) + data = reductor(data, axis=smallest_dims) + + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) + return data + + def _clear_images(self) -> None: + """Remove all images from the canvas.""" + for handles in self._img_handles.values(): + for handle in handles: + handle.remove() + self._img_handles.clear() + + # clear the current LutControls as well + for c in self._lut_ctrls.values(): + cast("QVBoxLayout", self.layout()).removeWidget(c) + c.deleteLater() + self._lut_ctrls.clear() + + +def _to_sizes(sizes: SizesLike | None) -> Sizes: + """Coerce `sizes` to a {dimKey -> int} mapping.""" + if sizes is None: + return {} + if isinstance(sizes, Mapping): + return {k: int(v) for k, v in sizes.items()} + if not isinstance(sizes, Iterable): + raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}") + _sizes: dict[Hashable, int] = {} + for i, val in enumerate(sizes): + if isinstance(val, int): + _sizes[i] = val + elif isinstance(val, Sequence) and len(val) == 2: + _sizes[val[0]] = int(val[1]) + else: + raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.") + return _sizes diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..3e27fd83 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,45 @@ +import gc +from typing import TYPE_CHECKING, Iterator + +import pytest + +if TYPE_CHECKING: + from pytest import FixtureRequest + from qtpy.QtWidgets import QApplication + + +@pytest.fixture(autouse=True) +def _find_leaks(request: "FixtureRequest", qapp: "QApplication") -> Iterator[None]: + """Run after each test to ensure no widgets have been left around. + + When this test fails, it means that a widget being tested has an issue closing + cleanly. Perhaps a strong reference has leaked somewhere. Look for + `functools.partial(self._method)` or `lambda: self._method` being used in that + widget's code. + """ + nbefore = len(qapp.topLevelWidgets()) + failures_before = request.session.testsfailed + yield + # if the test failed, don't worry about checking widgets + if request.session.testsfailed - failures_before: + return + remaining = qapp.topLevelWidgets() + if len(remaining) > nbefore: + test_node = request.node + + test = f"{test_node.path.name}::{test_node.originalname}" + msg = f"{len(remaining)} topLevelWidgets remaining after {test!r}:" + + for widget in remaining: + try: + obj_name = widget.objectName() + except Exception: + obj_name = None + msg += f"\n{widget!r} {obj_name!r}" + # Get the referrers of the widget + referrers = gc.get_referrers(widget) + msg += "\n Referrers:" + for ref in referrers: + msg += f"\n - {ref}, {id(ref):#x}" + + raise AssertionError(msg) diff --git a/tests/test_nd_viewer.py b/tests/test_nd_viewer.py new file mode 100644 index 00000000..68a8e4e4 --- /dev/null +++ b/tests/test_nd_viewer.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np + +from ndv import NDViewer + +if TYPE_CHECKING: + from pytestqt.qtbot import QtBot + + +def make_lazy_array(shape: tuple[int, ...]) -> da.Array: + rest_shape = shape[:-2] + frame_shape = shape[-2:] + + def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: + if isinstance(block_id, np.ndarray): + return None + size = (1,) * len(rest_shape) + frame_shape + return np.random.randint(0, 255, size=size, dtype=np.uint8) + + chunks = [(1,) * x for x in rest_shape] + [(x,) for x in frame_shape] + return da.map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) # type: ignore + + +def test_stack_viewer2(qtbot: QtBot) -> None: + dask_arr = make_lazy_array((1000, 64, 3, 256, 256)) + v = NDViewer(dask_arr) + qtbot.addWidget(v) + v.show() + + # wait until there are no running jobs, because the callbacks + # in the futures hold a strong reference to the viewer + qtbot.waitUntil(lambda: v._last_future is None, timeout=1000) diff --git a/tests/test_ndv.py b/tests/test_ndv.py deleted file mode 100644 index 363b3e20..00000000 --- a/tests/test_ndv.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_something(): - pass