Skip to content

Commit

Permalink
feat: return 3d support to v2 (#83)
Browse files Browse the repository at this point in the history
* starting on base classes

* more removals

* typing complete

* histogram fixes

* remove protocols

* make work with old viewer

* final?

* fix jupyter

* style(pre-commit.ci): auto fixes [...]

* fix pygfx

* fix tests and typing

* move stuff

* fix circle

* revert

* wip

* fix merge

* kinda working

* remove prints

* fix data overwrite

* add to jupyter

* fix for wx

* 3d working on all frontends

* fix tests

* add note

* add tests

* remove unneeded change

* skip wx test

* fix test

* make hist work when switching to 3d

* fix 3d button in jupyter

* update notebook

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tlambert03 and pre-commit-ci[bot] authored Jan 16, 2025
1 parent dc59f7b commit c1be08a
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 48 deletions.
15 changes: 9 additions & 6 deletions examples/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "461399b0-e02d-43d9-9ede-c1aa6c180338",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b582aca26e04cc4a59d755ea986c44f",
"model_id": "d931ccb769154ecaabedf58b566913aa",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -22,7 +23,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3141028c938243b59325ff7ab52e1b3f",
"model_id": "af68f584e2cc479492806c577e70d960",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -45,6 +46,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "455ebabe-c2c1-4366-9784-65e45def5aa2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -53,11 +55,12 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "fb96f975-aa05-4d8a-9f85-f4316293e05d",
"metadata": {},
"outputs": [],
"source": [
"viewer.display_model.default_lut.cmap = \"cubehelix\"\n",
"viewer.display_model.default_lut.cmap = \"magma\"\n",
"viewer.display_model.channel_mode = \"grayscale\""
]
}
Expand All @@ -82,5 +85,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 5
}
48 changes: 32 additions & 16 deletions src/ndv/controllers/_array_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
"When display_model is provided, kwargs are be ignored.",
stacklevel=2,
)
self._data_model = _ArrayDataDisplayModel(
data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs)
)

