Skip to content

Commit 24bec3e

Browse files
gselzerpre-commit-ci[bot]tlambert03
authored
feat: Rectangular ROIs (#114)
* ROI Model * WIP * Always provide a ROIModel Probably want to refactor this later * Work with all gui frontends * Note bug with Wx+PyGFX * style(pre-commit.ci): auto fixes [...] * Make test pass * Minor cleanup * Weak Ref to last roi created Necessary for the "click roi button then click and drag" functionality. Still not happy with this pattern, though... * Rename RectangularROI views * Complete roi controller test * Remove old Mouseable class * Fix pygfx+wx * Wx: Prevent repeat sliders This can happen when you call ArrayViewer.fully_synchronize_view multiple times. * Remove FIXME * Shorten docstring * WIP: Fix tests * Use cache over lru_cache * Explicitly validate bounding box after Co-authored-by: Talley Lambert <[email protected]> * style(pre-commit.ci): auto fixes [...] * Correctly type bounding box docstring * sync roi in separate method * Check for roi_model not None * Move button selection to mouse down * Intercept jupyter mouse events as needed * Move Qt import to top * Patch ArrayCanvas.elements_at in roi test * Actually check around jupyter mouse button * Add roi interaction test * Skip test on Jupyter+PyGFX * Fix RectangularROIModel docstring * Update src/ndv/models/_roi_model.py * style(pre-commit.ci): auto fixes [...] --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Talley Lambert <[email protected]>
1 parent c3cf013 commit 24bec3e

24 files changed

+1166
-668
lines changed

src/ndv/_types.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Hashable, Sequence
66
from contextlib import suppress
77
from enum import Enum, IntFlag, auto
8+
from functools import cache
89
from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, cast
910

1011
from pydantic import PlainSerializer, PlainValidator
@@ -13,6 +14,7 @@
1314
if TYPE_CHECKING:
1415
from qtpy.QtCore import Qt
1516
from qtpy.QtWidgets import QWidget
17+
from wx import Cursor
1618

1719
from ndv.views.bases import Viewable
1820

@@ -56,33 +58,36 @@ class MouseButton(IntFlag):
5658
LEFT = auto()
5759
MIDDLE = auto()
5860
RIGHT = auto()
61+
NONE = auto()
5962

6063

6164
class MouseMoveEvent(NamedTuple):
6265
"""Event emitted when the user moves the cursor."""
6366

6467
x: float
6568
y: float
69+
btn: MouseButton = MouseButton.NONE
6670

6771

6872
class MousePressEvent(NamedTuple):
6973
"""Event emitted when mouse button is pressed."""
7074

7175
x: float
7276
y: float
73-
btn: MouseButton = MouseButton.LEFT
77+
btn: MouseButton
7478

7579

7680
class MouseReleaseEvent(NamedTuple):
7781
"""Event emitted when mouse button is released."""
7882

7983
x: float
8084
y: float
81-
btn: MouseButton = MouseButton.LEFT
85+
btn: MouseButton
8286

8387

8488
class CursorType(Enum):
8589
DEFAULT = "default"
90+
CROSS = "cross"
8691
V_ARROW = "v_arrow"
8792
H_ARROW = "h_arrow"
8893
ALL_ARROW = "all_arrow"
@@ -101,9 +106,47 @@ def to_qt(self) -> Qt.CursorShape:
101106

102107
return {
103108
CursorType.DEFAULT: Qt.CursorShape.ArrowCursor,
109+
CursorType.CROSS: Qt.CursorShape.CrossCursor,
104110
CursorType.V_ARROW: Qt.CursorShape.SizeVerCursor,
105111
CursorType.H_ARROW: Qt.CursorShape.SizeHorCursor,
106112
CursorType.ALL_ARROW: Qt.CursorShape.SizeAllCursor,
107113
CursorType.BDIAG_ARROW: Qt.CursorShape.SizeBDiagCursor,
108114
CursorType.FDIAG_ARROW: Qt.CursorShape.SizeFDiagCursor,
109115
}[self]
116+
117+
def to_jupyter(self) -> str:
118+
"""Converts CursorType to jupyter cursor strings."""
119+
return {
120+
CursorType.DEFAULT: "default",
121+
CursorType.CROSS: "crosshair",
122+
CursorType.V_ARROW: "ns-resize",
123+
CursorType.H_ARROW: "ew-resize",
124+
CursorType.ALL_ARROW: "move",
125+
CursorType.BDIAG_ARROW: "nesw-resize",
126+
CursorType.FDIAG_ARROW: "nwse-resize",
127+
}[self]
128+
129+
# Note a new object must be created every time. We should cache it!
130+
@cache
131+
def to_wx(self) -> Cursor:
132+
"""Converts CursorType to jupyter cursor strings."""
133+
from wx import (
134+
CURSOR_ARROW,
135+
CURSOR_CROSS,
136+
CURSOR_SIZENESW,
137+
CURSOR_SIZENS,
138+
CURSOR_SIZENWSE,
139+
CURSOR_SIZEWE,
140+
CURSOR_SIZING,
141+
Cursor,
142+
)
143+
144+
return {
145+
CursorType.DEFAULT: Cursor(CURSOR_ARROW),
146+
CursorType.CROSS: Cursor(CURSOR_CROSS),
147+
CursorType.V_ARROW: Cursor(CURSOR_SIZENS),
148+
CursorType.H_ARROW: Cursor(CURSOR_SIZEWE),
149+
CursorType.ALL_ARROW: Cursor(CURSOR_SIZING),
150+
CursorType.BDIAG_ARROW: Cursor(CURSOR_SIZENESW),
151+
CursorType.FDIAG_ARROW: Cursor(CURSOR_SIZENWSE),
152+
}[self]

src/ndv/controllers/_array_viewer.py

+102-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from ndv.controllers._channel_controller import ChannelController
1010
from ndv.models import ArrayDisplayModel, ChannelMode, DataWrapper, LUTModel
1111
from ndv.models._data_display_model import DataResponse, _ArrayDataDisplayModel
12+
from ndv.models._roi_model import RectangularROIModel
13+
from ndv.models._viewer_model import ArrayViewerModel, InteractionMode
1214
from ndv.views import _app
1315

1416
if TYPE_CHECKING:
@@ -21,6 +23,7 @@
2123
from ndv._types import MouseMoveEvent
2224
from ndv.models._array_display_model import ArrayDisplayModelKwargs
2325
from ndv.views.bases import HistogramCanvas
26+
from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle
2427

2528
LutKey: TypeAlias = int | None
2629

@@ -69,6 +72,11 @@ def __init__(
6972
self._data_model = _ArrayDataDisplayModel(
7073
data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs)
7174
)
75+
self._viewer_model = ArrayViewerModel()
76+
self._viewer_model.events.interaction_mode.connect(
77+
self._on_interaction_mode_changed
78+
)
79+
self._roi_model: RectangularROIModel | None = None
7280

7381
app = _app.gui_frontend()
7482

@@ -88,10 +96,14 @@ def __init__(
8896
# get and create the front-end and canvas classes
8997
frontend_cls = _app.get_array_view_class()
9098
canvas_cls = _app.get_array_canvas_class()
91-
self._canvas = canvas_cls()
99+
self._canvas = canvas_cls(self._viewer_model)
92100

93101
self._histogram: HistogramCanvas | None = None
94-
self._view = frontend_cls(self._canvas.frontend_widget(), self._data_model)
102+
self._view = frontend_cls(
103+
self._canvas.frontend_widget(), self._data_model, self._viewer_model
104+
)
105+
106+
self._roi_view: RectangularROIHandle | None = None
95107

96108
self._set_model_connected(self._data_model.display)
97109
self._canvas.set_ndim(self.display_model.n_visible_axes)
@@ -163,6 +175,27 @@ def data(self, data: Any) -> None:
163175
self._data_model.data_wrapper = DataWrapper.create(data)
164176
self._fully_synchronize_view()
165177

178+
@property
179+
def roi(self) -> RectangularROIModel | None:
180+
"""Return ROI being displayed."""
181+
return self._roi_model
182+
183+
@roi.setter
184+
def roi(self, roi_model: RectangularROIModel | None) -> None:
185+
"""Set ROI being displayed."""
186+
# Disconnect old model
187+
if self._roi_model is not None:
188+
self._set_roi_model_connected(self._roi_model, False)
189+
190+
# Connect new model
191+
if isinstance(roi_model, tuple):
192+
self._roi_model = RectangularROIModel(bounding_box=roi_model)
193+
else:
194+
self._roi_model = roi_model
195+
if self._roi_model is not None:
196+
self._set_roi_model_connected(self._roi_model)
197+
self._synchronize_roi()
198+
166199
def show(self) -> None:
167200
"""Show the viewer."""
168201
self._view.set_visible(True)
@@ -239,6 +272,28 @@ def _set_model_connected(
239272
]:
240273
getattr(obj, _connect)(callback)
241274

275+
def _set_roi_model_connected(
276+
self, model: RectangularROIModel, connect: bool = True
277+
) -> None:
278+
"""Connect or disconnect the model to/from the viewer.
279+
280+
We do this in a single method so that we are sure to connect and disconnect
281+
the same events in the same order. (but it's kinda ugly)
282+
"""
283+
_connect = "connect" if connect else "disconnect"
284+
285+
for obj, callback in [
286+
(model.events.bounding_box, self._on_roi_model_bounding_box_changed),
287+
(model.events.visible, self._on_roi_model_visible_changed),
288+
]:
289+
getattr(obj, _connect)(callback)
290+
291+
if _connect:
292+
self._create_roi_view()
293+
else:
294+
if self._roi_view:
295+
self._roi_view.remove()
296+
242297
# ------------------ Model callbacks ------------------
243298

244299
def _fully_synchronize_view(self) -> None:
@@ -261,6 +316,13 @@ def _fully_synchronize_view(self) -> None:
261316
for lut_ctr in self._lut_controllers.values():
262317
lut_ctr.synchronize()
263318
self._update_hist_domain_for_dtype()
319+
self._synchronize_roi()
320+
321+
def _synchronize_roi(self) -> None:
322+
"""Fully re-synchronize the ROI view with the model."""
323+
if self.roi is not None:
324+
self._on_roi_model_bounding_box_changed(self.roi.bounding_box)
325+
self._on_roi_model_visible_changed(self.roi.visible)
264326

265327
def _on_model_visible_axes_changed(self) -> None:
266328
self._view.set_visible_axes(self._data_model.normed_visible_axes)
@@ -288,6 +350,38 @@ def _on_model_channel_mode_changed(self, mode: ChannelMode) -> None:
288350
self._clear_canvas()
289351
self._request_data()
290352

353+
def _on_roi_model_bounding_box_changed(
354+
self, bb: tuple[tuple[float, float], tuple[float, float]]
355+
) -> None:
356+
if self._roi_view is not None:
357+
self._roi_view.set_bounding_box(*bb)
358+
359+
def _on_roi_model_visible_changed(self, visible: bool) -> None:
360+
if self._roi_view is not None:
361+
self._roi_view.set_visible(visible)
362+
363+
def _on_interaction_mode_changed(self, mode: InteractionMode) -> None:
364+
if mode == InteractionMode.CREATE_ROI:
365+
# Create ROI model if needed to store ROI state
366+
if self.roi is None:
367+
self.roi = RectangularROIModel(visible=False)
368+
369+
# Create a new ROI
370+
self._create_roi_view()
371+
372+
def _create_roi_view(self) -> None:
373+
# Remove old ROI view
374+
# TODO: Enable multiple ROIs
375+
if self._roi_view:
376+
self._roi_view.remove()
377+
378+
# Create new ROI view
379+
self._roi_view = self._canvas.add_bounding_box()
380+
# Connect view signals
381+
self._roi_view.boundingBoxChanged.connect(
382+
self._on_roi_view_bounding_box_changed
383+
)
384+
291385
def _clear_canvas(self) -> None:
292386
for lut_ctrl in self._lut_controllers.values():
293387
# self._view.remove_lut_view(lut_ctrl.lut_view)
@@ -309,6 +403,12 @@ def _on_view_reset_zoom_clicked(self) -> None:
309403
"""Reset the zoom level of the canvas."""
310404
self._canvas.set_range()
311405

406+
def _on_roi_view_bounding_box_changed(
407+
self, bb: tuple[tuple[float, float], tuple[float, float]]
408+
) -> None:
409+
if self._roi_model:
410+
self._roi_model.bounding_box = bb
411+
312412
def _on_canvas_mouse_moved(self, event: MouseMoveEvent) -> None:
313413
"""Respond to a mouse move event in the view."""
314414
x, y, _z = self._canvas.canvas_to_world((event.x, event.y))

src/ndv/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ClimsStdDev,
1212
LUTModel,
1313
)
14+
from ._roi_model import RectangularROIModel
1415

1516
__all__ = [
1617
"ArrayDisplayModel",
@@ -23,4 +24,5 @@
2324
"DataWrapper",
2425
"LUTModel",
2526
"NDVModel",
27+
"RectangularROIModel",
2628
]

src/ndv/models/_roi_model.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
from pydantic import field_validator
4+
5+
from ndv.models._base_model import NDVModel
6+
7+
8+
class RectangularROIModel(NDVModel):
9+
"""Representation of an axis-aligned rectangular Region of Interest (ROI).
10+
11+
Attributes
12+
----------
13+
visible : bool
14+
Whether to display this roi.
15+
bounding_box : tuple[tuple[float, float], tuple[float, float]]
16+
The minimum (2D) point and the maximum (2D) point contained within the
17+
region. Using these two points, an axis-aligned bounding box can be
18+
constructed.
19+
"""
20+
21+
visible: bool = True
22+
bounding_box: tuple[tuple[float, float], tuple[float, float]] = ((0, 0), (0, 0))
23+
24+
@field_validator("bounding_box", mode="after")
25+
@classmethod
26+
def _validate_bounding_box(
27+
cls, bb: tuple[tuple[float, float], tuple[float, float]]
28+
) -> tuple[tuple[float, float], tuple[float, float]]:
29+
x1 = min(bb[0][0], bb[1][0])
30+
y1 = min(bb[0][1], bb[1][1])
31+
x2 = max(bb[0][0], bb[1][0])
32+
y2 = max(bb[0][1], bb[1][1])
33+
return ((x1, y1), (x2, y2))

src/ndv/models/_viewer_model.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from enum import Enum, auto
2+
3+
from ndv.models._base_model import NDVModel
4+
5+
6+
class InteractionMode(Enum):
7+
"""An enum defining graphical interaction mechanisms with an array Viewer."""
8+
9+
PAN_ZOOM = auto() # Mode allowing the user to pan and zoom
10+
CREATE_ROI = auto() # Mode where user clicks create ROIs
11+
12+
13+
class ArrayViewerModel(NDVModel):
14+
"""Representation of an array viewer.
15+
16+
TODO: This will likely contain other fields including:
17+
* Dimensionality
18+
* Camera position
19+
* Camera frustum
20+
21+
Parameters
22+
----------
23+
interaction_mode : InteractionMode
24+
Describes the current interaction mode of the Viewer.
25+
"""
26+
27+
interaction_mode: InteractionMode = InteractionMode.PAN_ZOOM

0 commit comments

Comments
 (0)