Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Rectangular ROIs #114

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions src/ndv/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Hashable, Sequence
from contextlib import suppress
from enum import Enum, IntFlag, auto
from functools import lru_cache
from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, cast

from pydantic import PlainSerializer, PlainValidator
Expand All @@ -13,6 +14,7 @@
if TYPE_CHECKING:
from qtpy.QtCore import Qt
from qtpy.QtWidgets import QWidget
from wx import Cursor

from ndv.views.bases import Viewable

Expand Down Expand Up @@ -56,33 +58,36 @@ class MouseButton(IntFlag):
LEFT = auto()
MIDDLE = auto()
RIGHT = auto()
NONE = auto()


class MouseMoveEvent(NamedTuple):
"""Event emitted when the user moves the cursor."""

x: float
y: float
btn: MouseButton = MouseButton.NONE


class MousePressEvent(NamedTuple):
"""Event emitted when mouse button is pressed."""

x: float
y: float
btn: MouseButton = MouseButton.LEFT
btn: MouseButton


class MouseReleaseEvent(NamedTuple):
"""Event emitted when mouse button is released."""

x: float
y: float
btn: MouseButton = MouseButton.LEFT
btn: MouseButton


class CursorType(Enum):
DEFAULT = "default"
CROSS = "cross"
V_ARROW = "v_arrow"
H_ARROW = "h_arrow"
ALL_ARROW = "all_arrow"
Expand All @@ -101,9 +106,46 @@ def to_qt(self) -> Qt.CursorShape:

return {
CursorType.DEFAULT: Qt.CursorShape.ArrowCursor,
CursorType.CROSS: Qt.CursorShape.CrossCursor,
CursorType.V_ARROW: Qt.CursorShape.SizeVerCursor,
CursorType.H_ARROW: Qt.CursorShape.SizeHorCursor,
CursorType.ALL_ARROW: Qt.CursorShape.SizeAllCursor,
CursorType.BDIAG_ARROW: Qt.CursorShape.SizeBDiagCursor,
CursorType.FDIAG_ARROW: Qt.CursorShape.SizeFDiagCursor,
}[self]

def to_jupyter(self) -> str:
"""Converts CursorType to jupyter cursor strings."""
return {
CursorType.DEFAULT: "default",
CursorType.CROSS: "crosshair",
CursorType.V_ARROW: "ns-resize",
CursorType.H_ARROW: "ew-resize",
CursorType.ALL_ARROW: "move",
CursorType.BDIAG_ARROW: "nesw-resize",
CursorType.FDIAG_ARROW: "nwse-resize",
}[self]

@lru_cache
def to_wx(self) -> Cursor:
"""Converts CursorType to jupyter cursor strings."""
from wx import (
CURSOR_ARROW,
CURSOR_CROSS,
CURSOR_SIZENESW,
CURSOR_SIZENS,
CURSOR_SIZENWSE,
CURSOR_SIZEWE,
CURSOR_SIZING,
Cursor,
)

return {
CursorType.DEFAULT: Cursor(CURSOR_ARROW),
CursorType.CROSS: Cursor(CURSOR_CROSS),
CursorType.V_ARROW: Cursor(CURSOR_SIZENS),
CursorType.H_ARROW: Cursor(CURSOR_SIZEWE),
CursorType.ALL_ARROW: Cursor(CURSOR_SIZING),
CursorType.BDIAG_ARROW: Cursor(CURSOR_SIZENESW),
CursorType.FDIAG_ARROW: Cursor(CURSOR_SIZENWSE),
}[self]
88 changes: 86 additions & 2 deletions src/ndv/controllers/_array_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from ndv.controllers._channel_controller import ChannelController
from ndv.models import ArrayDisplayModel, ChannelMode, DataWrapper, LUTModel
from ndv.models._data_display_model import DataResponse, _ArrayDataDisplayModel
from ndv.models._roi_model import RectangularROIModel
from ndv.models._viewer_model import ArrayViewerModel, InteractionMode
from ndv.views import _app