# mapping of channel keys to their respective controllers
# where None is the default channel
Expand All @@ -72,21 +75,19 @@ def __init__(
frontend_cls = _app.get_array_view_class()
canvas_cls = _app.get_array_canvas_class()
self._canvas = canvas_cls()
self._canvas.set_ndim(2)

self._histogram: HistogramCanvas | None = None
self._view = frontend_cls(self._canvas.frontend_widget())
self._view = frontend_cls(self._canvas.frontend_widget(), self._data_model)

display_model = display_model or ArrayDisplayModel(**kwargs)
self._data_model = _ArrayDataDisplayModel(
data_wrapper=data, display=display_model
)
self._set_model_connected(self._data_model.display)
self._canvas.set_ndim(self.display_model.n_visible_axes)

self._view.currentIndexChanged.connect(self._on_view_current_index_changed)
self._view.resetZoomClicked.connect(self._on_view_reset_zoom_clicked)
self._view.histogramRequested.connect(self._add_histogram)
self._view.channelModeChanged.connect(self._on_view_channel_mode_changed)
self._view.visibleAxesChanged.connect(self._on_view_visible_axes_changed)

self._canvas.mouseMoved.connect(self._on_canvas_mouse_moved)

if self._data_model.data_wrapper is not None:
Expand Down Expand Up @@ -119,7 +120,7 @@ def display_model(self) -> ArrayDisplayModel:
@display_model.setter
def display_model(self, model: ArrayDisplayModel) -> None:
"""Set the ArrayDisplayModel."""
if not isinstance(model, ArrayDisplayModel):
if not isinstance(model, ArrayDisplayModel): # pragma: no cover
raise TypeError("model must be an ArrayDisplayModel")
self._set_model_connected(self._data_model.display, False)
self._data_model.display = model
Expand Down Expand Up @@ -233,6 +234,7 @@ def _fully_synchronize_view(self) -> None:
self._view.create_sliders(self._data_model.normed_data_coords)
self._view.set_channel_mode(display_model.channel_mode)
if self.data is not None:
self._view.set_visible_axes(self._data_model.normed_visible_axes)
self._update_visible_sliders()
if cur_index := display_model.current_index:
self._view.set_current_index(cur_index)
Expand All @@ -248,8 +250,11 @@ def _fully_synchronize_view(self) -> None:
self._update_hist_domain_for_dtype()

def _on_model_visible_axes_changed(self) -> None:
self._view.set_visible_axes(self._data_model.normed_visible_axes)
self._update_visible_sliders()
self._clear_canvas()
self._update_canvas()
self._canvas.set_ndim(self.display_model.n_visible_axes)

def _on_model_current_index_changed(self) -> None:
value = self._data_model.display.current_index
Expand Down Expand Up @@ -283,6 +288,10 @@ def _on_view_current_index_changed(self) -> None:
"""Update the model when slider value changes."""
self._data_model.display.current_index.update(self._view.current_index())

def _on_view_visible_axes_changed(self) -> None:
"""Update the model when the visible axes change."""
self.display_model.visible_axes = self._view.visible_axes() # type: ignore [assignment]

def _on_view_reset_zoom_clicked(self) -> None:
"""Reset the zoom level of the canvas."""
self._canvas.set_range()
Expand Down Expand Up @@ -353,17 +362,24 @@ def _update_canvas(self) -> None:

if not lut_ctrl.handles:
# we don't yet have any handles for this channel
lut_ctrl.add_handle(self._canvas.add_image(data))
if response.n_visible_axes == 2:
handle = self._canvas.add_image(data)
lut_ctrl.add_handle(handle)
elif response.n_visible_axes == 3:
handle = self._canvas.add_volume(data)
lut_ctrl.add_handle(handle)

else:
lut_ctrl.update_texture_data(data)
if self._histogram is not None:
# TODO: once data comes in in chunks, we'll need a proper stateful
# stats object that calculates the histogram incrementally
counts, bin_edges = _calc_hist_bins(data)
# TODO: currently this is updating the histogram on *any*
# channel index... so it doesn't work with composite mode
self._histogram.set_data(counts, bin_edges)
self._histogram.set_range()

if self._histogram is not None:
# TODO: once data comes in in chunks, we'll need a proper stateful
# stats object that calculates the histogram incrementally
counts, bin_edges = _calc_hist_bins(data)
# FIXME: currently this is updating the histogram on *any*
# channel index... so it doesn't work with composite mode
self._histogram.set_data(counts, bin_edges)
self._histogram.set_range()

self._canvas.refresh()

Expand Down
4 changes: 2 additions & 2 deletions src/ndv/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def cells3d() -> np.ndarray:

# this data has been stretched to 16 bit, and lacks certain intensity values
# add a small random integer to each pixel ... so the histogram is not silly
data = (data + np.random.randint(-24, 24, data.shape)).astype(np.uint16)
return data
data = (data + np.random.randint(-24, 24, data.shape)).clip(0, 65535)
return data.astype(np.uint16)


def cat() -> np.ndarray:
Expand Down
38 changes: 23 additions & 15 deletions src/ndv/models/_data_display_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class DataRequest:
"""Request object for data slicing."""

wrapper: DataWrapper
wrapper: DataWrapper = field(repr=False)
index: Mapping[int, Union[int, slice]]
visible_axes: tuple[int, ...]
channel_axis: Optional[int]
Expand All @@ -28,10 +28,19 @@ class DataResponse:
"""Response object for data requests."""

data: np.ndarray = field(repr=False)
shape: tuple[int, ...] = field(init=False)
dtype: np.dtype = field(init=False)
channel_key: Optional[int]
n_visible_axes: int
request: Optional[DataRequest] = None

def __post_init__(self) -> None:
self.shape = self.data.shape
self.dtype = self.data.dtype


# NOTE: nobody particularly likes this class. It does important stuff, but we're
# not yet sure where this logic belongs.
class _ArrayDataDisplayModel(NDVModel):
"""Utility class combining ArrayDisplayModel model with a DataWrapper.
Expand Down Expand Up @@ -154,14 +163,13 @@ def current_slice_requests(self) -> list[DataRequest]:
if isinstance(val, int):
requested_slice[ax] = slice(val, val + 1)

return [
DataRequest(
wrapper=self.data_wrapper,
index=requested_slice,
visible_axes=self.normed_visible_axes,
channel_axis=c_ax,
)
]
request = DataRequest(
wrapper=self.data_wrapper,
index=requested_slice,
visible_axes=self.normed_visible_axes,
channel_axis=c_ax,
)
return [request]

# TODO: make async
def request_sliced_data(self) -> list[Future[DataResponse]]:
Expand Down Expand Up @@ -192,13 +200,13 @@ def request_sliced_data(self) -> list[Future[DataResponse]]:
ch_keepdims = (slice(None),) * cast(int, ch_ax) + (i,) + (None,)
ch_data = data[ch_keepdims]
future = Future[DataResponse]()
future.set_result(
DataResponse(
data=ch_data.transpose(*t_dims).squeeze(),
channel_key=i,
request=req,
)
response = DataResponse(
data=ch_data.transpose(*t_dims).squeeze(),
channel_key=i,
n_visible_axes=len(vis_ax),
request=req,
)
future.set_result(response)
futures.append(future)

return futures
21 changes: 21 additions & 0 deletions src/ndv/models/_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class DataWrapper(Generic[ArrayT], ABC):
PRIORITY: ClassVar[int] = 50
# These names will be checked when looking for a channel axis
COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c")
COMMON_Z_AXIS_NAMES: ClassVar[Container[str]] = ("z", "depth", "focus")

# Maximum dimension size consider when guessing the channel axis
MAX_CHANNELS = 16

Expand Down Expand Up @@ -159,6 +161,9 @@ def sizes(self) -> Mapping[Hashable, int]:
"""Return the sizes of the dimensions."""
return {dim: len(self.coords[dim]) for dim in self.dims}

# these guess_x methods may change in the future to become more agnostic to the
# dimension name/semantics that they are guessing.

def guess_channel_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the channel dimension."""
# for arrays with labeled dimensions,
Expand All @@ -172,6 +177,22 @@ def guess_channel_axis(self) -> Hashable | None:
# otherwise use the smallest dimension as the channel axis
return min(sizes, key=sizes.get) # type: ignore [arg-type]

def guess_z_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the z (3rd spatial) dimension."""
sizes = self.sizes()
ch = self.guess_channel_axis()
for dimkey in sizes:
if str(dimkey).lower() in self.COMMON_Z_AXIS_NAMES:
if (normed := self.normalized_axis_key(dimkey)) != ch:
return normed

# otherwise return the LAST axis that is neither in the last two dimensions
# or the channel axis guess
return next(
(self.normalized_axis_key(x) for x in reversed(self.dims[:-2]) if x != ch),
None,
)

def summary_info(self) -> str:
"""Return info label with information about the data."""
package = getattr(self._data, "__module__", "").split(".")[0]
Expand Down
49 changes: 47 additions & 2 deletions src/ndv/views/_jupyter/_array_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vispy.app.backends import _jupyter_rfb

from ndv._types import AxisKey
from ndv.models._data_display_model import _ArrayDataDisplayModel

# not entirely sure why it's necessary to specifically annotat signals as : PSignal
# i think it has to do with type variance?
Expand Down Expand Up @@ -125,10 +126,14 @@ def frontend_widget(self) -> Any:

class JupyterArrayView(ArrayView):
def __init__(
self, canvas_widget: _jupyter_rfb.CanvasBackend, **kwargs: Any
self,
canvas_widget: _jupyter_rfb.CanvasBackend,
data_model: _ArrayDataDisplayModel,
) -> None:
# WIDGETS
self._data_model = data_model
self._canvas_widget = canvas_widget
self._visible_axes: Sequence[AxisKey] = []

self._sliders: dict[Hashable, widgets.IntSlider] = {}
self._slider_box = widgets.VBox([], layout=widgets.Layout(width="100%"))
Expand All @@ -146,21 +151,36 @@ def __init__(
self._channel_mode_combo.layout.align_self = "flex-end"
self._channel_mode_combo.observe(self._on_channel_mode_changed, names="value")

self._ndims_btn = widgets.ToggleButton(
value=False,
description="3D",
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="View in 3D",
icon="check",
layout=widgets.Layout(width="60px"),
)
self._ndims_btn.observe(self._on_ndims_toggled, names="value")

# LAYOUT

try:
width = getattr(canvas_widget, "css_width", "600px").replace("px", "")
width = f"{int(width) + 4}px"
except Exception:
width = "604px"

btns = widgets.HBox(
[self._channel_mode_combo, self._ndims_btn],
layout=widgets.Layout(justify_content="flex-end"),
)
self.layout = widgets.VBox(
[
self._data_info_label,
self._canvas_widget,
self._hover_info_label,
self._slider_box,
self._luts_box,
self._channel_mode_combo,
btns,
],
layout=widgets.Layout(width=width),
)
Expand Down Expand Up @@ -279,5 +299,30 @@ def set_visible(self, visible: bool) -> None:
else:
display.clear_output() # type: ignore [no-untyped-call]

def visible_axes(self) -> Sequence[AxisKey]:
return self._visible_axes

def set_visible_axes(self, axes: Sequence[AxisKey]) -> None:
self._visible_axes = tuple(axes)
self._ndims_btn.value = len(axes) == 3

def _on_ndims_toggled(self, change: dict[str, Any]) -> None:
if len(self._visible_axes) > 2:
if not change["new"]: # is now 2D
self._visible_axes = self._visible_axes[-2:]
else:
z_ax = None
if wrapper := self._data_model.data_wrapper:
z_ax = wrapper.guess_z_axis()
if z_ax is None:
# get the last slider that is not in visible axes
z_ax = next(
ax for ax in reversed(self._sliders) if ax not in self._visible_axes
)
self._visible_axes = (z_ax, *self._visible_axes)
# TODO: a future PR may decide to set this on the model directly...
# since we now have access to it.
self.visibleAxesChanged.emit()

def close(self) -> None:
self.layout.close()
Loading

0 comments on commit c1be08a

Please sign in to comment.