Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return 3d support to v2 #83

Merged
merged 41 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fa9fb62
starting on base classes
tlambert03 Dec 11, 2024
23045d6
more removals
tlambert03 Dec 11, 2024
ea54a06
typing complete
tlambert03 Dec 11, 2024
06e6b08
histogram fixes
tlambert03 Dec 11, 2024
96ada85
remove protocols
tlambert03 Dec 11, 2024
ba857ac
make work with old viewer
tlambert03 Dec 11, 2024
0b986db
final?
tlambert03 Dec 11, 2024
20437f0
fix jupyter
tlambert03 Dec 11, 2024
8e22916
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Dec 11, 2024
b54237a
fix pygfx
tlambert03 Dec 11, 2024
c2ac039
Merge branch 'v2-bases-instead-of-protocols' of https://github.com/tl…
tlambert03 Dec 11, 2024
45ab81c
fix tests and typing
tlambert03 Dec 11, 2024
ce9453e
move stuff
tlambert03 Dec 11, 2024
ffa085c
fix circle
tlambert03 Dec 11, 2024
5a55ea7
revert
tlambert03 Dec 11, 2024
7b06611
wip
tlambert03 Dec 11, 2024
f4f591c
Merge branch 'v2-mvc' into v2-3d
tlambert03 Dec 14, 2024
ab06356
fix merge
tlambert03 Dec 14, 2024
f958cad
kinda working
tlambert03 Dec 14, 2024
8f257e9
remove prints
tlambert03 Dec 14, 2024
22292d5
fix data overwrite
tlambert03 Dec 14, 2024
0060a2b
add to jupyter
tlambert03 Dec 14, 2024
480e2ad
Merge branch 'v2-mvc' into v2-3d
tlambert03 Dec 16, 2024
ce4389d
Merge branch 'v2-mvc' into v2-3d
tlambert03 Dec 16, 2024
4659f90
Merge branch 'v2-mvc' into v2-3d
tlambert03 Dec 17, 2024
d05c4cf
Merge branch 'v2-mvc' into v2-3d
tlambert03 Dec 19, 2024
6f73d42
fix for wx
tlambert03 Dec 19, 2024
a0837e7
Merge branch 'main' into v2-3d
tlambert03 Jan 9, 2025
c7279b6
Merge branch 'main' into v2-3d
tlambert03 Jan 11, 2025
1668830
3d working on all frontends
tlambert03 Jan 11, 2025
3d5078c
fix tests
tlambert03 Jan 11, 2025
90e5842
add note
tlambert03 Jan 11, 2025
2737b70
add tests
tlambert03 Jan 11, 2025
f414852
remove unneeded change
tlambert03 Jan 11, 2025
a7551a3
skip wx test
tlambert03 Jan 11, 2025
954339d
Merge branch 'main' into v2-3d
tlambert03 Jan 12, 2025
8214245
Merge branch 'main' into v2-3d
tlambert03 Jan 16, 2025
d5af7e1
fix test
tlambert03 Jan 16, 2025
9028d63
make hist work when switching to 3d
tlambert03 Jan 16, 2025
35ae838
fix 3d button in jupyter
tlambert03 Jan 16, 2025
d714dec
update notebook
tlambert03 Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "461399b0-e02d-43d9-9ede-c1aa6c180338",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b582aca26e04cc4a59d755ea986c44f",
"model_id": "d931ccb769154ecaabedf58b566913aa",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -22,7 +23,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3141028c938243b59325ff7ab52e1b3f",
"model_id": "af68f584e2cc479492806c577e70d960",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -45,6 +46,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "455ebabe-c2c1-4366-9784-65e45def5aa2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -53,11 +55,12 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "fb96f975-aa05-4d8a-9f85-f4316293e05d",
"metadata": {},
"outputs": [],
"source": [
"viewer.display_model.default_lut.cmap = \"cubehelix\"\n",
"viewer.display_model.default_lut.cmap = \"magma\"\n",
"viewer.display_model.channel_mode = \"grayscale\""
]
}
Expand All @@ -82,5 +85,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 5
}
48 changes: 32 additions & 16 deletions src/ndv/controllers/_array_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
"When display_model is provided, kwargs are be ignored.",
stacklevel=2,
)
self._data_model = _ArrayDataDisplayModel(
data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs)
)