if TYPE_CHECKING:
Expand All @@ -20,6 +22,7 @@
from ndv._types import MouseMoveEvent
from ndv.models._array_display_model import ArrayDisplayModelKwargs
from ndv.views.bases import HistogramCanvas
from ndv.views.bases._graphics._canvas_elements import RectangularROI

LutKey: TypeAlias = int | None

Expand Down Expand Up @@ -68,6 +71,11 @@ def __init__(
self._data_model = _ArrayDataDisplayModel(
data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs)
)
self._viewer_model = ArrayViewerModel()
self._viewer_model.events.interaction_mode.connect(
self._on_interaction_mode_changed
)
self._roi_model: RectangularROIModel | None = RectangularROIModel(visible=False)

app = _app.gui_frontend()

Expand All @@ -87,12 +95,17 @@ def __init__(
# get and create the front-end and canvas classes
frontend_cls = _app.get_array_view_class()
canvas_cls = _app.get_array_canvas_class()
self._canvas = canvas_cls()
self._canvas = canvas_cls(self._viewer_model)

self._histogram: HistogramCanvas | None = None
self._view = frontend_cls(self._canvas.frontend_widget(), self._data_model)
self._view = frontend_cls(
self._canvas.frontend_widget(), self._data_model, self._viewer_model
)

self._roi_view: RectangularROI | None = None

self._set_model_connected(self._data_model.display)
self._set_roi_model_connected(self._roi_model)
self._canvas.set_ndim(self.display_model.n_visible_axes)

self._view.currentIndexChanged.connect(self._on_view_current_index_changed)
Expand Down Expand Up @@ -162,6 +175,19 @@ def data(self, data: Any) -> None:
self._data_model.data_wrapper = DataWrapper.create(data)
self._fully_synchronize_view()

@property
def roi(self) -> RectangularROIModel | None:
return self._roi_model

@roi.setter
def roi(self, roi_model: RectangularROIModel | None) -> None:
if self._roi_model is not None:
self._set_roi_model_connected(self._roi_model, False)
self._roi_model = roi_model
if self._roi_model is not None:
self._set_roi_model_connected(self._roi_model)
self._fully_synchronize_view()

def show(self) -> None:
"""Show the viewer."""
self._view.set_visible(True)
Expand Down Expand Up @@ -238,6 +264,22 @@ def _set_model_connected(
]:
getattr(obj, _connect)(callback)

def _set_roi_model_connected(
self, model: RectangularROIModel, connect: bool = True
) -> None:
"""Connect or disconnect the model to/from the viewer.

We do this in a single method so that we are sure to connect and disconnect
the same events in the same order. (but it's kinda ugly)
"""
_connect = "connect" if connect else "disconnect"

for obj, callback in [
(model.events.bounding_box, self._on_roi_model_bounding_box_changed),
(model.events.visible, self._on_roi_model_visible_changed),
]:
getattr(obj, _connect)(callback)

# ------------------ Model callbacks ------------------

def _fully_synchronize_view(self) -> None:
Expand All @@ -261,6 +303,9 @@ def _fully_synchronize_view(self) -> None:
for lut_ctr in self._lut_controllers.values():
lut_ctr._update_view_from_model()
self._update_hist_domain_for_dtype()
if self.roi is not None:
self._on_roi_model_bounding_box_changed(self.roi.bounding_box)
self._on_roi_model_visible_changed(self.roi.visible)

def _on_model_visible_axes_changed(self) -> None:
self._view.set_visible_axes(self._data_model.normed_visible_axes)
Expand Down Expand Up @@ -288,6 +333,39 @@ def _on_model_channel_mode_changed(self, mode: ChannelMode) -> None:
self._clear_canvas()
self._request_data()

def _on_roi_model_bounding_box_changed(
self, bb: tuple[tuple[float, float], tuple[float, float]]
) -> None:
if self._roi_view is None:
self._roi_view = self._canvas.add_bounding_box()
# HACK
self._roi_view.set_visible(True)
self._roi_view.boundingBoxChanged.connect(
self._on_roi_view_bounding_box_changed
)
self._roi_view.set_bounding_box(*bb)

def _on_roi_model_visible_changed(self, visible: bool) -> None:
if self._roi_view is None:
self._roi_view = self._canvas.add_bounding_box()
# HACK
self._roi_view.set_visible(True)
self._roi_view.boundingBoxChanged.connect(
self._on_roi_view_bounding_box_changed
)
self._roi_view.set_visible(visible)

def _on_interaction_mode_changed(self, mode: InteractionMode) -> None:
# TODO: Unify with _on_roi_model_bounding_box_changed
if mode == InteractionMode.CREATE_ROI:
if self._roi_view:
self._roi_view.remove()
self._roi_view = self._canvas.add_bounding_box()
# HACK
self._roi_view.boundingBoxChanged.connect(
self._on_roi_view_bounding_box_changed
)

def _clear_canvas(self) -> None:
for lut_ctrl in self._lut_controllers.values():
# self._view.remove_lut_view(lut_ctrl.lut_view)
Expand All @@ -309,6 +387,12 @@ def _on_view_reset_zoom_clicked(self) -> None:
"""Reset the zoom level of the canvas."""
self._canvas.set_range()

def _on_roi_view_bounding_box_changed(
self, bb: tuple[tuple[float, float], tuple[float, float]]
) -> None:
if self._roi_model:
self._roi_model.bounding_box = bb

def _on_canvas_mouse_moved(self, event: MouseMoveEvent) -> None:
"""Respond to a mouse move event in the view."""
x, y, _z = self._canvas.canvas_to_world((event.x, event.y))
Expand Down
2 changes: 2 additions & 0 deletions src/ndv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ClimsStdDev,
LUTModel,
)
from ._roi_model import RectangularROIModel

