diff --git a/examples/notebook.ipynb b/examples/notebook.ipynb index 3791a05..e4aed8a 100644 --- a/examples/notebook.ipynb +++ b/examples/notebook.ipynb @@ -2,13 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, + "id": "461399b0-e02d-43d9-9ede-c1aa6c180338", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8b582aca26e04cc4a59d755ea986c44f", + "model_id": "d931ccb769154ecaabedf58b566913aa", "version_major": 2, "version_minor": 0 }, @@ -22,7 +23,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3141028c938243b59325ff7ab52e1b3f", + "model_id": "af68f584e2cc479492806c577e70d960", "version_major": 2, "version_minor": 0 }, @@ -45,6 +46,7 @@ { "cell_type": "code", "execution_count": 2, + "id": "455ebabe-c2c1-4366-9784-65e45def5aa2", "metadata": {}, "outputs": [], "source": [ @@ -53,11 +55,12 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, + "id": "fb96f975-aa05-4d8a-9f85-f4316293e05d", "metadata": {}, "outputs": [], "source": [ - "viewer.display_model.default_lut.cmap = \"cubehelix\"\n", + "viewer.display_model.default_lut.cmap = \"magma\"\n", "viewer.display_model.channel_mode = \"grayscale\"" ] } @@ -82,5 +85,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/src/ndv/controllers/_array_viewer.py b/src/ndv/controllers/_array_viewer.py index 0043c10..5b3efad 100644 --- a/src/ndv/controllers/_array_viewer.py +++ b/src/ndv/controllers/_array_viewer.py @@ -63,6 +63,9 @@ def __init__( "When display_model is provided, kwargs are be ignored.", stacklevel=2, ) + self._data_model = _ArrayDataDisplayModel( + data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs) + ) # mapping of channel keys to their respective controllers # where None is the default channel @@ -72,21 +75,19 @@ def __init__( frontend_cls = _app.get_array_view_class() canvas_cls = _app.get_array_canvas_class() self._canvas = canvas_cls() - self._canvas.set_ndim(2) self._histogram: HistogramCanvas | None = None - self._view = frontend_cls(self._canvas.frontend_widget()) + self._view = frontend_cls(self._canvas.frontend_widget(), self._data_model) - display_model = display_model or ArrayDisplayModel(**kwargs) - self._data_model = _ArrayDataDisplayModel( - data_wrapper=data, display=display_model - ) self._set_model_connected(self._data_model.display) + self._canvas.set_ndim(self.display_model.n_visible_axes) self._view.currentIndexChanged.connect(self._on_view_current_index_changed) self._view.resetZoomClicked.connect(self._on_view_reset_zoom_clicked) self._view.histogramRequested.connect(self._add_histogram) self._view.channelModeChanged.connect(self._on_view_channel_mode_changed) + self._view.visibleAxesChanged.connect(self._on_view_visible_axes_changed) + self._canvas.mouseMoved.connect(self._on_canvas_mouse_moved) if self._data_model.data_wrapper is not None: @@ -119,7 +120,7 @@ def display_model(self) -> ArrayDisplayModel: @display_model.setter def display_model(self, model: ArrayDisplayModel) -> None: """Set the ArrayDisplayModel.""" - if not isinstance(model, ArrayDisplayModel): + if not isinstance(model, ArrayDisplayModel): # pragma: no cover raise TypeError("model must be an ArrayDisplayModel") self._set_model_connected(self._data_model.display, False) self._data_model.display = model @@ -233,6 +234,7 @@ def _fully_synchronize_view(self) -> None: self._view.create_sliders(self._data_model.normed_data_coords) self._view.set_channel_mode(display_model.channel_mode) if self.data is not None: + self._view.set_visible_axes(self._data_model.normed_visible_axes) self._update_visible_sliders() if cur_index := display_model.current_index: self._view.set_current_index(cur_index) @@ -248,8 +250,11 @@ def _fully_synchronize_view(self) -> None: self._update_hist_domain_for_dtype() def _on_model_visible_axes_changed(self) -> None: + self._view.set_visible_axes(self._data_model.normed_visible_axes) self._update_visible_sliders() + self._clear_canvas() self._update_canvas() + self._canvas.set_ndim(self.display_model.n_visible_axes) def _on_model_current_index_changed(self) -> None: value = self._data_model.display.current_index @@ -283,6 +288,10 @@ def _on_view_current_index_changed(self) -> None: """Update the model when slider value changes.""" self._data_model.display.current_index.update(self._view.current_index()) + def _on_view_visible_axes_changed(self) -> None: + """Update the model when the visible axes change.""" + self.display_model.visible_axes = self._view.visible_axes() # type: ignore [assignment] + def _on_view_reset_zoom_clicked(self) -> None: """Reset the zoom level of the canvas.""" self._canvas.set_range() @@ -353,17 +362,24 @@ def _update_canvas(self) -> None: if not lut_ctrl.handles: # we don't yet have any handles for this channel - lut_ctrl.add_handle(self._canvas.add_image(data)) + if response.n_visible_axes == 2: + handle = self._canvas.add_image(data) + lut_ctrl.add_handle(handle) + elif response.n_visible_axes == 3: + handle = self._canvas.add_volume(data) + lut_ctrl.add_handle(handle) + else: lut_ctrl.update_texture_data(data) - if self._histogram is not None: - # TODO: once data comes in in chunks, we'll need a proper stateful - # stats object that calculates the histogram incrementally - counts, bin_edges = _calc_hist_bins(data) - # TODO: currently this is updating the histogram on *any* - # channel index... so it doesn't work with composite mode - self._histogram.set_data(counts, bin_edges) - self._histogram.set_range() + + if self._histogram is not None: + # TODO: once data comes in in chunks, we'll need a proper stateful + # stats object that calculates the histogram incrementally + counts, bin_edges = _calc_hist_bins(data) + # FIXME: currently this is updating the histogram on *any* + # channel index... so it doesn't work with composite mode + self._histogram.set_data(counts, bin_edges) + self._histogram.set_range() self._canvas.refresh() diff --git a/src/ndv/data.py b/src/ndv/data.py index 59d649e..d8a0a6e 100644 --- a/src/ndv/data.py +++ b/src/ndv/data.py @@ -70,8 +70,8 @@ def cells3d() -> np.ndarray: # this data has been stretched to 16 bit, and lacks certain intensity values # add a small random integer to each pixel ... so the histogram is not silly - data = (data + np.random.randint(-24, 24, data.shape)).astype(np.uint16) - return data + data = (data + np.random.randint(-24, 24, data.shape)).clip(0, 65535) + return data.astype(np.uint16) def cat() -> np.ndarray: diff --git a/src/ndv/models/_data_display_model.py b/src/ndv/models/_data_display_model.py index 591e23d..1ac2207 100644 --- a/src/ndv/models/_data_display_model.py +++ b/src/ndv/models/_data_display_model.py @@ -17,7 +17,7 @@ class DataRequest: """Request object for data slicing.""" - wrapper: DataWrapper + wrapper: DataWrapper = field(repr=False) index: Mapping[int, Union[int, slice]] visible_axes: tuple[int, ...] channel_axis: Optional[int] @@ -28,10 +28,19 @@ class DataResponse: """Response object for data requests.""" data: np.ndarray = field(repr=False) + shape: tuple[int, ...] = field(init=False) + dtype: np.dtype = field(init=False) channel_key: Optional[int] + n_visible_axes: int request: Optional[DataRequest] = None + def __post_init__(self) -> None: + self.shape = self.data.shape + self.dtype = self.data.dtype + +# NOTE: nobody particularly likes this class. It does important stuff, but we're +# not yet sure where this logic belongs. class _ArrayDataDisplayModel(NDVModel): """Utility class combining ArrayDisplayModel model with a DataWrapper. @@ -154,14 +163,13 @@ def current_slice_requests(self) -> list[DataRequest]: if isinstance(val, int): requested_slice[ax] = slice(val, val + 1) - return [ - DataRequest( - wrapper=self.data_wrapper, - index=requested_slice, - visible_axes=self.normed_visible_axes, - channel_axis=c_ax, - ) - ] + request = DataRequest( + wrapper=self.data_wrapper, + index=requested_slice, + visible_axes=self.normed_visible_axes, + channel_axis=c_ax, + ) + return [request] # TODO: make async def request_sliced_data(self) -> list[Future[DataResponse]]: @@ -192,13 +200,13 @@ def request_sliced_data(self) -> list[Future[DataResponse]]: ch_keepdims = (slice(None),) * cast(int, ch_ax) + (i,) + (None,) ch_data = data[ch_keepdims] future = Future[DataResponse]() - future.set_result( - DataResponse( - data=ch_data.transpose(*t_dims).squeeze(), - channel_key=i, - request=req, - ) + response = DataResponse( + data=ch_data.transpose(*t_dims).squeeze(), + channel_key=i, + n_visible_axes=len(vis_ax), + request=req, ) + future.set_result(response) futures.append(future) return futures diff --git a/src/ndv/models/_data_wrapper.py b/src/ndv/models/_data_wrapper.py index c981910..af9c9db 100644 --- a/src/ndv/models/_data_wrapper.py +++ b/src/ndv/models/_data_wrapper.py @@ -71,6 +71,8 @@ class DataWrapper(Generic[ArrayT], ABC): PRIORITY: ClassVar[int] = 50 # These names will be checked when looking for a channel axis COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c") + COMMON_Z_AXIS_NAMES: ClassVar[Container[str]] = ("z", "depth", "focus") + # Maximum dimension size consider when guessing the channel axis MAX_CHANNELS = 16 @@ -159,6 +161,9 @@ def sizes(self) -> Mapping[Hashable, int]: """Return the sizes of the dimensions.""" return {dim: len(self.coords[dim]) for dim in self.dims} + # these guess_x methods may change in the future to become more agnostic to the + # dimension name/semantics that they are guessing. + def guess_channel_axis(self) -> Hashable | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, @@ -172,6 +177,22 @@ def guess_channel_axis(self) -> Hashable | None: # otherwise use the smallest dimension as the channel axis return min(sizes, key=sizes.get) # type: ignore [arg-type] + def guess_z_axis(self) -> Hashable | None: + """Return the (best guess) axis name for the z (3rd spatial) dimension.""" + sizes = self.sizes() + ch = self.guess_channel_axis() + for dimkey in sizes: + if str(dimkey).lower() in self.COMMON_Z_AXIS_NAMES: + if (normed := self.normalized_axis_key(dimkey)) != ch: + return normed + + # otherwise return the LAST axis that is neither in the last two dimensions + # or the channel axis guess + return next( + (self.normalized_axis_key(x) for x in reversed(self.dims[:-2]) if x != ch), + None, + ) + def summary_info(self) -> str: """Return info label with information about the data.""" package = getattr(self._data, "__module__", "").split(".")[0] diff --git a/src/ndv/views/_jupyter/_array_view.py b/src/ndv/views/_jupyter/_array_view.py index 48e0e87..ec6f238 100644 --- a/src/ndv/views/_jupyter/_array_view.py +++ b/src/ndv/views/_jupyter/_array_view.py @@ -15,6 +15,7 @@ from vispy.app.backends import _jupyter_rfb from ndv._types import AxisKey + from ndv.models._data_display_model import _ArrayDataDisplayModel # not entirely sure why it's necessary to specifically annotat signals as : PSignal # i think it has to do with type variance? @@ -125,10 +126,14 @@ def frontend_widget(self) -> Any: class JupyterArrayView(ArrayView): def __init__( - self, canvas_widget: _jupyter_rfb.CanvasBackend, **kwargs: Any + self, + canvas_widget: _jupyter_rfb.CanvasBackend, + data_model: _ArrayDataDisplayModel, ) -> None: # WIDGETS + self._data_model = data_model self._canvas_widget = canvas_widget + self._visible_axes: Sequence[AxisKey] = [] self._sliders: dict[Hashable, widgets.IntSlider] = {} self._slider_box = widgets.VBox([], layout=widgets.Layout(width="100%")) @@ -146,6 +151,16 @@ def __init__( self._channel_mode_combo.layout.align_self = "flex-end" self._channel_mode_combo.observe(self._on_channel_mode_changed, names="value") + self._ndims_btn = widgets.ToggleButton( + value=False, + description="3D", + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="View in 3D", + icon="check", + layout=widgets.Layout(width="60px"), + ) + self._ndims_btn.observe(self._on_ndims_toggled, names="value") + # LAYOUT try: @@ -153,6 +168,11 @@ def __init__( width = f"{int(width) + 4}px" except Exception: width = "604px" + + btns = widgets.HBox( + [self._channel_mode_combo, self._ndims_btn], + layout=widgets.Layout(justify_content="flex-end"), + ) self.layout = widgets.VBox( [ self._data_info_label, @@ -160,7 +180,7 @@ def __init__( self._hover_info_label, self._slider_box, self._luts_box, - self._channel_mode_combo, + btns, ], layout=widgets.Layout(width=width), ) @@ -279,5 +299,30 @@ def set_visible(self, visible: bool) -> None: else: display.clear_output() # type: ignore [no-untyped-call] + def visible_axes(self) -> Sequence[AxisKey]: + return self._visible_axes + + def set_visible_axes(self, axes: Sequence[AxisKey]) -> None: + self._visible_axes = tuple(axes) + self._ndims_btn.value = len(axes) == 3 + + def _on_ndims_toggled(self, change: dict[str, Any]) -> None: + if len(self._visible_axes) > 2: + if not change["new"]: # is now 2D + self._visible_axes = self._visible_axes[-2:] + else: + z_ax = None + if wrapper := self._data_model.data_wrapper: + z_ax = wrapper.guess_z_axis() + if z_ax is None: + # get the last slider that is not in visible axes + z_ax = next( + ax for ax in reversed(self._sliders) if ax not in self._visible_axes + ) + self._visible_axes = (z_ax, *self._visible_axes) + # TODO: a future PR may decide to set this on the model directly... + # since we now have access to it. + self.visibleAxesChanged.emit() + def close(self) -> None: self.layout.close() diff --git a/src/ndv/views/_pygfx/_array_canvas.py b/src/ndv/views/_pygfx/_array_canvas.py index 3cc0a47..87471c8 100755 --- a/src/ndv/views/_pygfx/_array_canvas.py +++ b/src/ndv/views/_pygfx/_array_canvas.py @@ -463,6 +463,11 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: if ndim == 3: self._camera = cam = pygfx.PerspectiveCamera(0, 1) with suppress(ValueError): + # if the scene has no children yet, this will raise a ValueErrors + # FIXME: there's a bit of order-of-call problem here: + # this method needs to be called *after* the scene is constructed... + # that's what controller._on_model_visible_axes_changed does, but + # it seems fragile and should be fixed. cam.show_object(self._scene, up=(0, -1, 0), view_dir=(0, 0, 1)) controller = pygfx.OrbitController(cam, register_events=self._renderer) zoom = "zoom" @@ -487,6 +492,11 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle: """Add a new Image node to the scene.""" + if data is not None: + # pygfx uses a view of the data without copy, so if we don't + # copy it here, the original data will be modified when the + # texture changes. + data = data.copy() tex = pygfx.Texture(data, dim=2) image = pygfx.Image( pygfx.Geometry(grid=tex), @@ -507,6 +517,11 @@ def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle: return handle def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle: + if data is not None: + # pygfx uses a view of the data without copy, so if we don't + # copy it here, the original data will be modified when the + # texture changes. + data = data.copy() tex = pygfx.Texture(data, dim=3) vol = pygfx.Volume( pygfx.Geometry(grid=tex), diff --git a/src/ndv/views/_qt/_array_view.py b/src/ndv/views/_qt/_array_view.py index 6652a8f..8664349 100644 --- a/src/ndv/views/_qt/_array_view.py +++ b/src/ndv/views/_qt/_array_view.py @@ -30,6 +30,7 @@ from qtpy.QtGui import QIcon from ndv._types import AxisKey + from ndv.models._data_display_model import _ArrayDataDisplayModel SLIDER_STYLE = """ QSlider::groove:horizontal { @@ -89,6 +90,14 @@ def setCurrentColormap(self, cmap_: cmap.Colormap) -> None: self.addColormap(cmap_) +class _DimToggleButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + icn = QIconifyIcon("f7:view-2d", color="#333333") + icn.addKey("f7:view-3d") + super().__init__(icn, "", parent) + self.setCheckable(True) + + class _QLUTWidget(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) @@ -302,8 +311,11 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): self._btn_layout.setParent(None) self.luts.expand() + # button to change number of displayed dimensions + self.ndims_btn = _DimToggleButton(self) + self._btn_layout.addWidget(self.channel_mode_combo) - # self._btns.addWidget(self._ndims_btn) + self._btn_layout.addWidget(self.ndims_btn) self._btn_layout.addWidget(self.histogram_btn) self._btn_layout.addWidget(self.set_range_btn) # self._btns.addWidget(self._add_roi_btn) @@ -337,7 +349,10 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): class QtArrayView(ArrayView): - def __init__(self, canvas_widget: QWidget) -> None: + def __init__( + self, canvas_widget: QWidget, data_model: _ArrayDataDisplayModel + ) -> None: + self._data_model = data_model self._qwidget = qwdg = _QArrayViewer(canvas_widget) qwdg.histogram_btn.clicked.connect(self._on_add_histogram_clicked) @@ -347,6 +362,9 @@ def __init__(self, canvas_widget: QWidget) -> None: self._on_channel_mode_changed ) qwdg.set_range_btn.clicked.connect(self.resetZoomClicked.emit) + qwdg.ndims_btn.toggled.connect(self._on_ndims_toggled) + + self._visible_axes: Sequence[AxisKey] = [] def add_lut_view(self) -> QLutView: view = QLutView() @@ -398,6 +416,30 @@ def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: """Set the current value of the sliders.""" self._qwidget.dims_sliders.set_current_index(value) + def _on_ndims_toggled(self, is_3d: bool) -> None: + if len(self._visible_axes) > 2: + if not is_3d: # is now 2D + self._visible_axes = self._visible_axes[-2:] + else: + z_ax = None + if wrapper := self._data_model.data_wrapper: + z_ax = wrapper.guess_z_axis() + if z_ax is None: + # get the last slider that is not in visible axes + sld = reversed(self._qwidget.dims_sliders._sliders) + z_ax = next(ax for ax in sld if ax not in self._visible_axes) + self._visible_axes = (z_ax, *self._visible_axes) + # TODO: a future PR may decide to set this on the model directly... + # since we now have access to it. + self.visibleAxesChanged.emit() + + def visible_axes(self) -> Sequence[AxisKey]: + return self._visible_axes # no widget to control this yet + + def set_visible_axes(self, axes: Sequence[AxisKey]) -> None: + self._visible_axes = tuple(axes) + self._qwidget.ndims_btn.setChecked(len(axes) > 2) + def set_data_info(self, text: str) -> None: """Set the data info text, above the canvas.""" self._qwidget.data_info_label.setText(text) diff --git a/src/ndv/views/_wx/_array_view.py b/src/ndv/views/_wx/_array_view.py index 71028c0..0f4de62 100644 --- a/src/ndv/views/_wx/_array_view.py +++ b/src/ndv/views/_wx/_array_view.py @@ -19,6 +19,7 @@ import cmap from ndv._types import AxisKey + from ndv.models._data_display_model import _ArrayDataDisplayModel # mostly copied from _qt.qt_view._QLUTWidget @@ -207,6 +208,9 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): # Reset zoom button self.reset_zoom_btn = wx.Button(self, label="Reset Zoom") + # Reset zoom button + self.ndims_btn = wx.ToggleButton(self, label="3D") + # LUT layout (simple vertical grouping for LUT widgets) self.luts = wx.BoxSizer(wx.VERTICAL) @@ -214,6 +218,7 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): btns.AddStretchSpacer() btns.Add(self.channel_mode_combo, 0, wx.ALL, 5) btns.Add(self.reset_zoom_btn, 0, wx.ALL, 5) + btns.Add(self.ndims_btn, 0, wx.ALL, 5) # Layout for the panel inner = wx.BoxSizer(wx.VERTICAL) @@ -232,13 +237,21 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): class WxArrayView(ArrayView): - def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None) -> None: + def __init__( + self, + canvas_widget: wx.Window, + data_model: _ArrayDataDisplayModel, + parent: wx.Window = None, + ) -> None: + self._data_model = data_model self._wxwidget = wdg = _WxArrayViewer(canvas_widget, parent) + self._visible_axes: Sequence[AxisKey] = [] # TODO: use emit_fast wdg.dims_sliders.currentIndexChanged.connect(self.currentIndexChanged.emit) wdg.channel_mode_combo.Bind(wx.EVT_COMBOBOX, self._on_channel_mode_changed) wdg.reset_zoom_btn.Bind(wx.EVT_BUTTON, self._on_reset_zoom_clicked) + wdg.ndims_btn.Bind(wx.EVT_TOGGLEBUTTON, self._on_ndims_toggled) def _on_channel_mode_changed(self, event: wx.CommandEvent) -> None: mode = self._wxwidget.channel_mode_combo.GetValue() @@ -247,6 +260,31 @@ def _on_channel_mode_changed(self, event: wx.CommandEvent) -> None: def _on_reset_zoom_clicked(self, event: wx.CommandEvent) -> None: self.resetZoomClicked.emit() + def _on_ndims_toggled(self, event: wx.CommandEvent) -> None: + is_3d = self._wxwidget.ndims_btn.GetValue() + if len(self._visible_axes) > 2: + if not is_3d: # is now 2D + self._visible_axes = self._visible_axes[-2:] + else: + z_ax = None + if wrapper := self._data_model.data_wrapper: + z_ax = wrapper.guess_z_axis() + if z_ax is None: + # get the last slider that is not in visible axes + sld = reversed(self._wxwidget.dims_sliders._sliders) + z_ax = next(ax for ax in sld if ax not in self._visible_axes) + self._visible_axes = (z_ax, *self._visible_axes) + # TODO: a future PR may decide to set this on the model directly... + # since we now have access to it. + self.visibleAxesChanged.emit() + + def visible_axes(self) -> Sequence[AxisKey]: + return self._visible_axes # no widget to control this yet + + def set_visible_axes(self, axes: Sequence[AxisKey]) -> None: + self._visible_axes = tuple(axes) + self._wxwidget.ndims_btn.SetValue(len(axes) == 3) + def frontend_widget(self) -> wx.Window: return self._wxwidget diff --git a/src/ndv/views/bases/_array_view.py b/src/ndv/views/bases/_array_view.py index 997cce0..abe6990 100644 --- a/src/ndv/views/bases/_array_view.py +++ b/src/ndv/views/bases/_array_view.py @@ -14,6 +14,7 @@ from collections.abc import Container, Hashable, Mapping, Sequence from ndv._types import AxisKey + from ndv.models._data_display_model import _ArrayDataDisplayModel from ndv.views.bases import LutView @@ -28,16 +29,26 @@ class ArrayView(Viewable): currentIndexChanged = Signal() resetZoomClicked = Signal() histogramRequested = Signal() + visibleAxesChanged = Signal() channelModeChanged = Signal(ChannelMode) + # model: _ArrayDataDisplayModel is likely a temporary parameter @abstractmethod - def __init__(self, canvas_widget: Any, **kwargs: Any) -> None: ... + def __init__( + self, canvas_widget: Any, model: _ArrayDataDisplayModel, **kwargs: Any + ) -> None: ... @abstractmethod def create_sliders(self, coords: Mapping[int, Sequence]) -> None: ... @abstractmethod def current_index(self) -> Mapping[AxisKey, int | slice]: ... @abstractmethod def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: ... + + @abstractmethod + def visible_axes(self) -> Sequence[AxisKey]: ... + @abstractmethod + def set_visible_axes(self, axes: Sequence[AxisKey]) -> None: ... + @abstractmethod def set_channel_mode(self, mode: ChannelMode) -> None: ... @abstractmethod diff --git a/tests/test_controller.py b/tests/test_controller.py index 2488198..8be82fe 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -23,9 +23,13 @@ def _get_mock_canvas() -> ArrayCanvas: mock = MagicMock(spec=ArrayCanvas) - handle = MagicMock(spec=ImageHandle) - handle.data.return_value = np.zeros((10, 10)).astype(np.uint8) - mock.add_image.return_value = handle + img_handle = MagicMock(spec=ImageHandle) + img_handle.data.return_value = np.zeros((10, 10)).astype(np.uint8) + mock.add_image.return_value = img_handle + + vol_handle = MagicMock(spec=ImageHandle) + vol_handle.data.return_value = np.zeros((10, 10, 10)).astype(np.uint8) + mock.add_volume.return_value = vol_handle return mock @@ -96,6 +100,11 @@ def test_controller() -> None: ctrl._on_view_current_index_changed() assert model.current_index == idx + # when the view sets 3 dimensions, the model is updated + mock_view.visible_axes.return_value = (0, -2, -1) + ctrl._on_view_visible_axes_changed() + assert model.visible_axes == (0, -2, -1) + # when the view changes the channel mode, the model is updated assert model.channel_mode == ChannelMode.GRAYSCALE ctrl._on_view_channel_mode_changed(ChannelMode.COMPOSITE) @@ -175,7 +184,21 @@ def test_array_viewer_with_app() -> None: index_mock.assert_called_once() for k, v in index.items(): assert viewer.display_model.current_index[k] == v + # setting again should not trigger the signal index_mock.reset_mock() viewer._view.set_current_index(index) index_mock.assert_not_called() + + # test_setting 3D + assert viewer.display_model.visible_axes == (-2, -1) + visax_mock = Mock() + viewer.display_model.events.visible_axes.connect(visax_mock) + viewer._view.set_visible_axes((0, -2, -1)) + + # FIXME: + # calling set_visible_axes on wx during testing is not triggering the + # _on_ndims_toggled callback... and I don't know enough about wx yet to know why. + if gui_frontend() != _app.GuiFrontend.WX: + visax_mock.assert_called_once() + assert viewer.display_model.visible_axes == (0, -2, -1)