# mapping of channel keys to their respective controllers
# where None is the default channel
Expand All @@ -72,21 +75,19 @@ def __init__(
frontend_cls = _app.get_array_view_class()
canvas_cls = _app.get_array_canvas_class()
self._canvas = canvas_cls()
self._canvas.set_ndim(2)

self._histogram: HistogramCanvas | None = None
self._view = frontend_cls(self._canvas.frontend_widget())
self._view = frontend_cls(self._canvas.frontend_widget(), self._data_model)

display_model = display_model or ArrayDisplayModel(**kwargs)
self._data_model = _ArrayDataDisplayModel(
data_wrapper=data, display=display_model
)
self._set_model_connected(self._data_model.display)
self._canvas.set_ndim(self.display_model.n_visible_axes)

self._view.currentIndexChanged.connect(self._on_view_current_index_changed)
self._view.resetZoomClicked.connect(self._on_view_reset_zoom_clicked)
self._view.histogramRequested.connect(self._add_histogram)
self._view.channelModeChanged.connect(self._on_view_channel_mode_changed)
self._view.visibleAxesChanged.connect(self._on_view_visible_axes_changed)

self._canvas.mouseMoved.connect(self._on_canvas_mouse_moved)

if self._data_model.data_wrapper is not None:
Expand Down Expand Up @@ -119,7 +120,7 @@ def display_model(self) -> ArrayDisplayModel:
@display_model.setter
def display_model(self, model: ArrayDisplayModel) -> None:
"""Set the ArrayDisplayModel."""
if not isinstance(model, ArrayDisplayModel):
if not isinstance(model, ArrayDisplayModel): # pragma: no cover
raise TypeError("model must be an ArrayDisplayModel")
self._set_model_connected(self._data_model.display, False)
self._data_model.display = model
Expand Down Expand Up @@ -233,6 +234,7 @@ def _fully_synchronize_view(self) -> None:
self._view.create_sliders(self._data_model.normed_data_coords)
self._view.set_channel_mode(display_model.channel_mode)
if self.data is not None:
self._view.set_visible_axes(self._data_model.normed_visible_axes)
self._update_visible_sliders()
if cur_index := display_model.current_index:
self._view.set_current_index(cur_index)
Expand All @@ -248,8 +250,11 @@ def _fully_synchronize_view(self) -> None:
self._update_hist_domain_for_dtype()

def _on_model_visible_axes_changed(self) -> None:
self._view.set_visible_axes(self._data_model.normed_visible_axes)
self._update_visible_sliders()
self._clear_canvas()
self._update_canvas()
self._canvas.set_ndim(self.display_model.n_visible_axes)

def _on_model_current_index_changed(self) -> None:
value = self._data_model.display.current_index
Expand Down Expand Up @@ -283,6 +288,10 @@ def _on_view_current_index_changed(self) -> None:
"""Update the model when slider value changes."""
self._data_model.display.current_index.update(self._view.current_index())

def _on_view_visible_axes_changed(self) -> None:
"""Update the model when the visible axes change."""
self.display_model.visible_axes = self._view.visible_axes() # type: ignore [assignment]

def _on_view_reset_zoom_clicked(self) -> None:
"""Reset the zoom level of the canvas."""
self._canvas.set_range()
Expand Down Expand Up @@ -353,17 +362,24 @@ def _update_canvas(self) -> None:

if not lut_ctrl.handles:
# we don't yet have any handles for this channel
lut_ctrl.add_handle(self._canvas.add_image(data))
if response.n_visible_axes == 2:
handle = self._canvas.add_image(data)
lut_ctrl.add_handle(handle)
elif response.n_visible_axes == 3:
handle = self._canvas.add_volume(data)
lut_ctrl.add_handle(handle)

else:
lut_ctrl.update_texture_data(data)
if self._histogram is not None:
# TODO: once data comes in in chunks, we'll need a proper stateful
# stats object that calculates the histogram incrementally
counts, bin_edges = _calc_hist_bins(data)
# TODO: currently this is updating the histogram on *any*
# channel index... so it doesn't work with composite mode
self._histogram.set_data(counts, bin_edges)
self._histogram.set_range()

if self._histogram is not None:
# TODO: once data comes in in chunks, we'll need a proper stateful
# stats object that calculates the histogram incrementally
counts, bin_edges = _calc_hist_bins(data)
# FIXME: currently this is updating the histogram on *any*
# channel index... so it doesn't work with composite mode
self._histogram.set_data(counts, bin_edges)
self._histogram.set_range()

self._canvas.refresh()

Expand Down
4 changes: 2 additions & 2 deletions src/ndv/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def cells3d() -> np.ndarray:

# this data has been stretched to 16 bit, and lacks certain intensity values
# add a small random integer to each pixel ... so the histogram is not silly
data = (data + np.random.randint(-24, 24, data.shape)).astype(np.uint16)
return data
data = (data + np.random.randint(-24, 24, data.shape)).clip(0, 65535)
return data.astype(np.uint16)


def cat() -> np.ndarray:
Expand Down
38 changes: 23 additions & 15 deletions src/ndv/models/_data_display_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class DataRequest:
"""Request object for data slicing."""

wrapper: DataWrapper
wrapper: DataWrapper = field(repr=False)
index: Mapping[int, Union[int, slice]]
visible_axes: tuple[int, ...]
channel_axis: Optional[int]
Expand All @@ -28,10 +28,19 @@ class DataResponse:
"""Response object for data requests."""

data: np.ndarray = field(repr=False)
shape: tuple[int, ...] = field(init=False)
dtype: np.dtype = field(init=False)
channel_key: Optional[int]
n_visible_axes: int
request: Optional[DataRequest] = None

def __post_init__(self) -> None:
self.shape = self.data.shape
self.dtype = self.data.dtype


# NOTE: nobody particularly likes this class. It does important stuff, but we're
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guilty as charged 🙈

# not yet sure where this logic belongs.
class _ArrayDataDisplayModel(NDVModel):
"""Utility class combining ArrayDisplayModel model with a DataWrapper.

Expand Down Expand Up @@ -154,14 +163,13 @@ def current_slice_requests(self) -> list[DataRequest]:
if isinstance(val, int):
requested_slice[ax] = slice(val, val + 1)

return [
DataRequest(
wrapper=self.data_wrapper,
index=requested_slice,
visible_axes=self.normed_visible_axes,
channel_axis=c_ax,
)
]
request = DataRequest(
wrapper=self.data_wrapper,
index=requested_slice,
visible_axes=self.normed_visible_axes,
channel_axis=c_ax,
)
return [request]

# TODO: make async
def request_sliced_data(self) -> list[Future[DataResponse]]:
Expand Down Expand Up @@ -192,13 +200,13 @@ def request_sliced_data(self) -> list[Future[DataResponse]]:
ch_keepdims = (slice(None),) * cast(int, ch_ax) + (i,) + (None,)
ch_data = data[ch_keepdims]
future = Future[DataResponse]()
future.set_result(
DataResponse(
data=ch_data.transpose(*t_dims).squeeze(),
channel_key=i,
request=req,
)
response = DataResponse(
data=ch_data.transpose(*t_dims).squeeze(),
channel_key=i,
n_visible_axes=len(vis_ax),
request=req,
)
future.set_result(response)
futures.append(future)

return futures
21 changes: 21 additions & 0 deletions src/ndv/models/_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
PRIORITY: ClassVar[int] = 50
# These names will be checked when looking for a channel axis
COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c")
COMMON_Z_AXIS_NAMES: ClassVar[Container[str]] = ("z", "depth", "focus")

# Maximum dimension size consider when guessing the channel axis
MAX_CHANNELS = 16

Expand Down Expand Up @@ -159,6 +161,9 @@
"""Return the sizes of the dimensions."""
return {dim: len(self.coords[dim]) for dim in self.dims}

# these guess_x methods may change in the future to become more agnostic to the
# dimension name/semantics that they are guessing.

def guess_channel_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the channel dimension."""
# for arrays with labeled dimensions,
Expand All @@ -172,6 +177,22 @@
# otherwise use the smallest dimension as the channel axis
return min(sizes, key=sizes.get) # type: ignore [arg-type]

def guess_z_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the z (3rd spatial) dimension."""
sizes = self.sizes()
ch = self.guess_channel_axis()
for dimkey in sizes:
if str(dimkey).lower() in self.COMMON_Z_AXIS_NAMES:
if (normed := self.normalized_axis_key(dimkey)) != ch:
return normed

Check warning on line 187 in src/ndv/models/_data_wrapper.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/models/_data_wrapper.py#L182-L187

Added lines #L182 - L187 were not covered by tests

# otherwise return the LAST axis that is neither in the last two dimensions
# or the channel axis guess
return next(

Check warning on line 191 in src/ndv/models/_data_wrapper.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/models/_data_wrapper.py#L191

Added line #L191 was not covered by tests
(self.normalized_axis_key(x) for x in reversed(self.dims[:-2]) if x != ch),
None,
)

def summary_info(self) -> str:
"""Return info label with information about the data."""
package = getattr(self._data, "__module__", "").split(".")[0]
Expand Down
49 changes: 47 additions & 2 deletions src/ndv/views/_jupyter/_array_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vispy.app.backends import _jupyter_rfb

from ndv._types import AxisKey
from ndv.models._data_display_model import _ArrayDataDisplayModel

# not entirely sure why it's necessary to specifically annotat signals as : PSignal
# i think it has to do with type variance?
Expand Down Expand Up @@ -125,10 +126,14 @@

class JupyterArrayView(ArrayView):
def __init__(
self, canvas_widget: _jupyter_rfb.CanvasBackend, **kwargs: Any
self,
canvas_widget: _jupyter_rfb.CanvasBackend,
data_model: _ArrayDataDisplayModel,
) -> None:
# WIDGETS
self._data_model = data_model
self._canvas_widget = canvas_widget
self._visible_axes: Sequence[AxisKey] = []

self._sliders: dict[Hashable, widgets.IntSlider] = {}
self._slider_box = widgets.VBox([], layout=widgets.Layout(width="100%"))
Expand All @@ -146,21 +151,36 @@
self._channel_mode_combo.layout.align_self = "flex-end"
self._channel_mode_combo.observe(self._on_channel_mode_changed, names="value")

self._ndims_btn = widgets.ToggleButton(
value=False,
description="3D",
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="View in 3D",
icon="check",
layout=widgets.Layout(width="60px"),
)
self._ndims_btn.observe(self._on_ndims_toggled, names="value")

# LAYOUT

try:
width = getattr(canvas_widget, "css_width", "600px").replace("px", "")
width = f"{int(width) + 4}px"
except Exception:
width = "604px"

btns = widgets.HBox(
[self._channel_mode_combo, self._ndims_btn],
layout=widgets.Layout(justify_content="flex-end"),
)
self.layout = widgets.VBox(
[
self._data_info_label,
self._canvas_widget,
self._hover_info_label,
self._slider_box,
self._luts_box,
self._channel_mode_combo,
btns,
],
layout=widgets.Layout(width=width),
)
Expand Down Expand Up @@ -279,5 +299,30 @@
else:
display.clear_output() # type: ignore [no-untyped-call]

def visible_axes(self) -> Sequence[AxisKey]:
return self._visible_axes

def set_visible_axes(self, axes: Sequence[AxisKey]) -> None:
self._visible_axes = tuple(axes)
self._ndims_btn.value = len(axes) == 3

def _on_ndims_toggled(self, change: dict[str, Any]) -> None:
if len(self._visible_axes) > 2:
if not change["new"]: # is now 2D
self._visible_axes = self._visible_axes[-2:]

Check warning on line 312 in src/ndv/views/_jupyter/_array_view.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/views/_jupyter/_array_view.py#L312

Added line #L312 was not covered by tests
else:
z_ax = None
if wrapper := self._data_model.data_wrapper:
z_ax = wrapper.guess_z_axis()
if z_ax is None:

Check warning on line 317 in src/ndv/views/_jupyter/_array_view.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/views/_jupyter/_array_view.py#L314-L317

Added lines #L314 - L317 were not covered by tests
# get the last slider that is not in visible axes
z_ax = next(

Check warning on line 319 in src/ndv/views/_jupyter/_array_view.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/views/_jupyter/_array_view.py#L319

Added line #L319 was not covered by tests
ax for ax in reversed(self._sliders) if ax not in self._visible_axes
)
self._visible_axes = (z_ax, *self._visible_axes)

Check warning on line 322 in src/ndv/views/_jupyter/_array_view.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/views/_jupyter/_array_view.py#L322

Added line #L322 was not covered by tests
# TODO: a future PR may decide to set this on the model directly...
# since we now have access to it.
self.visibleAxesChanged.emit()

def close(self) -> None:
self.layout.close()
Loading
Loading