diff --git a/src/ndv/_types.py b/src/ndv/_types.py index f39c68db..26c33827 100644 --- a/src/ndv/_types.py +++ b/src/ndv/_types.py @@ -5,6 +5,7 @@ from collections.abc import Hashable, Sequence from contextlib import suppress from enum import Enum, IntFlag, auto +from functools import cache from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, cast from pydantic import PlainSerializer, PlainValidator @@ -13,6 +14,7 @@ if TYPE_CHECKING: from qtpy.QtCore import Qt from qtpy.QtWidgets import QWidget + from wx import Cursor from ndv.views.bases import Viewable @@ -56,6 +58,7 @@ class MouseButton(IntFlag): LEFT = auto() MIDDLE = auto() RIGHT = auto() + NONE = auto() class MouseMoveEvent(NamedTuple): @@ -63,6 +66,7 @@ class MouseMoveEvent(NamedTuple): x: float y: float + btn: MouseButton = MouseButton.NONE class MousePressEvent(NamedTuple): @@ -70,7 +74,7 @@ class MousePressEvent(NamedTuple): x: float y: float - btn: MouseButton = MouseButton.LEFT + btn: MouseButton class MouseReleaseEvent(NamedTuple): @@ -78,11 +82,12 @@ class MouseReleaseEvent(NamedTuple): x: float y: float - btn: MouseButton = MouseButton.LEFT + btn: MouseButton class CursorType(Enum): DEFAULT = "default" + CROSS = "cross" V_ARROW = "v_arrow" H_ARROW = "h_arrow" ALL_ARROW = "all_arrow" @@ -101,9 +106,47 @@ def to_qt(self) -> Qt.CursorShape: return { CursorType.DEFAULT: Qt.CursorShape.ArrowCursor, + CursorType.CROSS: Qt.CursorShape.CrossCursor, CursorType.V_ARROW: Qt.CursorShape.SizeVerCursor, CursorType.H_ARROW: Qt.CursorShape.SizeHorCursor, CursorType.ALL_ARROW: Qt.CursorShape.SizeAllCursor, CursorType.BDIAG_ARROW: Qt.CursorShape.SizeBDiagCursor, CursorType.FDIAG_ARROW: Qt.CursorShape.SizeFDiagCursor, }[self] + + def to_jupyter(self) -> str: + """Converts CursorType to jupyter cursor strings.""" + return { + CursorType.DEFAULT: "default", + CursorType.CROSS: "crosshair", + CursorType.V_ARROW: "ns-resize", + CursorType.H_ARROW: "ew-resize", + CursorType.ALL_ARROW: "move", + CursorType.BDIAG_ARROW: "nesw-resize", + CursorType.FDIAG_ARROW: "nwse-resize", + }[self] + + # Note a new object must be created every time. We should cache it! + @cache + def to_wx(self) -> Cursor: + """Converts CursorType to jupyter cursor strings.""" + from wx import ( + CURSOR_ARROW, + CURSOR_CROSS, + CURSOR_SIZENESW, + CURSOR_SIZENS, + CURSOR_SIZENWSE, + CURSOR_SIZEWE, + CURSOR_SIZING, + Cursor, + ) + + return { + CursorType.DEFAULT: Cursor(CURSOR_ARROW), + CursorType.CROSS: Cursor(CURSOR_CROSS), + CursorType.V_ARROW: Cursor(CURSOR_SIZENS), + CursorType.H_ARROW: Cursor(CURSOR_SIZEWE), + CursorType.ALL_ARROW: Cursor(CURSOR_SIZING), + CursorType.BDIAG_ARROW: Cursor(CURSOR_SIZENESW), + CursorType.FDIAG_ARROW: Cursor(CURSOR_SIZENWSE), + }[self] diff --git a/src/ndv/controllers/_array_viewer.py b/src/ndv/controllers/_array_viewer.py index cdc1a559..eb3e8456 100644 --- a/src/ndv/controllers/_array_viewer.py +++ b/src/ndv/controllers/_array_viewer.py @@ -9,6 +9,8 @@ from ndv.controllers._channel_controller import ChannelController from ndv.models import ArrayDisplayModel, ChannelMode, DataWrapper, LUTModel from ndv.models._data_display_model import DataResponse, _ArrayDataDisplayModel +from ndv.models._roi_model import RectangularROIModel +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views import _app if TYPE_CHECKING: @@ -21,6 +23,7 @@ from ndv._types import MouseMoveEvent from ndv.models._array_display_model import ArrayDisplayModelKwargs from ndv.views.bases import HistogramCanvas + from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle LutKey: TypeAlias = int | None @@ -69,6 +72,11 @@ def __init__( self._data_model = _ArrayDataDisplayModel( data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs) ) + self._viewer_model = ArrayViewerModel() + self._viewer_model.events.interaction_mode.connect( + self._on_interaction_mode_changed + ) + self._roi_model: RectangularROIModel | None = None app = _app.gui_frontend() @@ -88,10 +96,14 @@ def __init__( # get and create the front-end and canvas classes frontend_cls = _app.get_array_view_class() canvas_cls = _app.get_array_canvas_class() - self._canvas = canvas_cls() + self._canvas = canvas_cls(self._viewer_model) self._histogram: HistogramCanvas | None = None - self._view = frontend_cls(self._canvas.frontend_widget(), self._data_model) + self._view = frontend_cls( + self._canvas.frontend_widget(), self._data_model, self._viewer_model + ) + + self._roi_view: RectangularROIHandle | None = None self._set_model_connected(self._data_model.display) self._canvas.set_ndim(self.display_model.n_visible_axes) @@ -163,6 +175,27 @@ def data(self, data: Any) -> None: self._data_model.data_wrapper = DataWrapper.create(data) self._fully_synchronize_view() + @property + def roi(self) -> RectangularROIModel | None: + """Return ROI being displayed.""" + return self._roi_model + + @roi.setter + def roi(self, roi_model: RectangularROIModel | None) -> None: + """Set ROI being displayed.""" + # Disconnect old model + if self._roi_model is not None: + self._set_roi_model_connected(self._roi_model, False) + + # Connect new model + if isinstance(roi_model, tuple): + self._roi_model = RectangularROIModel(bounding_box=roi_model) + else: + self._roi_model = roi_model + if self._roi_model is not None: + self._set_roi_model_connected(self._roi_model) + self._synchronize_roi() + def show(self) -> None: """Show the viewer.""" self._view.set_visible(True) @@ -239,6 +272,28 @@ def _set_model_connected( ]: getattr(obj, _connect)(callback) + def _set_roi_model_connected( + self, model: RectangularROIModel, connect: bool = True + ) -> None: + """Connect or disconnect the model to/from the viewer. + + We do this in a single method so that we are sure to connect and disconnect + the same events in the same order. (but it's kinda ugly) + """ + _connect = "connect" if connect else "disconnect" + + for obj, callback in [ + (model.events.bounding_box, self._on_roi_model_bounding_box_changed), + (model.events.visible, self._on_roi_model_visible_changed), + ]: + getattr(obj, _connect)(callback) + + if _connect: + self._create_roi_view() + else: + if self._roi_view: + self._roi_view.remove() + # ------------------ Model callbacks ------------------ def _fully_synchronize_view(self) -> None: @@ -261,6 +316,13 @@ def _fully_synchronize_view(self) -> None: for lut_ctr in self._lut_controllers.values(): lut_ctr.synchronize() self._update_hist_domain_for_dtype() + self._synchronize_roi() + + def _synchronize_roi(self) -> None: + """Fully re-synchronize the ROI view with the model.""" + if self.roi is not None: + self._on_roi_model_bounding_box_changed(self.roi.bounding_box) + self._on_roi_model_visible_changed(self.roi.visible) def _on_model_visible_axes_changed(self) -> None: self._view.set_visible_axes(self._data_model.normed_visible_axes) @@ -288,6 +350,38 @@ def _on_model_channel_mode_changed(self, mode: ChannelMode) -> None: self._clear_canvas() self._request_data() + def _on_roi_model_bounding_box_changed( + self, bb: tuple[tuple[float, float], tuple[float, float]] + ) -> None: + if self._roi_view is not None: + self._roi_view.set_bounding_box(*bb) + + def _on_roi_model_visible_changed(self, visible: bool) -> None: + if self._roi_view is not None: + self._roi_view.set_visible(visible) + + def _on_interaction_mode_changed(self, mode: InteractionMode) -> None: + if mode == InteractionMode.CREATE_ROI: + # Create ROI model if needed to store ROI state + if self.roi is None: + self.roi = RectangularROIModel(visible=False) + + # Create a new ROI + self._create_roi_view() + + def _create_roi_view(self) -> None: + # Remove old ROI view + # TODO: Enable multiple ROIs + if self._roi_view: + self._roi_view.remove() + + # Create new ROI view + self._roi_view = self._canvas.add_bounding_box() + # Connect view signals + self._roi_view.boundingBoxChanged.connect( + self._on_roi_view_bounding_box_changed + ) + def _clear_canvas(self) -> None: for lut_ctrl in self._lut_controllers.values(): # self._view.remove_lut_view(lut_ctrl.lut_view) @@ -309,6 +403,12 @@ def _on_view_reset_zoom_clicked(self) -> None: """Reset the zoom level of the canvas.""" self._canvas.set_range() + def _on_roi_view_bounding_box_changed( + self, bb: tuple[tuple[float, float], tuple[float, float]] + ) -> None: + if self._roi_model: + self._roi_model.bounding_box = bb + def _on_canvas_mouse_moved(self, event: MouseMoveEvent) -> None: """Respond to a mouse move event in the view.""" x, y, _z = self._canvas.canvas_to_world((event.x, event.y)) diff --git a/src/ndv/models/__init__.py b/src/ndv/models/__init__.py index 40c4dd5a..18fc4ddf 100644 --- a/src/ndv/models/__init__.py +++ b/src/ndv/models/__init__.py @@ -11,6 +11,7 @@ ClimsStdDev, LUTModel, ) +from ._roi_model import RectangularROIModel __all__ = [ "ArrayDisplayModel", @@ -23,4 +24,5 @@ "DataWrapper", "LUTModel", "NDVModel", + "RectangularROIModel", ] diff --git a/src/ndv/models/_roi_model.py b/src/ndv/models/_roi_model.py new file mode 100644 index 00000000..3e5dc5c1 --- /dev/null +++ b/src/ndv/models/_roi_model.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from pydantic import field_validator + +from ndv.models._base_model import NDVModel + + +class RectangularROIModel(NDVModel): + """Representation of an axis-aligned rectangular Region of Interest (ROI). + + Attributes + ---------- + visible : bool + Whether to display this roi. + bounding_box : tuple[tuple[float, float], tuple[float, float]] + The minimum (2D) point and the maximum (2D) point contained within the + region. Using these two points, an axis-aligned bounding box can be + constructed. + """ + + visible: bool = True + bounding_box: tuple[tuple[float, float], tuple[float, float]] = ((0, 0), (0, 0)) + + @field_validator("bounding_box", mode="after") + @classmethod + def _validate_bounding_box( + cls, bb: tuple[tuple[float, float], tuple[float, float]] + ) -> tuple[tuple[float, float], tuple[float, float]]: + x1 = min(bb[0][0], bb[1][0]) + y1 = min(bb[0][1], bb[1][1]) + x2 = max(bb[0][0], bb[1][0]) + y2 = max(bb[0][1], bb[1][1]) + return ((x1, y1), (x2, y2)) diff --git a/src/ndv/models/_viewer_model.py b/src/ndv/models/_viewer_model.py new file mode 100644 index 00000000..f56227bb --- /dev/null +++ b/src/ndv/models/_viewer_model.py @@ -0,0 +1,27 @@ +from enum import Enum, auto + +from ndv.models._base_model import NDVModel + + +class InteractionMode(Enum): + """An enum defining graphical interaction mechanisms with an array Viewer.""" + + PAN_ZOOM = auto() # Mode allowing the user to pan and zoom + CREATE_ROI = auto() # Mode where user clicks create ROIs + + +class ArrayViewerModel(NDVModel): + """Representation of an array viewer. + + TODO: This will likely contain other fields including: + * Dimensionality + * Camera position + * Camera frustum + + Parameters + ---------- + interaction_mode : InteractionMode + Describes the current interaction mode of the Viewer. + """ + + interaction_mode: InteractionMode = InteractionMode.PAN_ZOOM diff --git a/src/ndv/v1/_old_viewer.py b/src/ndv/v1/_old_viewer.py index 8d73ea24..1d1ff2a3 100755 --- a/src/ndv/v1/_old_viewer.py +++ b/src/ndv/v1/_old_viewer.py @@ -12,6 +12,8 @@ from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread from superqt.utils import qthrottled, signals_blocked +from ndv._types import MouseButton, MouseMoveEvent, MousePressEvent +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views import get_array_canvas_class from ._old_data_wrapper import DataWrapper @@ -37,7 +39,7 @@ from ndv.views.bases._graphics._canvas_elements import ( CanvasElement, ImageHandle, - RoiHandle, + RectangularROIHandle, ) DimKey = int @@ -129,6 +131,10 @@ def __init__( channel_mode: ChannelMode | str = ChannelMode.MONO, ): super().__init__(parent=parent) + self._array_viewer_model = ArrayViewerModel() + self._array_viewer_model.events.interaction_mode.connect( + self._on_view_model_mode_changed + ) # ATTRIBUTES ---------------------------------------------------- self._data_wrapper: DataWrapper | None = None @@ -160,7 +166,7 @@ def __init__( # Canvas selection self._selection: CanvasElement | None = None # ROI - self._roi: RoiHandle | None = None + self._roi: RectangularROIHandle | None = None # WIDGETS ---------------------------------------------------- @@ -187,7 +193,7 @@ def __init__( # place to display arbitrary text self._hover_info_label = QLabel("", self) # the canvas that displays the images - self._canvas: ArrayCanvas = get_array_canvas_class()() + self._canvas: ArrayCanvas = get_array_canvas_class()(self._array_viewer_model) self._canvas.set_ndim(self._ndims) self._qcanvas = self._canvas.frontend_widget() @@ -343,14 +349,30 @@ def set_roi( border_color : str, tuple, list, array, Color, or int The border color. Can be any "ColorLike". """ - # Remove the old ROI + roi = self._canvas.add_bounding_box() + if color: + roi.set_fill(cmap.Color(color)) + if border_color: + roi.set_border(cmap.Color(border_color)) + if vertices: + roi.set_bounding_box(vertices[0], vertices[2]) + + # Assert vertices represent axis-aligned rectangle + if vertices is not None: + if len(vertices) != 4: + raise ValueError("Only rectangles are currently supported") + if ( + vertices[0][1] != vertices[1][1] + or vertices[1][0] != vertices[2][0] + or vertices[2][1] != vertices[3][1] + or vertices[3][0] != vertices[0][0] + ): + raise ValueError("Only axis-aligned rectangles are currently supported") + + # Remove the old ROI and add the new one if self._roi: self._roi.remove() - color = cmap.Color(color) if color is not None else None - border_color = cmap.Color(border_color) if border_color is not None else None - self._roi = self._canvas.add_roi( - vertices=vertices, color=color, border_color=border_color - ) + self._roi = roi def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: """Set the dimensions that will be visualized. @@ -474,10 +496,23 @@ def _on_set_range_clicked(self) -> None: # using method to swallow the parameter passed by _set_range_btn.clicked self._canvas.set_range() - def _on_add_roi_clicked(self, checked: bool) -> None: - if checked: - # Add new roi + def _on_view_model_mode_changed(self, mode: InteractionMode) -> None: + if mode == InteractionMode.CREATE_ROI: self.set_roi() + # Since we have no ROIModel, HACK the view to listen to itself + if (roi := self._roi) is not None: + + def on_bounding_box_edited( + bb: tuple[tuple[float, float], tuple[float, float]], + ) -> None: + roi.set_bounding_box(bb[0], bb[1]) + + roi.boundingBoxChanged.connect(on_bounding_box_edited) + + def _on_add_roi_clicked(self, checked: bool) -> None: + self._array_viewer_model.interaction_mode = ( + InteractionMode.CREATE_ROI if checked else InteractionMode.PAN_ZOOM + ) def _image_key(self, index: Indices) -> ImgKey: """Return the key for image handle(s) corresponding to `index`.""" @@ -683,35 +718,39 @@ def _canvas_mouse_event(self, ev: QMouseEvent) -> bool: # FIXME: This is ugly def _begin_roi(self, event: QMouseEvent) -> bool: if self._roi: - ev_pos = event.position() - pos = self._canvas.canvas_to_world((ev_pos.x(), ev_pos.y())) - self._roi.move(pos) + canvas_pos = (event.pos().x(), event.pos().y()) + world_pos = self._canvas.canvas_to_world(canvas_pos)[:2] + # HACK: Provide a non-zero starting size so that if the user clicks + # and immediately releases, it's visible and can be selected again + _min = world_pos + _max = (world_pos[0] + 1, world_pos[1] + 1) + # Put the ROI where the user clicked + self._roi.set_bounding_box(_min, _max) self._roi.set_visible(True) return False def _press_element(self, event: QMouseEvent) -> bool: ev_pos = (event.position().x(), event.position().y()) - pos = self._canvas.canvas_to_world(ev_pos) - # TODO why does the canvas need this point untransformed?? elements = self._canvas.elements_at(ev_pos) # Deselect prior selection before editing new selection if self._selection: self._selection.set_selected(False) + # Select a new element, if one is under the mouse for e in elements: if e.can_select(): - e.start_move(pos) - # Select new selection self._selection = e - self._selection.set_selected(True) + self._selection.on_mouse_press( + MousePressEvent(ev_pos[0], ev_pos[1], MouseButton.LEFT) + ) return False return False def _move_selection(self, event: QMouseEvent) -> bool: if event.buttons() == Qt.MouseButton.LeftButton: if self._selection and self._selection.selected(): - ev_pos = event.pos() - pos = self._canvas.canvas_to_world((ev_pos.x(), ev_pos.y())) - self._selection.move(pos) + self._selection.on_mouse_move( + MouseMoveEvent(event.pos().x(), event.pos().y(), MouseButton.LEFT) + ) # If we are moving the object, we don't want to move the camera return True return False @@ -725,9 +764,10 @@ def _update_cursor(self, event: QMouseEvent) -> bool: self._qcanvas.setCursor(Qt.CursorShape.CrossCursor) return False # If any local elements have a preference, use it + mme = MouseMoveEvent(event.pos().x(), event.pos().y()) pos = (event.pos().x(), event.pos().y()) for e in self._canvas.elements_at(pos): - if (pref := e.cursor_at(pos)) is not None: + if (pref := e.get_cursor(mme)) is not None: self._qcanvas.setCursor(pref.to_qt()) return False # Otherwise, normal cursor diff --git a/src/ndv/views/_jupyter/_app.py b/src/ndv/views/_jupyter/_app.py index e622d59d..3d52ccdd 100644 --- a/src/ndv/views/_jupyter/_app.py +++ b/src/ndv/views/_jupyter/_app.py @@ -6,7 +6,12 @@ from jupyter_rfb import RemoteFrameBuffer -from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent +from ndv._types import ( + MouseButton, + MouseMoveEvent, + MousePressEvent, + MouseReleaseEvent, +) from ndv.views.bases._app import NDVApp if TYPE_CHECKING: @@ -41,6 +46,19 @@ def array_view_class(self) -> type[ArrayView]: return JupyterArrayView + @staticmethod + def mouse_btn(btn: Any) -> MouseButton: + if btn == 0: + return MouseButton.NONE + if btn == 1: + return MouseButton.LEFT + if btn == 2: + return MouseButton.RIGHT + if btn == 3: + return MouseButton.MIDDLE + + raise Exception(f"Jupyter mouse button {btn} is unknown") + def filter_mouse_events( self, canvas: Any, receiver: Mouseable ) -> Callable[[], None]: @@ -52,22 +70,35 @@ def filter_mouse_events( # patch the handle_event from _jupyter_rfb.CanvasBackend # to intercept various mouse events. super_handle_event = canvas.handle_event + active_btn: MouseButton = MouseButton.NONE def handle_event(self: RemoteFrameBuffer, ev: dict) -> None: + nonlocal active_btn + + intercepted = False etype = ev["event_type"] if etype == "pointer_move": - mme = MouseMoveEvent(x=ev["x"], y=ev["y"]) - receiver.on_mouse_move(mme) + mme = MouseMoveEvent(x=ev["x"], y=ev["y"], btn=active_btn) + intercepted |= receiver.on_mouse_move(mme) + if cursor := receiver.get_cursor(mme): + canvas.cursor = cursor.to_jupyter() receiver.mouseMoved.emit(mme) elif etype == "pointer_down": - mpe = MousePressEvent(x=ev["x"], y=ev["y"]) - receiver.on_mouse_press(mpe) + if "button" in ev: + active_btn = JupyterAppWrap.mouse_btn(ev["button"]) + else: + active_btn = MouseButton.NONE + mpe = MousePressEvent(x=ev["x"], y=ev["y"], btn=active_btn) + intercepted |= receiver.on_mouse_press(mpe) receiver.mousePressed.emit(mpe) elif etype == "pointer_up": - mre = MouseReleaseEvent(x=ev["x"], y=ev["y"]) - receiver.on_mouse_release(mre) + mre = MouseReleaseEvent(x=ev["x"], y=ev["y"], btn=active_btn) + active_btn = MouseButton.NONE + intercepted |= receiver.on_mouse_release(mre) receiver.mouseReleased.emit(mre) - super_handle_event(ev) + + if not intercepted: + super_handle_event(ev) canvas.handle_event = MethodType(handle_event, canvas) return lambda: setattr(canvas, "handle_event", super_handle_event) diff --git a/src/ndv/views/_jupyter/_array_view.py b/src/ndv/views/_jupyter/_array_view.py index d4afd6b4..3ffd4375 100644 --- a/src/ndv/views/_jupyter/_array_view.py +++ b/src/ndv/views/_jupyter/_array_view.py @@ -9,6 +9,7 @@ from ndv.models._array_display_model import ChannelMode from ndv.models._lut_model import ClimPolicy, ClimsManual, ClimsMinMax +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views.bases import ArrayView, LutView if TYPE_CHECKING: @@ -152,7 +153,10 @@ def __init__( self, canvas_widget: _jupyter_rfb.CanvasBackend, data_model: _ArrayDataDisplayModel, + viewer_model: ArrayViewerModel, ) -> None: + self._viewer_model = viewer_model + self._viewer_model.events.interaction_mode.connect(self._on_model_mode_changed) # WIDGETS self._data_model = data_model self._canvas_widget = canvas_widget @@ -198,6 +202,17 @@ def __init__( ) self._ndims_btn.observe(self._on_ndims_toggled, names="value") + # Add ROI button + self._add_roi_btn = widgets.ToggleButton( + value=False, + description="New ROI", + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Adds a new Rectangular ROI.", + icon="square", + ) + + self._add_roi_btn.observe(self._on_add_roi_button_toggle, names="value") + # LAYOUT top_row = widgets.HBox( @@ -216,7 +231,12 @@ def __init__( width = "604px" btns = widgets.HBox( - [self._channel_mode_combo, self._ndims_btn, self._reset_zoom_btn], + [ + self._channel_mode_combo, + self._ndims_btn, + self._add_roi_btn, + self._reset_zoom_btn, + ], layout=widgets.Layout(justify_content="flex-end"), ) self.layout = widgets.VBox( @@ -326,6 +346,19 @@ def _on_channel_mode_changed(self, change: dict[str, Any]) -> None: """Emit signal when the channel mode changes.""" self.channelModeChanged.emit(ChannelMode(change["new"])) + def _on_add_roi_button_toggle(self, change: dict[str, Any]) -> None: + """Emit signal when the channel mode changes.""" + self._viewer_model.interaction_mode = ( + InteractionMode.CREATE_ROI if change["new"] else InteractionMode.PAN_ZOOM + ) + + def _on_model_mode_changed( + self, new: InteractionMode, old: InteractionMode + ) -> None: + # If leaving CanvasMode.CREATE_ROI, uncheck the ROI button + if old == InteractionMode.CREATE_ROI: + self._add_roi_btn.value = False + def add_histogram(self, widget: Any) -> None: """Add a histogram widget to the viewer.""" warnings.warn("Histograms are not supported in Jupyter frontend", stacklevel=2) diff --git a/src/ndv/views/_pygfx/_array_canvas.py b/src/ndv/views/_pygfx/_array_canvas.py index f09fbeb2..555d0681 100755 --- a/src/ndv/views/_pygfx/_array_canvas.py +++ b/src/ndv/views/_pygfx/_array_canvas.py @@ -3,16 +3,24 @@ import warnings from contextlib import suppress from typing import TYPE_CHECKING, Any, Callable, Literal, cast -from weakref import WeakKeyDictionary +from weakref import ReferenceType, WeakKeyDictionary, ref import cmap as _cmap import numpy as np import pygfx import pylinalg as la -from ndv._types import CursorType +from ndv._types import ( + CursorType, + MouseButton, + MouseMoveEvent, + MousePressEvent, + MouseReleaseEvent, +) +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views._app import filter_mouse_events from ndv.views.bases import ArrayCanvas, CanvasElement, ImageHandle +from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle, ROIMoveMode if TYPE_CHECKING: from collections.abc import Sequence @@ -22,8 +30,9 @@ from pygfx.resources import Texture from wgpu.gui.jupyter import JupyterWgpuCanvas from wgpu.gui.qt import QWgpuCanvas + from wgpu.gui.wx import WxWgpuCanvas - WgpuCanvas: TypeAlias = QWgpuCanvas | JupyterWgpuCanvas + WgpuCanvas: TypeAlias = QWgpuCanvas | JupyterWgpuCanvas | WxWgpuCanvas def _is_inside(bounding_box: np.ndarray, pos: Sequence[float]) -> bool: @@ -97,215 +106,86 @@ def remove(self) -> None: if (par := self._image.parent) is not None: par.remove(self._image) - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: + def get_cursor(self, mme: MouseMoveEvent) -> CursorType | None: return None -class PyGFXRoiHandle(pygfx.WorldObject): - _render: Callable = lambda _: None +class PyGFXRectangle(RectangularROIHandle): + def __init__( + self, + render: Callable, + canvas_to_world: Callable, + parent: pygfx.WorldObject | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + # Positional array backing visual objects + # NB we need five points for the outline + # The first and last rows should be identical + self._positions: np.ndarray = np.zeros((5, 3), dtype=np.float32) - def __init__(self, render: Callable, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, *kwargs) + # Visual objects self._fill = self._create_fill() - if self._fill: - self.add(self._fill) self._outline = self._create_outline() - if self._outline: - self.add(self._outline) - self._handles = self._create_handles() - if self._handles: - self.add(self._handles) - - self._render = render - - def _create_fill(self) -> pygfx.Mesh | None: - # To be implemented by subclasses needing a fill - return None - - def _create_outline(self) -> pygfx.Line | None: - # To be implemented by subclasses needing an outline - return None - - def _create_handles(self) -> pygfx.Points | None: - # To be implemented by subclasses needing handles - return None - - @property - def vertices(self) -> Sequence[Sequence[float]]: - # To be implemented by subclasses - raise NotImplementedError("Must be implemented in subclasses") - - @vertices.setter - def vertices(self, data: Sequence[Sequence[float]]) -> None: - # To be implemented by subclasses - raise NotImplementedError("Must be implemented in subclasses") - - @property - def visible(self) -> bool: - if self._outline: - return bool(self._outline.visible) - if self._fill: - return bool(self._fill.visible) - # Nothing to see - return False - - @visible.setter - def visible(self, visible: bool) -> None: - if fill := getattr(self, "_fill", None): - fill.visible = visible - if outline := getattr(self, "_outline", None): - outline.visible = visible - if handles := getattr(self, "_handles", None): - handles.visible = self.selected - self._render() - def can_select(self) -> bool: - return True - - def selected(self) -> bool: - if self._handles: - return bool(self._handles.visible) - # Can't be selected without handles - return False + # Handles used for ROI manipulation + self._handle_rad = 5 # PIXELS + self._handles = self._create_handles() - def set_selected(self, selected: bool) -> None: - if self._handles: - self._handles.visible = selected + # containing all ROI objects makes selection easier. + self._container = pygfx.WorldObject(*args, **kwargs) + self._container.add(self._fill, self._outline, self._handles) + if parent: + parent.add(self._container) + + # Utilities for moving ROI + self._selected = False + self._move_mode: ROIMoveMode | None = None + # NB _move_anchor has different meanings depending on _move_mode + self._move_anchor: tuple[float, float] = (0, 0) + self._render: Callable = render + self._canvas_to_world: Callable = canvas_to_world + + # Initialize + self.set_fill(_cmap.Color("transparent")) + self.set_border(_cmap.Color("yellow")) + self.set_handles(_cmap.Color("white")) + self.set_visible(False) + + # -- BoundingBox methods -- # + + def set_bounding_box( + self, minimum: tuple[float, float], maximum: tuple[float, float] + ) -> None: + # NB: Support two diagonal points, not necessarily true min/max + x1 = float(min(minimum[0], maximum[0])) + y1 = float(min(minimum[1], maximum[1])) + x2 = float(max(minimum[0], maximum[0])) + y2 = float(max(minimum[1], maximum[1])) - def color(self) -> Any: - if self._fill: - return _cmap.Color(self._fill.material.color) - return _cmap.Color("transparent") + # Update each handle + self._positions[0, :2] = [x1, y1] + self._positions[1, :2] = [x2, y1] + self._positions[2, :2] = [x2, y2] + self._positions[3, :2] = [x1, y2] + self._positions[4, :2] = [x1, y1] + self._refresh() - def set_color(self, color: _cmap.Color | None = None) -> None: + def set_fill(self, color: _cmap.Color) -> None: if self._fill: - if color is None: - color = _cmap.Color("transparent") - if not isinstance(color, _cmap.Color): - color = _cmap.Color(color) self._fill.material.color = color.rgba self._render() - def border_color(self) -> Any: - if self._outline: - return _cmap.Color(self._outline.material.color) - return _cmap.Color("transparent") - - def set_border_color(self, color: _cmap.Color | None = None) -> None: + def set_border(self, color: _cmap.Color) -> None: if self._outline: - if color is None: - color = _cmap.Color("yellow") - if not isinstance(color, _cmap.Color): - color = _cmap.Color(color) self._outline.material.color = color.rgba self._render() - def start_move(self, pos: Sequence[float]) -> None: - # To be implemented by subclasses - raise NotImplementedError("Must be implemented in subclasses") - - def move(self, pos: Sequence[float]) -> None: - # To be implemented by subclasses - raise NotImplementedError("Must be implemented in subclasses") - - def remove(self) -> None: - if (par := self.parent) is not None: - par.remove(self) - - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: - # To be implemented by subclasses - raise NotImplementedError("Must be implemented in subclasses") - - -class RectangularROIHandle(PyGFXRoiHandle): - def __init__( - self, render: Callable, canvas_to_world: Callable, *args: Any, **kwargs: Any - ) -> None: - self._point_rad = 5 # PIXELS - self._positions: np.ndarray = np.zeros((5, 3), dtype=np.float32) - - super().__init__(render, *args, *kwargs) - self._canvas_to_world = canvas_to_world - - # drag_reference defines the offset between where the user clicks and the center - # of the rectangle - self._drag_idx: int | None = None - self._offset = np.zeros((5, 2)) - self._on_drag = [ - self._move_handle_0, - self._move_handle_1, - self._move_handle_2, - self._move_handle_3, - ] - - @property - def vertices(self) -> Sequence[Sequence[float]]: - # Buffer object - return [p[:2] for p in self._positions] - - @vertices.setter - def vertices(self, vertices: Sequence[Sequence[float]]) -> None: - if len(vertices) != 4 or any(len(v) != 2 for v in vertices): - raise Exception("Only 2D rectangles are currently supported") - is_aligned = ( - vertices[0][1] == vertices[1][1] - and vertices[1][0] == vertices[2][0] - and vertices[2][1] == vertices[3][1] - and vertices[3][0] == vertices[0][0] - ) - if not is_aligned: - raise Exception( - "Only rectangles aligned with the axes are currently supported" - ) - - # Update each handle - self._positions[:-1, :2] = vertices - self._positions[-1, :2] = vertices[0] - self._refresh() - - def start_move(self, pos: Sequence[float]) -> None: - self._drag_idx = self._handle_hover_idx(pos) - - if self._drag_idx is None: - self._offset[:, :] = self._positions[:, :2] - pos[:2] - - def move(self, pos: Sequence[float]) -> None: - if self._drag_idx is not None: - self._on_drag[self._drag_idx](pos) - else: - # TODO: We could potentially do this smarter via transforms - self._positions[:, :2] = self._offset[:, :2] + pos[:2] - self._refresh() - - def _move_handle_0(self, pos: Sequence[float]) -> None: - # NB pygfx requires (idx 0) = (idx 4) - self._positions[0, :2] = pos[:2] - self._positions[4, :2] = pos[:2] - - self._positions[3, 0] = pos[0] - self._positions[1, 1] = pos[1] - - def _move_handle_1(self, pos: Sequence[float]) -> None: - self._positions[1, :2] = pos[:2] - - self._positions[2, 0] = pos[0] - # NB pygfx requires (idx 0) = (idx 4) - self._positions[0, 1] = pos[1] - self._positions[4, 1] = pos[1] - - def _move_handle_2(self, pos: Sequence[float]) -> None: - self._positions[2, :2] = pos[:2] - - self._positions[1, 0] = pos[0] - self._positions[3, 1] = pos[1] - - def _move_handle_3(self, pos: Sequence[float]) -> None: - self._positions[3, :2] = pos[:2] - - # NB pygfx requires (idx 0) = (idx 4) - self._positions[0, 0] = pos[0] - self._positions[4, 0] = pos[0] - self._positions[2, 1] = pos[1] + # TODO: Misleading name? + def set_handles(self, color: _cmap.Color) -> None: + if self._handles: + self._handles.material.color = color.rgba + self._render() def _create_fill(self) -> pygfx.Mesh | None: fill = pygfx.Mesh( @@ -313,7 +193,7 @@ def _create_fill(self) -> pygfx.Mesh | None: positions=self._positions, indices=np.array([[0, 1, 2, 3]], dtype=np.int32), ), - material=pygfx.MeshBasicMaterial(color=(0, 0, 0, 0)), + material=pygfx.MeshBasicMaterial(), ) return fill @@ -323,7 +203,7 @@ def _create_outline(self) -> pygfx.Line | None: positions=self._positions, indices=np.array([[0, 1, 2, 3]], dtype=np.int32), ), - material=pygfx.LineMaterial(thickness=1, color=(0, 0, 0, 0)), + material=pygfx.LineMaterial(thickness=1), ) return outline @@ -333,7 +213,7 @@ def _create_handles(self) -> pygfx.Points | None: geometry=geometry, # FIXME Size in pixels is not ideal for selection. # TODO investigate what size_mode = vertex does... - material=pygfx.PointsMaterial(color=(1, 1, 1), size=1.5 * self._point_rad), + material=pygfx.PointsMaterial(size=1.5 * self._handle_rad), ) # NB: Default bounding box for points does not consider the radius of @@ -341,8 +221,8 @@ def _create_handles(self) -> pygfx.Points | None: def get_handle_bb(old: Callable[[], np.ndarray]) -> Callable[[], np.ndarray]: def new_get_bb() -> np.ndarray: bb = old().copy() - bb[0, :2] -= self._point_rad - bb[1, :2] += self._point_rad + bb[0, :2] -= self._handle_rad + bb[1, :2] += self._handle_rad return bb return new_get_bb @@ -350,6 +230,17 @@ def new_get_bb() -> np.ndarray: geometry.get_bounding_box = get_handle_bb(geometry.get_bounding_box) return handles + def can_select(self) -> bool: + return True + + def selected(self) -> bool: + return self._selected + + def set_selected(self, selected: bool) -> None: + self._selected = selected + if self._handles: + self._handles.visible = selected + def _refresh(self) -> None: if self._fill: self._fill.geometry.positions.data[:, :] = self._positions @@ -362,32 +253,99 @@ def _refresh(self) -> None: self._handles.geometry.positions.update_range() self._render() - def _handle_hover_idx(self, pos: Sequence[float]) -> int | None: + def on_mouse_move(self, event: MouseMoveEvent) -> bool: + # Convert canvas -> world + world_pos = tuple(self._canvas_to_world((event.x, event.y))[:2]) + # moving a handle + if self._move_mode == ROIMoveMode.HANDLE: + # The anchor is set to the opposite handle, which never moves. + self.boundingBoxChanged.emit((world_pos, self._move_anchor)) + # translating the whole roi + elif self._move_mode == ROIMoveMode.TRANSLATE: + # The anchor is the mouse position reported in the previous mouse event. + dx = world_pos[0] - self._move_anchor[0] + dy = world_pos[1] - self._move_anchor[1] + # If the mouse moved (dx, dy) between events, the whole ROI needs to be + # translated that amount. + new_min = (self._positions[0, 0] + dx, self._positions[0, 1] + dy) + new_max = (self._positions[2, 0] + dx, self._positions[2, 1] + dy) + self.boundingBoxChanged.emit((new_min, new_max)) + self._move_anchor = world_pos + + return False + + def on_mouse_press(self, event: MousePressEvent) -> bool: + self.set_selected(True) + # Convert canvas -> world + world_pos = self._canvas_to_world((event.x, event.y)) + drag_idx = self._handle_under(world_pos) + # If a marker is pressed + if drag_idx is not None: + opposite_idx = (drag_idx + 2) % 4 + self._move_mode = ROIMoveMode.HANDLE + self._move_anchor = tuple(self._positions[opposite_idx, :2].copy()) + # If the rectangle is pressed + else: + self._move_mode = ROIMoveMode.TRANSLATE + self._move_anchor = world_pos + return False + + def on_mouse_release(self, event: MouseReleaseEvent) -> bool: + return False + + def visible(self) -> bool: + if self._outline: + return bool(self._outline.visible) + if self._fill: + return bool(self._fill.visible) + # Nothing to see + return False + + def set_visible(self, visible: bool) -> None: + if fill := getattr(self, "_fill", None): + fill.visible = visible + if outline := getattr(self, "_outline", None): + outline.visible = visible + if handles := getattr(self, "_handles", None): + handles.visible = visible and self.selected() + self._render() + + def _handle_under(self, pos: Sequence[float]) -> int | None: + """Returns an int in [0, 3], or None. + + If an int i, means that the handle at self._positions[i] is at pos. + If None, there is no handle at pos. + """ # FIXME: Ideally, Renderer.get_pick_info would do this for us. But it # seems broken. for i, p in enumerate(self._positions[:-1]): - if (p[0] - pos[0]) ** 2 + (p[1] - pos[1]) ** 2 <= self._point_rad**2: + if (p[0] - pos[0]) ** 2 + (p[1] - pos[1]) ** 2 <= self._handle_rad**2: return i return None - def cursor_at(self, canvas_pos: Sequence[float]) -> CursorType | None: - # Convert canvas -> world - world_pos = self._canvas_to_world(canvas_pos) - # Step 1: Check if over handle - if (idx := self._handle_hover_idx(world_pos)) is not None: - if np.array_equal( - self._positions[idx], self._positions.min(axis=0) - ) or np.array_equal(self._positions[idx], self._positions.max(axis=0)): + def get_cursor(self, mme: MouseMoveEvent) -> CursorType | None: + # Convert event pos (on canvas) to world pos + world_pos = self._canvas_to_world((mme.x, mme.y)) + # Step 1: Handles + # Preferred over the rectangle + # Can only be moved if ROI is selected + if (idx := self._handle_under(world_pos)) is not None and self.selected(): + # Idx 0 is top left, 2 is bottom right + if idx % 2 == 0: return CursorType.FDIAG_ARROW + # Idx 1 is bottom left, 3 is top right return CursorType.BDIAG_ARROW - - # Step 2: Check if over ROI + # Step 2: Entire ROI if self._outline: roi_bb = self._outline.geometry.get_bounding_box() if _is_inside(roi_bb, world_pos): return CursorType.ALL_ARROW return None + def remove(self) -> None: + if (par := self._container.parent) is not None: + par.remove(self._container) + def get_canvas_class() -> WgpuCanvas: from ndv.views._app import GuiFrontend, gui_frontend @@ -418,7 +376,9 @@ def sizeHint(self) -> QSize: class GfxArrayCanvas(ArrayCanvas): """pygfx-based canvas wrapper.""" - def __init__(self) -> None: + def __init__(self, viewer_model: ArrayViewerModel) -> None: + self._viewer = viewer_model + self._current_shape: tuple[int, ...] = () self._last_state: dict[Literal[2, 3], Any] = {} @@ -446,6 +406,9 @@ def __init__(self) -> None: self._ndim: Literal[2, 3] | None = None self._elements = WeakKeyDictionary[pygfx.WorldObject, CanvasElement]() + self._selection: CanvasElement | None = None + # Maintain a weak reference to the last ROI created. + self._last_roi_created: ReferenceType[PyGFXRectangle] | None = None def frontend_widget(self) -> Any: return self._canvas @@ -541,23 +504,17 @@ def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle: self._elements[vol] = handle return handle - def add_roi( - self, - vertices: Sequence[tuple[float, float]] | None = None, - color: _cmap.Color | None = None, - border_color: _cmap.Color | None = None, - ) -> PyGFXRoiHandle: + def add_bounding_box(self) -> PyGFXRectangle: """Add a new Rectangular ROI node to the scene.""" - handle = RectangularROIHandle(self.refresh, self.canvas_to_world) - handle.visible = False - self._scene.add(handle) - if vertices: - handle.vertices = vertices - handle.set_color(color) - handle.set_border_color(border_color) - - self._elements[handle] = handle - return handle + roi = PyGFXRectangle( + render=self.refresh, + canvas_to_world=self.canvas_to_world, + parent=self._scene, + ) + roi.set_visible(False) + self._elements[roi._container] = roi + self._last_roi_created = ref(roi) + return roi def set_range( self, @@ -637,8 +594,8 @@ def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: pos = self.canvas_to_world((pos_xy[0], pos_xy[1])) for c in self._scene.children: bb = c.get_bounding_box() - if _is_inside(bb, pos) and (element := self._elements.get(c)): - elements.append(element) + if _is_inside(bb, pos): + elements.append(self._elements[c]) return elements def set_visible(self, visible: bool) -> None: @@ -648,3 +605,61 @@ def set_visible(self, visible: bool) -> None: def close(self) -> None: self._disconnect_mouse_events() self._canvas.close() + + def on_mouse_press(self, event: MousePressEvent) -> bool: + if self._selection: + self._selection.set_selected(False) + self._selection = None + canvas_pos = (event.x, event.y) + world_pos = self.canvas_to_world(canvas_pos)[:2] + + # If in CREATE_ROI mode, the new ROI should "start" here. + if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: + if self._last_roi_created is None: + raise ValueError("No ROI to create!") + if new_roi := self._last_roi_created(): + self._last_roi_created = None + # HACK: Provide a non-zero starting size so that if the user clicks + # and immediately releases, it's visible and can be selected again + _min = world_pos + _max = (world_pos[0] + 1, world_pos[1] + 1) + # Put the ROI where the user clicked + new_roi.boundingBoxChanged.emit((_min, _max)) + # Make it visible + new_roi.set_visible(True) + # Select it so the mouse press event below triggers ROIMoveMode.HANDLE + # TODO: Make behavior more direct + new_roi.set_selected(True) + + # All done - exit the mode + self._viewer.interaction_mode = InteractionMode.PAN_ZOOM + + # Select first selectable object at clicked point + for vis in self.elements_at(canvas_pos): + if vis.can_select(): + self._selection = vis + self._selection.on_mouse_press(event) + return False + + return False + + def on_mouse_move(self, event: MouseMoveEvent) -> bool: + if event.btn == MouseButton.LEFT: + if self._selection and self._selection.selected(): + self._selection.on_mouse_move(event) + # If we are moving the object, we don't want to move the camera + return True + return False + + def on_mouse_release(self, event: MouseReleaseEvent) -> bool: + if self._selection: + self._selection.on_mouse_release(event) + return False + + def get_cursor(self, event: MouseMoveEvent) -> CursorType: + if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: + return CursorType.CROSS + for vis in self.elements_at((event.x, event.y)): + if cursor := vis.get_cursor(event): + return cursor + return CursorType.DEFAULT diff --git a/src/ndv/views/_pygfx/_histogram.py b/src/ndv/views/_pygfx/_histogram.py index b39ad0a6..63003ad0 100644 --- a/src/ndv/views/_pygfx/_histogram.py +++ b/src/ndv/views/_pygfx/_histogram.py @@ -405,7 +405,8 @@ def _generate_clim_colors(self, npoints: int) -> np.ndarray: return color - def get_cursor(self, pos: tuple[float, float]) -> CursorType: + def get_cursor(self, mme: MouseMoveEvent) -> CursorType: + pos = (mme.x, mme.y) nearby = self._find_nearby_node(pos) if nearby in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: @@ -460,7 +461,7 @@ def on_mouse_move(self, event: MouseMoveEvent) -> bool: self.model.clims = ClimsManual(min=newlims[0], max=newlims[1]) return False - self.get_cursor(pos).apply_to(self) + self.get_cursor(event).apply_to(self) return False def _find_nearby_node( diff --git a/src/ndv/views/_qt/_app.py b/src/ndv/views/_qt/_app.py index 644f26dc..f1d0b4be 100644 --- a/src/ndv/views/_qt/_app.py +++ b/src/ndv/views/_qt/_app.py @@ -3,11 +3,17 @@ import sys from typing import TYPE_CHECKING, Any, Callable, ClassVar -from qtpy.QtCore import QEvent, QObject +from qtpy.QtCore import QEvent, QObject, Qt from qtpy.QtGui import QMouseEvent -from qtpy.QtWidgets import QApplication - -from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent +from qtpy.QtWidgets import QApplication, QWidget + +from ndv._types import ( + CursorType, + MouseButton, + MouseMoveEvent, + MousePressEvent, + MouseReleaseEvent, +) from ndv.views.bases._app import NDVApp if TYPE_CHECKING: @@ -67,8 +73,8 @@ def array_view_class(self) -> type[ArrayView]: def filter_mouse_events( self, canvas: Any, receiver: Mouseable ) -> Callable[[], None]: - if not isinstance(canvas, QObject): - raise TypeError(f"Expected canvas to be QObject, got {type(canvas)}") + if not isinstance(canvas, QWidget): + raise TypeError(f"Expected canvas to be QWidget, got {type(canvas)}") f = MouseEventFilter(canvas, receiver) canvas.installEventFilter(f) @@ -76,10 +82,24 @@ def filter_mouse_events( class MouseEventFilter(QObject): - def __init__(self, canvas: QObject, receiver: Mouseable): + def __init__(self, canvas: QWidget, receiver: Mouseable): super().__init__() self.canvas = canvas self.receiver = receiver + self.active_button = MouseButton.NONE + + def mouse_btn(self, btn: Any) -> MouseButton: + if btn == Qt.MouseButton.LeftButton: + return MouseButton.LEFT + if btn == Qt.MouseButton.RightButton: + return MouseButton.RIGHT + if btn == Qt.MouseButton.NoButton: + return MouseButton.NONE + + raise Exception(f"Qt mouse button {btn} is unknown") + + def set_cursor(self, type: CursorType) -> None: + self.canvas.setCursor(type.to_qt()) def eventFilter(self, obj: QObject | None, qevent: QEvent | None) -> bool: """Event filter installed on the canvas to handle mouse events. @@ -104,16 +124,23 @@ def eventFilter(self, obj: QObject | None, qevent: QEvent | None) -> bool: if isinstance(qevent, QMouseEvent): pos = qevent.pos() etype = qevent.type() + btn = self.mouse_btn(qevent.button()) if etype == QEvent.Type.MouseMove: - mme = MouseMoveEvent(x=pos.x(), y=pos.y()) + mme = MouseMoveEvent(x=pos.x(), y=pos.y(), btn=self.active_button) intercept |= receiver.on_mouse_move(mme) + if cursor := receiver.get_cursor(mme): + self.set_cursor(cursor) receiver.mouseMoved.emit(mme) elif etype == QEvent.Type.MouseButtonPress: - mpe = MousePressEvent(x=pos.x(), y=pos.y()) + self.active_button = btn + mpe = MousePressEvent(x=pos.x(), y=pos.y(), btn=self.active_button) intercept |= receiver.on_mouse_press(mpe) receiver.mousePressed.emit(mpe) elif etype == QEvent.Type.MouseButtonRelease: - mre = MouseReleaseEvent(x=pos.x(), y=pos.y()) + mre = MouseReleaseEvent( + x=pos.x(), y=pos.y(), btn=self.active_button + ) + self.active_button = MouseButton.NONE intercept |= receiver.on_mouse_release(mre) receiver.mouseReleased.emit(mre) return intercept diff --git a/src/ndv/views/_qt/_array_view.py b/src/ndv/views/_qt/_array_view.py index 78ae0f53..f678fe35 100644 --- a/src/ndv/views/_qt/_array_view.py +++ b/src/ndv/views/_qt/_array_view.py @@ -25,8 +25,10 @@ from superqt.iconify import QIconifyIcon from superqt.utils import signals_blocked +from ndv._types import AxisKey from ndv.models._array_display_model import ChannelMode from ndv.models._lut_model import ClimPolicy, ClimsManual, ClimsMinMax +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views.bases import ArrayView, LutView if TYPE_CHECKING: @@ -37,6 +39,10 @@ from ndv._types import AxisKey from ndv.models._data_display_model import _ArrayDataDisplayModel + from ndv.views.bases._graphics._canvas_elements import ( + CanvasElement, + RectangularROIHandle, + ) SLIDER_STYLE = """ QSlider::groove:horizontal { @@ -220,6 +226,14 @@ def _on_q_auto_changed(self, autoscale: bool) -> None: self._model.clims = ClimsManual(min=clims[0], max=clims[1]) +class ROIButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self.setCheckable(True) + self.setToolTip("Add ROI") + self.setIcon(QIconifyIcon("mdi:vector-rectangle")) + + class _QDimsSliders(QWidget): currentIndexChanged = Signal() @@ -237,15 +251,20 @@ def create_sliders(self, coords: Mapping[Hashable, Sequence]) -> None: """Update sliders with the given coordinate ranges.""" layout = cast("QFormLayout", self.layout()) for axis, _coords in coords.items(): - sld = QLabeledSlider(Qt.Orientation.Horizontal) - sld.valueChanged.connect(self.currentIndexChanged) + # Create a slider for axis if necessary + if axis not in self._sliders: + sld = QLabeledSlider(Qt.Orientation.Horizontal) + sld.valueChanged.connect(self.currentIndexChanged) + layout.addRow(str(axis), sld) + self._sliders[axis] = sld + + # Update axis slider with coordinates + sld = self._sliders[axis] if isinstance(_coords, range): sld.setRange(_coords.start, _coords.stop - 1) sld.setSingleStep(_coords.step) else: sld.setRange(0, len(_coords) - 1) - layout.addRow(str(axis), sld) - self._sliders[axis] = sld self.currentIndexChanged.emit() def hide_dimensions( @@ -356,6 +375,11 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): add_histogram_icon = QIconifyIcon("foundation:graph-bar") self.histogram_btn = QPushButton(add_histogram_icon, "", self) + # button to draw ROIs + self._roi_handle: RectangularROIHandle | None = None + self._selection: CanvasElement | None = None + self.add_roi_btn = ROIButton() + self.luts = _UpCollapsible( "LUTs", parent=self, @@ -372,8 +396,8 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): self._btn_layout.addWidget(self.channel_mode_combo) self._btn_layout.addWidget(self.ndims_btn) self._btn_layout.addWidget(self.histogram_btn) + self._btn_layout.addWidget(self.add_roi_btn) self._btn_layout.addWidget(self.set_range_btn) - # self._btns.addWidget(self._add_roi_btn) # above the canvas info_widget = QWidget() @@ -417,11 +441,17 @@ def closeEvent(self, a0: Any) -> None: class QtArrayView(ArrayView): def __init__( - self, canvas_widget: QWidget, data_model: _ArrayDataDisplayModel + self, + canvas_widget: QWidget, + data_model: _ArrayDataDisplayModel, + viewer_model: ArrayViewerModel, ) -> None: self._data_model = data_model + self._viewer_model = viewer_model self._qwidget = qwdg = _QArrayViewer(canvas_widget) qwdg.histogram_btn.clicked.connect(self._on_add_histogram_clicked) + qwdg.add_roi_btn.toggled.connect(self._on_add_roi_clicked) + self._viewer_model.events.interaction_mode.connect(self._on_model_mode_changed) # TODO: use emit_fast qwdg.dims_sliders.currentIndexChanged.connect(self.currentIndexChanged.emit) @@ -454,6 +484,13 @@ def _on_add_histogram_clicked(self) -> None: else: self.histogramRequested.emit() + def _on_model_mode_changed( + self, new: InteractionMode, old: InteractionMode + ) -> None: + # If leaving CanvasMode.CREATE_ROI, uncheck the ROI button + if old == InteractionMode.CREATE_ROI: + self._qwidget.add_roi_btn.setChecked(False) + def add_histogram(self, widget: QWidget) -> None: if hasattr(self, "_hist"): raise RuntimeError("Only one histogram can be added at a time") @@ -530,3 +567,8 @@ def frontend_widget(self) -> QWidget: def set_progress_spinner_visible(self, visible: bool) -> None: self._qwidget._progress_spinner.setVisible(visible) + + def _on_add_roi_clicked(self, checked: bool) -> None: + self._viewer_model.interaction_mode = ( + InteractionMode.CREATE_ROI if checked else InteractionMode.PAN_ZOOM + ) diff --git a/src/ndv/views/_vispy/_array_canvas.py b/src/ndv/views/_vispy/_array_canvas.py index d6e8d2f2..b99e453c 100755 --- a/src/ndv/views/_vispy/_array_canvas.py +++ b/src/ndv/views/_vispy/_array_canvas.py @@ -3,253 +3,43 @@ import warnings from contextlib import suppress from typing import TYPE_CHECKING, Any, Literal, cast -from weakref import WeakKeyDictionary +from weakref import ReferenceType, WeakKeyDictionary import cmap as _cmap import numpy as np import vispy +import vispy.app +import vispy.color import vispy.scene import vispy.visuals from vispy import scene -from vispy.color import Color from vispy.util.quaternion import Quaternion -from ndv._types import CursorType +from ndv._types import ( + CursorType, + MouseButton, + MouseMoveEvent, + MousePressEvent, + MouseReleaseEvent, +) +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views._app import filter_mouse_events from ndv.views.bases import ArrayCanvas from ndv.views.bases._graphics._canvas_elements import ( CanvasElement, ImageHandle, - RoiHandle, + RectangularROIHandle, + ROIMoveMode, ) if TYPE_CHECKING: from collections.abc import Sequence - from typing import Callable - - import vispy.app turn = np.sin(np.pi / 4) DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) -class Handle(scene.visuals.Markers): - """A Marker that allows specific ROI alterations.""" - - def __init__( - self, - parent: RectangularROI, - on_move: Callable[[Sequence[float]], None] | None = None, - cursor: CursorType - | Callable[[Sequence[float]], CursorType] = CursorType.ALL_ARROW, - ) -> None: - super().__init__(parent=parent) - self.unfreeze() - self.parent = parent - # on_move function(s) - self.on_move: list[Callable[[Sequence[float]], None]] = [] - if on_move: - self.on_move.append(on_move) - # cusror preference function - if not callable(cursor): - - def cursor(_: Any) -> CursorType: - return cursor - - self._cursor_at = cursor - self._selected = False - # NB VisPy asks that the data is a 2D array - self._pos = np.array([[0, 0]], dtype=np.float32) - self.interactive = True - self.freeze() - - def start_move(self, pos: Sequence[float]) -> None: - pass - - def move(self, pos: Sequence[float]) -> None: - for func in self.on_move: - func(pos) - - @property - def pos(self) -> Sequence[float]: - return cast("Sequence[float]", self._pos[0, :]) - - @pos.setter - def pos(self, pos: Sequence[float]) -> None: - self._pos[:] = pos[:2] - self.set_data(self._pos) - - @property - def selected(self) -> bool: - return self._selected - - @selected.setter - def selected(self, selected: bool) -> None: - self._selected = selected - self.parent.selected = selected - - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: - return self._cursor_at(self.pos) - - -class RectangularROI(scene.visuals.Rectangle): - """A VisPy Rectangle visual whose attributes can be edited.""" - - def __init__( - self, - parent: scene.visuals.Visual, - center: list[float] | None = None, - width: float = 1e-6, - height: float = 1e-6, - ) -> None: - if center is None: - center = [0, 0] - scene.visuals.Rectangle.__init__( - self, center=center, width=width, height=height, radius=0, parent=parent - ) - self.unfreeze() - self.parent = parent - self.interactive = True - - self._handles = [ - Handle( - self, - on_move=self.move_top_left, - cursor=self._handle_cursor_pref, - ), - Handle( - self, - on_move=self.move_top_right, - cursor=self._handle_cursor_pref, - ), - Handle( - self, - on_move=self.move_bottom_right, - cursor=self._handle_cursor_pref, - ), - Handle( - self, - on_move=self.move_bottom_left, - cursor=self._handle_cursor_pref, - ), - ] - - # drag_reference defines the offset between where the user clicks and the center - # of the rectangle - self.drag_reference = [0.0, 0.0] - self.interactive = True - self._selected = False - self.freeze() - - def _handle_cursor_pref(self, handle_pos: Sequence[float]) -> CursorType: - # Bottom left handle - if handle_pos[0] < self.center[0] and handle_pos[1] < self.center[1]: - return CursorType.FDIAG_ARROW - # Top right handle - if handle_pos[0] > self.center[0] and handle_pos[1] > self.center[1]: - return CursorType.FDIAG_ARROW - # Top left, bottom right - return CursorType.BDIAG_ARROW - - def move_top_left(self, pos: Sequence[float]) -> None: - self._handles[3].pos = [pos[0], self._handles[3].pos[1]] - self._handles[0].pos = pos - self._handles[1].pos = [self._handles[1].pos[0], pos[1]] - self.redraw() - - def move_top_right(self, pos: Sequence[float]) -> None: - self._handles[0].pos = [self._handles[0].pos[0], pos[1]] - self._handles[1].pos = pos - self._handles[2].pos = [pos[0], self._handles[2].pos[1]] - self.redraw() - - def move_bottom_right(self, pos: Sequence[float]) -> None: - self._handles[1].pos = [pos[0], self._handles[1].pos[1]] - self._handles[2].pos = pos - self._handles[3].pos = [self._handles[3].pos[0], pos[1]] - self.redraw() - - def move_bottom_left(self, pos: Sequence[float]) -> None: - self._handles[2].pos = [self._handles[2].pos[0], pos[1]] - self._handles[3].pos = pos - self._handles[0].pos = [pos[0], self._handles[0].pos[1]] - self.redraw() - - def redraw(self) -> None: - left, top, *_ = self._handles[0].pos - right, bottom, *_ = self._handles[2].pos - - self.center = [(left + right) / 2, (top + bottom) / 2] - self.width = max(abs(left - right), 1e-6) - self.height = max(abs(top - bottom), 1e-6) - - # --------------------- EditableROI interface -------------------------- - # In the future, if any other objects implement these same methods, this - # could be extracted into an ABC. - - @property - def vertices(self) -> Sequence[Sequence[float]]: - return [h.pos for h in self._handles] - - @vertices.setter - def vertices(self, vertices: Sequence[Sequence[float]]) -> None: - if len(vertices) != 4 or any(len(v) != 2 for v in vertices): - raise Exception("Only 2D rectangles are currently supported") - is_aligned = ( - vertices[0][1] == vertices[1][1] - and vertices[1][0] == vertices[2][0] - and vertices[2][1] == vertices[3][1] - and vertices[3][0] == vertices[0][0] - ) - if not is_aligned: - raise Exception( - "Only rectangles aligned with the axes are currently supported" - ) - - # Update each handle - for i, handle in enumerate(self._handles): - handle.pos = vertices[i] - # Redraw - self.redraw() - - @property - def selected(self) -> bool: - return self._selected - - @selected.setter - def selected(self, selected: bool) -> None: - self._selected = selected - for h in self._handles: - h.visible = selected - - def start_move(self, pos: Sequence[float]) -> None: - self.drag_reference = [ - pos[0] - self.center[0], - pos[1] - self.center[1], - ] - - def move(self, pos: Sequence[float]) -> None: - new_center = [ - pos[0] - self.drag_reference[0], - pos[1] - self.drag_reference[1], - ] - old_center = self.center - # TODO: Simplify - for h in self._handles: - existing_pos = h.pos - h.pos = [ - existing_pos[0] + new_center[0] - old_center[0], - existing_pos[1] + new_center[1] - old_center[1], - ] - self.center = new_center - - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: - return CursorType.ALL_ARROW - - # ------------------- End EditableROI interface ------------------------- - - class VispyImageHandle(ImageHandle): def __init__(self, visual: scene.Image | scene.Volume) -> None: self._visual = visual @@ -322,99 +112,174 @@ def move(self, pos: Sequence[float]) -> None: def remove(self) -> None: self._visual.parent = None - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: + def get_cursor(self, mme: MouseMoveEvent) -> CursorType | None: return None -# FIXME: Unfortunate naming :) -class VispyHandleHandle(CanvasElement): - def __init__(self, handle: Handle, parent: CanvasElement) -> None: - self._handle = handle - self._parent = parent +class VispyRectangle(RectangularROIHandle): + def __init__(self, parent: Any) -> None: + self._selected = False + self._move_mode: ROIMoveMode | None = None + # NB _move_anchor has different meanings depending on _move_mode + self._move_anchor: tuple[float, float] = (0, 0) + + # Rectangle handles both fill and border + self._rect = scene.Rectangle(center=[0, 0], width=1, height=1, parent=parent) + # NB: Should be greater than image orders BUT NOT handle order + self._rect.order = 10 + self._rect.interactive = True + + self._handle_data = np.zeros((4, 2)) + self._handle_size = 10 # px + self._handles = scene.Markers( + pos=self._handle_data, + size=self._handle_size, + scaling="fixed", + parent=parent, + ) + # NB: Should be greater than image orders and rect order + self._handles.order = 100 + self._handles.interactive = True - def visible(self) -> bool: - return cast("bool", self._handle.visible) + self.set_fill(_cmap.Color("transparent")) + self.set_border(_cmap.Color("yellow")) + self.set_handles(_cmap.Color("white")) + self.set_visible(False) - def set_visible(self, visible: bool) -> None: - self._handle.visible = visible + def _tform(self) -> scene.transforms.BaseTransform: + return self._rect.transforms.get_transform("canvas", "scene") def can_select(self) -> bool: return True def selected(self) -> bool: - return self._handle.selected + return self._selected def set_selected(self, selected: bool) -> None: - self._handle.selected = selected - - def start_move(self, pos: Sequence[float]) -> None: - self._handle.start_move(pos) - - def move(self, pos: Sequence[float]) -> None: - self._handle.move(pos) - - def remove(self) -> None: - self._parent.remove() - - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: - return self._handle.cursor_at(pos) + self._selected = selected + self._handles.visible = selected and self.visible() + + def set_fill(self, color: _cmap.Color) -> None: + _vis_color = vispy.color.Color(color.hex) + # NB We need alpha>0 for selection + _vis_color.alpha = max(color.alpha, 1e-6) + self._rect.color = _vis_color + + def set_border(self, color: _cmap.Color) -> None: + _vis_color = vispy.color.Color(color.hex) + _vis_color.alpha = color.alpha + self._rect.border_color = _vis_color + + # TODO: Misleading name? + def set_handles(self, color: _cmap.Color) -> None: + _vis_color = vispy.color.Color(color.hex) + _vis_color.alpha = color.alpha + self._handles.set_data(face_color=_vis_color) + + def set_bounding_box( + self, mi: tuple[float, float], ma: tuple[float, float] + ) -> None: + # NB: Support two diagonal points, not necessarily true min/max + x1 = float(min(mi[0], ma[0])) + y1 = float(min(mi[1], ma[1])) + x2 = float(max(mi[0], ma[0])) + y2 = float(max(mi[1], ma[1])) + + # Update rectangle + self._rect.center = [(x1 + x2) / 2, (y1 + y2) / 2] + self._rect.width = max(float(x2 - x1), 1e-30) + self._rect.height = max(float(y2 - y1), 1e-30) + + # Update handles + self._handle_data[0] = x1, y1 + self._handle_data[1] = x2, y1 + self._handle_data[2] = x2, y2 + self._handle_data[3] = x1, y2 + self._handles.set_data(pos=self._handle_data) + + # FIXME: These should be called internally upon set_data, right? + # Looks like https://github.com/vispy/vispy/issues/1899 + self._rect._bounds_changed() + for v in self._rect._subvisuals: + v._bounds_changed() + self._handles._bounds_changed() + + def on_mouse_move(self, event: MouseMoveEvent) -> bool: + # Convert canvas -> world + canvas_pos = (event.x, event.y) + world_pos = self._tform().map(canvas_pos)[:2] + # moving a handle + if self._move_mode == ROIMoveMode.HANDLE: + # The anchor is set to the opposite handle, which never moves. + self.boundingBoxChanged.emit((world_pos, self._move_anchor)) + # translating the whole roi + elif self._move_mode == ROIMoveMode.TRANSLATE: + # The anchor is the mouse position reported in the previous mouse event. + dx = world_pos[0] - self._move_anchor[0] + dy = world_pos[1] - self._move_anchor[1] + # If the mouse moved (dx, dy) between events, the whole ROI needs to be + # translated that amount. + new_min = (self._handle_data[0, 0] + dx, self._handle_data[0, 1] + dy) + new_max = (self._handle_data[2, 0] + dx, self._handle_data[2, 1] + dy) + self.boundingBoxChanged.emit((new_min, new_max)) + self._move_anchor = world_pos + return False -class VispyRoiHandle(RoiHandle): - def __init__(self, roi: RectangularROI) -> None: - self._roi = roi + def on_mouse_press(self, event: MousePressEvent) -> bool: + self.set_selected(True) + # Convert canvas -> world + canvas_pos = (event.x, event.y) + world_pos = self._tform().map(canvas_pos)[:2] + drag_idx = self._handle_under(world_pos) + # If a marker is pressed + if drag_idx is not None: + opposite_idx = (drag_idx + 2) % 4 + self._move_mode = ROIMoveMode.HANDLE + self._move_anchor = tuple(self._handle_data[opposite_idx].copy()) + # If the rectangle is pressed + else: + self._move_mode = ROIMoveMode.TRANSLATE + self._move_anchor = world_pos + return False - def vertices(self) -> Sequence[Sequence[float]]: - return self._roi.vertices + def on_mouse_release(self, event: MouseReleaseEvent) -> bool: + return False - def set_vertices(self, vertices: Sequence[Sequence[float]]) -> None: - self._roi.vertices = vertices + def get_cursor(self, mme: MouseMoveEvent) -> CursorType | None: + canvas_pos = (mme.x, mme.y) + pos = self._tform().map(canvas_pos)[:2] + if self._handle_under(pos) is not None: + center = self._rect.center + if pos[0] < center[0] and pos[1] < center[1]: + return CursorType.FDIAG_ARROW + if pos[0] > center[0] and pos[1] > center[1]: + return CursorType.FDIAG_ARROW + return CursorType.BDIAG_ARROW + return CursorType.ALL_ARROW def visible(self) -> bool: - return bool(self._roi.visible) + return bool(self._rect.visible) def set_visible(self, visible: bool) -> None: - self._roi.visible = visible - - def can_select(self) -> bool: - return True - - def selected(self) -> bool: - return self._roi.selected - - def set_selected(self, selected: bool) -> None: - self._roi.selected = selected - - def start_move(self, pos: Sequence[float]) -> None: - self._roi.start_move(pos) - - def move(self, pos: Sequence[float]) -> None: - self._roi.move(pos) - - def color(self) -> Any: - return self._roi.color - - def set_color(self, color: _cmap.Color | None) -> None: - if color is None: - color = _cmap.Color("transparent") - # NB: To enable dragging the shape within the border, - # we require a positive alpha. - alpha = max(color.alpha, 1e-6) - self._roi.color = Color(color.hex, alpha=alpha) - - def border_color(self) -> _cmap.Color: - return _cmap.Color(self._roi.border_color.rgba) - - def set_border_color(self, color: _cmap.Color | None) -> None: - if color is None: - color = _cmap.Color("yellow") - self._roi.border_color = Color(color.hex, alpha=color.alpha) + self._rect.visible = visible + self._handles.visible = visible and self.selected() def remove(self) -> None: - self._roi.parent = None + self._rect.parent = None + self._handles.parent = None + + def _handle_under(self, pos: Sequence[float]) -> int | None: + """Returns an int in [0, 3], or None. - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: - return self._roi.cursor_at(pos) + If an int i, means that the handle at self._positions[i] is at pos. + If None, there is no handle at pos. + """ + rad2 = (self._handle_size / 2) ** 2 + for i, p in enumerate(self._handle_data): + if (p[0] - pos[0]) ** 2 + (p[1] - pos[1]) ** 2 <= rad2: + return i + return None class VispyArrayCanvas(ArrayCanvas): @@ -424,7 +289,9 @@ class VispyArrayCanvas(ArrayCanvas): could be swapped in if needed as long as they implement the same interface). """ - def __init__(self) -> None: + def __init__(self, viewer_model: ArrayViewerModel) -> None: + self._viewer = viewer_model + self._canvas = scene.SceneCanvas(size=(600, 600)) # this filter needs to remain in scope for the lifetime of the canvas @@ -440,6 +307,9 @@ def __init__(self) -> None: self._ndim: Literal[2, 3] | None = None self._elements: WeakKeyDictionary = WeakKeyDictionary() + self._selection: CanvasElement | None = None + # Maintain weak reference to last ROI created + self._last_roi_created: ReferenceType[VispyRectangle] | None = None @property def _camera(self) -> vispy.scene.cameras.BaseCamera: @@ -520,24 +390,14 @@ def add_volume(self, data: np.ndarray | None = None) -> VispyImageHandle: self.set_range() return handle - def add_roi( - self, - vertices: Sequence[tuple[float, float]] | None = None, - color: _cmap.Color | None = None, - border_color: _cmap.Color | None = None, - ) -> VispyRoiHandle: + def add_bounding_box(self) -> VispyRectangle: """Add a new Rectangular ROI node to the scene.""" - roi = RectangularROI(parent=self._view.scene) - handle = VispyRoiHandle(roi) - self._elements[roi] = handle - for h in roi._handles: - self._elements[h] = VispyHandleHandle(h, handle) - if vertices: - handle.set_vertices(vertices) - self.set_range() - handle.set_color(color) - handle.set_border_color(border_color) - return handle + roi = VispyRectangle(parent=self._view.scene) + roi.set_visible(False) + self._elements[roi._handles] = roi + self._elements[roi._rect] = roi + self._last_roi_created = ReferenceType(roi) + return roi def set_range( self, @@ -564,7 +424,7 @@ def set_range( _y[1] = max(_y[1], shape[1]) if len(shape) > 2: _z[1] = max(_z[1], shape[2]) - elif isinstance(handle, VispyRoiHandle): + elif isinstance(handle, VispyRectangle): for v in handle.vertices: _x[0] = min(_x[0], v[0]) _x[1] = max(_x[1], v[0]) @@ -600,6 +460,65 @@ def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: elements.append(handle) return elements + def on_mouse_press(self, event: MousePressEvent) -> bool: + if self._selection: + self._selection.set_selected(False) + self._selection = None + canvas_pos = (event.x, event.y) + world_pos = self.canvas_to_world(canvas_pos)[:2] + + # If in CREATE_ROI mode, the new ROI should "start" here. + if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: + if self._last_roi_created is None: + raise ValueError("No ROI to create!") + if new_roi := self._last_roi_created(): + self._last_roi_created = None + # HACK: Provide a non-zero starting size so that if the user clicks + # and immediately releases, it's visible and can be selected again + _min = world_pos + _max = (world_pos[0] + 1, world_pos[1] + 1) + # Put the ROI where the user clicked + new_roi.boundingBoxChanged.emit((_min, _max)) + # new_roi.set_bounding_box(_min, _max) + # Make it visible + new_roi.set_visible(True) + # Select it so the mouse press event below triggers ROIMoveMode.HANDLE + # TODO: Make behavior more direct + new_roi.set_selected(True) + + # All done - exit the mode + self._viewer.interaction_mode = InteractionMode.PAN_ZOOM + + # Select first selectable object at clicked point + for vis in self.elements_at(canvas_pos): + if vis.can_select(): + self._selection = vis + self._selection.on_mouse_press(event) + return False + + return False + + def on_mouse_move(self, event: MouseMoveEvent) -> bool: + if event.btn == MouseButton.LEFT: + if self._selection and self._selection.selected(): + self._selection.on_mouse_move(event) + # If we are moving the object, we don't want to move the camera + return True + return False + + def on_mouse_release(self, event: MouseReleaseEvent) -> bool: + if self._selection: + self._selection.on_mouse_release(event) + return False + + def get_cursor(self, mme: MouseMoveEvent) -> CursorType: + if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: + return CursorType.CROSS + for vis in self.elements_at((mme.x, mme.y)): + if cursor := vis.get_cursor(mme): + return cursor + return CursorType.DEFAULT + def _downcast(data: np.ndarray | None) -> np.ndarray | None: """Downcast >32bit data to 32bit.""" diff --git a/src/ndv/views/_vispy/_histogram.py b/src/ndv/views/_vispy/_histogram.py index 4e09f49a..9b3b1538 100644 --- a/src/ndv/views/_vispy/_histogram.py +++ b/src/ndv/views/_vispy/_histogram.py @@ -273,7 +273,8 @@ def _update_lut_lines(self, npoints: int = 256) -> None: v._bounds_changed() self._gamma_handle._bounds_changed() - def get_cursor(self, pos: tuple[float, float]) -> CursorType: + def get_cursor(self, event: MouseMoveEvent) -> CursorType: + pos = (event.x, event.y) nearby = self._find_nearby_node(pos) if nearby in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: @@ -335,7 +336,7 @@ def on_mouse_move(self, event: MouseMoveEvent) -> bool: self.model.gamma = -np.log2(y / y1) return False - self.get_cursor(pos).apply_to(self) + self.get_cursor(event).apply_to(self) return False def _find_nearby_node( diff --git a/src/ndv/views/_wx/_app.py b/src/ndv/views/_wx/_app.py index bc67a8f7..9ad1d7ab 100644 --- a/src/ndv/views/_wx/_app.py +++ b/src/ndv/views/_wx/_app.py @@ -5,7 +5,7 @@ import wx from wx import EVT_LEFT_DOWN, EVT_LEFT_UP, EVT_MOTION, EvtHandler, MouseEvent -from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent +from ndv._types import MouseButton, MouseMoveEvent, MousePressEvent, MouseReleaseEvent from ndv.views.bases._app import NDVApp from ._main_thread import call_in_main_thread @@ -59,33 +59,51 @@ def filter_mouse_events( f"Expected vispy canvas to be wx EvtHandler, got {type(canvas)}" ) + if hasattr(canvas, "_subwidget"): + canvas = canvas._subwidget + # TIP: event.Skip() allows the event to propagate to other handlers. + active_button: MouseButton = MouseButton.NONE + def on_mouse_move(event: MouseEvent) -> None: - mme = MouseMoveEvent(x=event.GetX(), y=event.GetY()) + nonlocal active_button + nonlocal canvas + + mme = MouseMoveEvent(x=event.GetX(), y=event.GetY(), btn=active_button) if not receiver.on_mouse_move(mme): receiver.mouseMoved.emit(mme) event.Skip() + # FIXME: get_cursor is VERY slow, unsure why. + if cursor := receiver.get_cursor(mme): + canvas.SetCursor(cursor.to_wx()) def on_mouse_press(event: MouseEvent) -> None: - mpe = MousePressEvent(x=event.GetX(), y=event.GetY()) + nonlocal active_button + + # NB This function is bound to the left mouse button press + active_button = MouseButton.LEFT + mpe = MousePressEvent(x=event.GetX(), y=event.GetY(), btn=active_button) if not receiver.on_mouse_press(mpe): receiver.mousePressed.emit(mpe) event.Skip() def on_mouse_release(event: MouseEvent) -> None: - mre = MouseReleaseEvent(x=event.GetX(), y=event.GetY()) + nonlocal active_button + + mre = MouseReleaseEvent(x=event.GetX(), y=event.GetY(), btn=active_button) + active_button = MouseButton.NONE if not receiver.on_mouse_release(mre): receiver.mouseReleased.emit(mre) event.Skip() - canvas.Bind(EVT_MOTION, on_mouse_move) - canvas.Bind(EVT_LEFT_DOWN, on_mouse_press) - canvas.Bind(EVT_LEFT_UP, on_mouse_release) + canvas.Bind(EVT_MOTION, handler=on_mouse_move) + canvas.Bind(EVT_LEFT_DOWN, handler=on_mouse_press) + canvas.Bind(EVT_LEFT_UP, handler=on_mouse_release) def _unbind() -> None: - canvas.Unbind(EVT_MOTION, on_mouse_move) - canvas.Unbind(EVT_LEFT_DOWN, on_mouse_press) - canvas.Unbind(EVT_LEFT_UP, on_mouse_release) + canvas.Unbind(EVT_MOTION, handler=on_mouse_move) + canvas.Unbind(EVT_LEFT_DOWN, handler=on_mouse_press) + canvas.Unbind(EVT_LEFT_UP, handler=on_mouse_release) return _unbind diff --git a/src/ndv/views/_wx/_array_view.py b/src/ndv/views/_wx/_array_view.py index 86b3af49..477d14c1 100644 --- a/src/ndv/views/_wx/_array_view.py +++ b/src/ndv/views/_wx/_array_view.py @@ -11,6 +11,7 @@ from ndv.models._array_display_model import ChannelMode from ndv.models._lut_model import ClimPolicy, ClimsManual, ClimsMinMax +from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views._wx._labeled_slider import WxLabeledSlider from ndv.views.bases import ArrayView, LutView @@ -162,18 +163,22 @@ def __init__(self, parent: wx.Window) -> None: def create_sliders(self, coords: Mapping[Hashable, Sequence]) -> None: """Update sliders with the given coordinate ranges.""" for axis, _coords in coords.items(): - slider = WxLabeledSlider(self) - slider.label.SetLabel(str(axis)) - slider.slider.Bind(wx.EVT_SLIDER, self._on_slider_changed) - + # Create a slider for axis if necessary + if axis not in self._sliders: + slider = WxLabeledSlider(self) + slider.slider.Bind(wx.EVT_SLIDER, self._on_slider_changed) + slider.label.SetLabel(str(axis)) + self.layout.Add(slider, 0, wx.EXPAND | wx.ALL, 5) + self._sliders[axis] = slider + + # Update axis slider with coordinates + slider = self._sliders[axis] if isinstance(_coords, range): slider.setRange(_coords.start, _coords.stop - 1) slider.setSingleStep(_coords.step) else: slider.setRange(0, len(_coords) - 1) - self.layout.Add(slider, 0, wx.EXPAND | wx.ALL, 5) - self._sliders[axis] = slider self.currentIndexChanged.emit() def hide_dimensions( @@ -252,6 +257,9 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): # 3d view button self.ndims_btn = wx.ToggleButton(self, label="3D") + # Add ROI button + self.add_roi_btn = wx.ToggleButton(self, label="Add ROI") + # LUT layout (simple vertical grouping for LUT widgets) self.luts = wx.BoxSizer(wx.VERTICAL) @@ -260,6 +268,7 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): btns.Add(self.channel_mode_combo, 0, wx.ALL, 5) btns.Add(self.reset_zoom_btn, 0, wx.ALL, 5) btns.Add(self.ndims_btn, 0, wx.ALL, 5) + btns.Add(self.add_roi_btn, 0, wx.ALL, 5) self._top_info = top_info = wx.BoxSizer(wx.HORIZONTAL) top_info.Add(self._data_info_label, 0, wx.EXPAND | wx.BOTTOM, 0) @@ -286,9 +295,11 @@ def __init__( self, canvas_widget: wx.Window, data_model: _ArrayDataDisplayModel, + viewer_model: ArrayViewerModel, parent: wx.Window = None, ) -> None: self._data_model = data_model + self._viewer_model = viewer_model self._wxwidget = wdg = _WxArrayViewer(canvas_widget, parent) self._visible_axes: Sequence[AxisKey] = [] @@ -297,6 +308,7 @@ def __init__( wdg.channel_mode_combo.Bind(wx.EVT_COMBOBOX, self._on_channel_mode_changed) wdg.reset_zoom_btn.Bind(wx.EVT_BUTTON, self._on_reset_zoom_clicked) wdg.ndims_btn.Bind(wx.EVT_TOGGLEBUTTON, self._on_ndims_toggled) + wdg.add_roi_btn.Bind(wx.EVT_TOGGLEBUTTON, self._on_add_roi_toggled) def _on_channel_mode_changed(self, event: wx.CommandEvent) -> None: mode = self._wxwidget.channel_mode_combo.GetValue() @@ -323,6 +335,12 @@ def _on_ndims_toggled(self, event: wx.CommandEvent) -> None: # since we now have access to it. self.visibleAxesChanged.emit() + def _on_add_roi_toggled(self, event: wx.CommandEvent) -> None: + create_roi = self._wxwidget.add_roi_btn.GetValue() + self._viewer_model.interaction_mode = ( + InteractionMode.CREATE_ROI if create_roi else InteractionMode.PAN_ZOOM + ) + def visible_axes(self) -> Sequence[AxisKey]: return self._visible_axes # no widget to control this yet diff --git a/src/ndv/views/bases/__init__.py b/src/ndv/views/bases/__init__.py index edfdd213..9528272d 100644 --- a/src/ndv/views/bases/__init__.py +++ b/src/ndv/views/bases/__init__.py @@ -3,7 +3,7 @@ from ._app import NDVApp from ._array_view import ArrayView from ._graphics._canvas import ArrayCanvas, HistogramCanvas -from ._graphics._canvas_elements import CanvasElement, ImageHandle, RoiHandle +from ._graphics._canvas_elements import CanvasElement, ImageHandle, RectangularROIHandle from ._graphics._mouseable import Mouseable from ._lut_view import LutView from ._view_base import Viewable @@ -17,6 +17,6 @@ "LutView", "Mouseable", "NDVApp", - "RoiHandle", + "RectangularROIHandle", "Viewable", ] diff --git a/src/ndv/views/bases/_array_view.py b/src/ndv/views/bases/_array_view.py index 386d6d69..4d73a6f3 100644 --- a/src/ndv/views/bases/_array_view.py +++ b/src/ndv/views/bases/_array_view.py @@ -15,6 +15,7 @@ from ndv._types import AxisKey from ndv.models._data_display_model import _ArrayDataDisplayModel + from ndv.models._viewer_model import ArrayViewerModel from ndv.views.bases import LutView @@ -35,7 +36,11 @@ class ArrayView(Viewable): # model: _ArrayDataDisplayModel is likely a temporary parameter @abstractmethod def __init__( - self, canvas_widget: Any, model: _ArrayDataDisplayModel, **kwargs: Any + self, + canvas_widget: Any, + model: _ArrayDataDisplayModel, + viewer_model: ArrayViewerModel, + **kwargs: Any, ) -> None: ... @abstractmethod def create_sliders(self, coords: Mapping[Hashable, Sequence]) -> None: ... diff --git a/src/ndv/views/bases/_graphics/__init__.py b/src/ndv/views/bases/_graphics/__init__.py index 163d8c9c..37e488c6 100644 --- a/src/ndv/views/bases/_graphics/__init__.py +++ b/src/ndv/views/bases/_graphics/__init__.py @@ -1,7 +1,7 @@ """Base classes for graphics elements.""" from ._canvas import ArrayCanvas, HistogramCanvas -from ._canvas_elements import CanvasElement, ImageHandle, RoiHandle +from ._canvas_elements import CanvasElement, ImageHandle, RectangularROIHandle from ._mouseable import Mouseable __all__ = [ @@ -10,5 +10,5 @@ "HistogramCanvas", "ImageHandle", "Mouseable", - "RoiHandle", + "RectangularROIHandle", ] diff --git a/src/ndv/views/bases/_graphics/_canvas.py b/src/ndv/views/bases/_graphics/_canvas.py index 6e1e1e2e..95fddd57 100644 --- a/src/ndv/views/bases/_graphics/_canvas.py +++ b/src/ndv/views/bases/_graphics/_canvas.py @@ -11,12 +11,11 @@ from ._mouseable import Mouseable if TYPE_CHECKING: - from collections.abc import Sequence - - import cmap import numpy as np - from ._canvas_elements import CanvasElement, ImageHandle, RoiHandle + from ndv.models._viewer_model import ArrayViewerModel + + from ._canvas_elements import CanvasElement, ImageHandle, RectangularROIHandle class GraphicsCanvas(Viewable, Mouseable): @@ -48,6 +47,8 @@ def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: ... class ArrayCanvas(GraphicsCanvas): """ABC for canvases that show array data.""" + @abstractmethod + def __init__(self, viewer_model: ArrayViewerModel | None = ...) -> None: ... @abstractmethod def set_ndim(self, ndim: Literal[2, 3]) -> None: ... @abstractmethod @@ -56,12 +57,7 @@ def add_image(self, data: np.ndarray | None = ...) -> ImageHandle: ... @abstractmethod def add_volume(self, data: np.ndarray | None = ...) -> ImageHandle: ... @abstractmethod - def add_roi( - self, - vertices: Sequence[tuple[float, float]] | None = None, - color: cmap.Color | None = None, - border_color: cmap.Color | None = None, - ) -> RoiHandle: ... + def add_bounding_box(self) -> RectangularROIHandle: ... class HistogramCanvas(GraphicsCanvas, LutView): diff --git a/src/ndv/views/bases/_graphics/_canvas_elements.py b/src/ndv/views/bases/_graphics/_canvas_elements.py index a12eee22..294baf5d 100644 --- a/src/ndv/views/bases/_graphics/_canvas_elements.py +++ b/src/ndv/views/bases/_graphics/_canvas_elements.py @@ -1,19 +1,21 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any +from enum import Enum, auto +from typing import TYPE_CHECKING + +from psygnal import Signal from ndv.views.bases._lut_view import LutView from ._mouseable import Mouseable if TYPE_CHECKING: - from collections.abc import Sequence + from typing import Any import cmap as _cmap import numpy as np - from ndv._types import CursorType from ndv.models._lut_model import ClimPolicy @@ -40,25 +42,6 @@ def selected(self) -> bool: def set_selected(self, selected: bool) -> None: """Sets element selection status.""" - def cursor_at(self, pos: Sequence[float]) -> CursorType | None: - """Returns the element's cursor preference at the provided position.""" - - def start_move(self, pos: Sequence[float]) -> None: - """ - Behavior executed at the beginning of a "move" operation. - - In layman's terms, this is the behavior executed during the the "click" - of a "click-and-drag". - """ - - def move(self, pos: Sequence[float]) -> None: - """ - Behavior executed throughout a "move" operation. - - In layman's terms, this is the behavior executed during the "drag" - of a "click-and-drag". - """ - def remove(self) -> None: """Removes the element from the canvas.""" @@ -98,16 +81,28 @@ def set_channel_visible(self, visible: bool) -> None: self.set_visible(visible) -class RoiHandle(CanvasElement): - @abstractmethod - def vertices(self) -> Sequence[Sequence[float]]: ... - @abstractmethod - def set_vertices(self, data: Sequence[Sequence[float]]) -> None: ... - @abstractmethod - def color(self) -> Any: ... - @abstractmethod - def set_color(self, color: _cmap.Color | None) -> None: ... - @abstractmethod - def border_color(self) -> Any: ... - @abstractmethod - def set_border_color(self, color: _cmap.Color | None) -> None: ... +class ROIMoveMode(Enum): + """Describes graphical mechanisms for ROI translation.""" + + HANDLE = auto() # Moving one handle (but not all) + TRANSLATE = auto() # Translating everything + + +class RectangularROIHandle(CanvasElement): + """An axis-aligned rectanglular ROI.""" + + boundingBoxChanged = Signal(tuple[tuple[float, float], tuple[float, float]]) + + def set_bounding_box( + self, minimum: tuple[float, float], maximum: tuple[float, float] + ) -> None: + """Sets the bounding box.""" + + def set_fill(self, color: _cmap.Color) -> None: + """Sets the fill color.""" + + def set_border(self, color: _cmap.Color) -> None: + """Sets the border color.""" + + def set_handles(self, color: _cmap.Color) -> None: + """Sets the handle face color.""" diff --git a/src/ndv/views/bases/_graphics/_mouseable.py b/src/ndv/views/bases/_graphics/_mouseable.py index d8f540e7..87802354 100644 --- a/src/ndv/views/bases/_graphics/_mouseable.py +++ b/src/ndv/views/bases/_graphics/_mouseable.py @@ -2,7 +2,7 @@ from psygnal import Signal -from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent +from ndv._types import CursorType, MouseMoveEvent, MousePressEvent, MouseReleaseEvent class Mouseable: @@ -28,3 +28,6 @@ def on_mouse_press(self, event: MousePressEvent) -> bool: def on_mouse_release(self, event: MouseReleaseEvent) -> bool: return False + + def get_cursor(self, event: MouseMoveEvent) -> CursorType | None: + return None diff --git a/tests/test_controller.py b/tests/test_controller.py index 722a8350..fb97c00d 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -9,11 +9,19 @@ import numpy as np import pytest -from ndv._types import MouseMoveEvent +from ndv._types import ( + CursorType, + MouseButton, + MouseMoveEvent, + MousePressEvent, + MouseReleaseEvent, +) from ndv.controllers import ArrayViewer from ndv.controllers._channel_controller import ChannelController from ndv.models._array_display_model import ArrayDisplayModel, ChannelMode from ndv.models._lut_model import ClimsManual, ClimsMinMax, LUTModel +from ndv.models._roi_model import RectangularROIModel +from ndv.models._viewer_model import InteractionMode from ndv.views import _app, gui_frontend from ndv.views.bases import ArrayView, LutView from ndv.views.bases._graphics._canvas import ArrayCanvas, HistogramCanvas @@ -29,7 +37,7 @@ IS_PYGFX = _app.canvas_backend(None) == "pygfx" -def _get_mock_canvas() -> ArrayCanvas: +def _get_mock_canvas(*_: Any) -> ArrayCanvas: mock = MagicMock(spec=ArrayCanvas) img_handle = MagicMock(spec=ImageHandle) img_handle.data.return_value = np.zeros((10, 10)).astype(np.uint8) @@ -268,3 +276,118 @@ def test_array_viewer_histogram() -> None: counts = np.bincount(data.flatten(), minlength=maxval + 1) bin_edges = np.arange(maxval + 2) - 0.5 viewer._histogram.set_data(counts, bin_edges) + + +@no_type_check +@pytest.mark.usefixtures("any_app") +def test_roi_controller() -> None: + ctrl = ArrayViewer() + roi = RectangularROIModel() + viewer = ctrl._viewer_model + + # Until a user interacts with ctrl.roi, there is no ROI model + assert ctrl._roi_model is None + ctrl.roi = roi + assert ctrl._roi_model is not None + + # Clicking the ROI button and then clicking the canvas creates a ROI + viewer.interaction_mode = InteractionMode.CREATE_ROI + canvas_pos = (5, 5) + mpe = MousePressEvent(canvas_pos[0], canvas_pos[1], MouseButton.LEFT) + + # Note - avoid diving into rendering logic here - just identify view + with patch.object(ctrl._canvas, "elements_at", return_value=[ctrl._roi_view]): + ctrl._canvas.on_mouse_press(mpe) + world_pos = ctrl._canvas.canvas_to_world(canvas_pos) + + assert roi.bounding_box == ( + (world_pos[0], world_pos[1]), + (world_pos[0] + 1, world_pos[1] + 1), + ) + assert viewer.interaction_mode == InteractionMode.PAN_ZOOM + + +@no_type_check +@pytest.mark.usefixtures("any_app") +def test_roi_interaction() -> None: + if _app.gui_frontend() == _app.GuiFrontend.JUPYTER and IS_PYGFX: + pytest.skip("Invalid canvas size on CI") + return + + ctrl = ArrayViewer() + roi = RectangularROIModel() + ctrl.roi = roi + roi_view = ctrl._roi_view + assert roi_view is not None + + # FIXME: We need a large world space on the canvas, but + # VispyArrayCanvas.set_range is not implemented yet. This workaround + # sets the range to the extent of the data i.e. the extent of the ROI + roi.bounding_box = ((0, 0), (500, 500)) + ctrl._canvas.set_range() + # Note that these positions are far apart to satisfy sufficient distance + # in world space + canvas_roi_start = (200, 200) + world_roi_start = tuple(ctrl._canvas.canvas_to_world(canvas_roi_start)[:2]) + canvas_new_start = (100, 100) + world_new_start = tuple(ctrl._canvas.canvas_to_world(canvas_new_start)[:2]) + canvas_roi_end = (300, 300) + world_roi_end = tuple(ctrl._canvas.canvas_to_world(canvas_roi_end)[:2]) + roi.bounding_box = (world_roi_start, world_roi_end) + + # Note - avoid diving into rendering logic here - just identify view + with patch.object(ctrl._canvas, "elements_at", return_value=[ctrl._roi_view]): + # Test moving handle + assert not roi_view.selected() + mpe = MousePressEvent( + canvas_roi_start[0], canvas_roi_start[1], MouseButton.LEFT + ) + ctrl._canvas.on_mouse_press(mpe) + assert roi_view.selected() + mme = MouseMoveEvent(canvas_new_start[0], canvas_new_start[1], MouseButton.LEFT) + ctrl._canvas.on_mouse_move(mme) + assert roi.bounding_box[0] == pytest.approx(world_new_start, 1e-6) + assert roi.bounding_box[1] == pytest.approx(world_roi_end, 1e-6) + mre = MouseReleaseEvent( + canvas_new_start[0], canvas_new_start[1], MouseButton.LEFT + ) + ctrl._canvas.on_mouse_release(mre) + + # Test translation + roi.bounding_box = (world_roi_start, world_roi_end) + mpe = MousePressEvent( + (canvas_roi_start[0] + canvas_roi_end[0] / 2), + (canvas_roi_start[1] + canvas_roi_end[1] / 2), + MouseButton.LEFT, + ) + ctrl._canvas.on_mouse_press(mpe) + assert roi_view.selected() + mme = MouseMoveEvent( + (canvas_roi_start[0] + canvas_new_start[0] / 2), + (canvas_roi_start[1] + canvas_new_start[1] / 2), + MouseButton.LEFT, + ) + ctrl._canvas.on_mouse_move(mme) + assert roi.bounding_box[0] == pytest.approx(world_new_start, 1e-6) + assert roi.bounding_box[1] == pytest.approx(world_roi_start, 1e-6) + mre = MouseReleaseEvent( + (canvas_roi_start[0] + canvas_new_start[0] / 2), + (canvas_roi_start[1] + canvas_new_start[1] / 2), + MouseButton.LEFT, + ) + ctrl._canvas.on_mouse_release(mre) + + # Test cursors + roi.bounding_box = (world_roi_start, world_roi_end) + # Top-Left corner + mme = MouseMoveEvent(canvas_roi_start[0], canvas_roi_start[1]) + assert roi_view.get_cursor(mme) == CursorType.FDIAG_ARROW + # Top-Right corner + mme = MouseMoveEvent(canvas_roi_start[0], canvas_roi_end[1]) + assert roi_view.get_cursor(mme) == CursorType.BDIAG_ARROW + # Middle + mme = MouseMoveEvent( + (canvas_roi_start[0] + canvas_roi_end[0]) / 2, + (canvas_roi_start[1] + canvas_roi_end[1]) / 2, + ) + assert roi_view.get_cursor(mme) == CursorType.ALL_ARROW diff --git a/tests/test_models.py b/tests/test_models.py index 6f4094e0..8c8d4e1e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ from unittest.mock import Mock from ndv.models._array_display_model import ArrayDisplayModel +from ndv.models._roi_model import RectangularROIModel def test_array_display_model() -> None: @@ -23,3 +24,28 @@ def test_array_display_model() -> None: assert ArrayDisplayModel.model_json_schema(mode="validation") assert ArrayDisplayModel.model_json_schema(mode="serialization") + + +def test_rectangular_roi_model() -> None: + m = RectangularROIModel() + + mock = Mock() + m.events.bounding_box.connect(mock) + m.events.visible.connect(mock) + + m.bounding_box = ((10, 10), (20, 20)) + mock.assert_called_once_with( + ((10, 10), (20, 20)), # New bounding box value + ((0, 0), (0, 0)), # Initial bounding box on construction + ) + mock.reset_mock() + + m.visible = False + mock.assert_called_once_with( + False, # New visibility + True, # Initial visibility on construction + ) + mock.reset_mock() + + assert RectangularROIModel.model_json_schema(mode="validation") + assert RectangularROIModel.model_json_schema(mode="serialization")