__all__ = [
"ArrayDisplayModel",
Expand All @@ -23,4 +24,5 @@
"DataWrapper",
"LUTModel",
"NDVModel",
"RectangularROIModel",
]
39 changes: 39 additions & 0 deletions src/ndv/models/_roi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from pydantic import field_validator

from ndv.models._base_model import NDVModel

if TYPE_CHECKING:
from typing import Any


class RectangularROIModel(NDVModel):
"""Representation of how to display an axis-aligned rectangular Region of Interest (ROI).

Parameters
----------
visible : bool
Whether to display this roi.
bounding_box: tuple[Sequence[float], Sequence[float]]
The minimum point and the maximum point contained within the region.
Using these two points, an axis-aligned bounding box can be constructed.
"""

visible: bool = True
bounding_box: tuple[tuple[float, float], tuple[float, float]] = ((0, 0), (0, 0))

@field_validator("bounding_box")
@classmethod
def _validate_bounding_box(
cls, bb: Any
) -> tuple[tuple[float, float], tuple[float, float]]:
if not isinstance(bb, tuple):
raise ValueError(f"{bb} not a tuple of points!")
x1 = min(bb[0][0], bb[1][0])
y1 = min(bb[0][1], bb[1][1])
x2 = max(bb[0][0], bb[1][0])
y2 = max(bb[0][1], bb[1][1])
return ((x1, y1), (x2, y2))
27 changes: 27 additions & 0 deletions src/ndv/models/_viewer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum, auto

from ndv.models._base_model import NDVModel


class InteractionMode(Enum):
"""An enum defining graphical interaction mechanisms with an array Viewer."""

PAN_ZOOM = auto() # Mode allowing the user to pan and zoom
CREATE_ROI = auto() # Mode where user clicks create ROIs


class ArrayViewerModel(NDVModel):
"""Representation of an array viewer.

TODO: This will likely contain other fields including:
* Dimensionality
* Camera position
* Camera frustum

Parameters
----------
interaction_mode : InteractionMode
Describes the current interaction mode of the Viewer.
"""

interaction_mode: InteractionMode = InteractionMode.PAN_ZOOM
Loading
Loading