diff --git a/pyproject.toml b/pyproject.toml index 37fec96a..a5a62016 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,8 @@ disallow_any_generics = false disallow_subclassing_any = false show_error_codes = true pretty = true +# sometimes experiencing a pydantic mypy plugin bug +# dealing with recursive models, like ndv.models._scene.nodes.Node plugins = ["pydantic.mypy"] [[tool.mypy.overrides]] diff --git a/src/ndv/_types.py b/src/ndv/_types.py index f39c68db..91ed06a5 100644 --- a/src/ndv/_types.py +++ b/src/ndv/_types.py @@ -7,6 +7,7 @@ from enum import Enum, IntFlag, auto from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, cast +import numpy.typing as npt from pydantic import PlainSerializer, PlainValidator from typing_extensions import TypeAlias @@ -107,3 +108,27 @@ def to_qt(self) -> Qt.CursorShape: CursorType.BDIAG_ARROW: Qt.CursorShape.SizeBDiagCursor, CursorType.FDIAG_ARROW: Qt.CursorShape.SizeFDiagCursor, }[self] + + +class CameraType(str, Enum): + """Camera type.""" + + ARCBALL = "arcball" + PANZOOM = "panzoom" + + def __str__(self) -> str: + return self.value + + +ArrayLike: TypeAlias = npt.NDArray + + +class ImageInterpolation(str, Enum): + """Image interpolation options.""" + + LINEAR = "linear" + NEAREST = "nearest" + BICUBIC = "bicubic" + + def __str__(self) -> str: + return self.value diff --git a/src/ndv/controllers/_array_viewer.py b/src/ndv/controllers/_array_viewer.py index 605a20d7..3dbae5a8 100644 --- a/src/ndv/controllers/_array_viewer.py +++ b/src/ndv/controllers/_array_viewer.py @@ -66,7 +66,8 @@ def __init__( stacklevel=2, ) self._data_model = _ArrayDataDisplayModel( - data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs) + data_wrapper=data, + display=display_model or ArrayDisplayModel(**kwargs), ) app = _app.gui_frontend() diff --git a/src/ndv/models/_mapping.py b/src/ndv/models/_mapping.py index 340f24f4..c76ae83e 100644 --- a/src/ndv/models/_mapping.py +++ b/src/ndv/models/_mapping.py @@ -290,6 +290,8 @@ def _new(*args: Any, **kwargs: Any) -> ValidatedEventedDict[_KT, _VT]: def _get_schema(hint: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: # check if the hint already has a core schema attached to it. + # this helps to avoid `Definitions error: definition `___` was never filled` + # for recursive types. if hasattr(hint, "__pydantic_core_schema__"): return cast("core_schema.CoreSchema", hint.__pydantic_core_schema__) # otherwise, call the handler to get the core schema. diff --git a/src/ndv/models/_scene/__init__.py b/src/ndv/models/_scene/__init__.py new file mode 100644 index 00000000..4dad735e --- /dev/null +++ b/src/ndv/models/_scene/__init__.py @@ -0,0 +1,14 @@ +from ._transform import Transform +from .canvas import Canvas +from .nodes import Camera, Image, Node, Scene +from .view import View + +__all__ = [ + "Camera", + "Canvas", + "Image", + "Node", + "Scene", + "Transform", + "View", +] diff --git a/src/ndv/models/_scene/_transform.py b/src/ndv/models/_scene/_transform.py new file mode 100644 index 00000000..4a5f941d --- /dev/null +++ b/src/ndv/models/_scene/_transform.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +import functools +import math +from functools import reduce +from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast + +import numpy as np +from numpy.typing import ArrayLike, DTypeLike, NDArray +from pydantic import ConfigDict, Field, RootModel +from pydantic_core import core_schema + +if TYPE_CHECKING: + from collections.abc import Iterable, Sized + + from pydantic import GetCoreSchemaHandler + + +def _arg_to_vec4( + func: Callable[[Transform, ArrayLike], NDArray], +) -> Callable[[Transform, ArrayLike], NDArray]: + """Return method decorator that converts arg to vec4, suitable for 4x4 matrix mul. + + [x, y] => [[x, y, 0, 1]] + + [x, y, z] => [[x, y, z, 1]] + + [[x1, y1], [[x1, y1, 0, 1], + [x2, y2], => [x2, y2, 0, 1], + [x3, y3]] [x3, y3, 0, 1]] + + If 1D input is provided, then the return value will be flattened. + Accepts input of any dimension, as long as shape[-1] <= 4 + """ + + @functools.wraps(func) + def wrapper(self_: Transform, arg: ArrayLike) -> NDArray: + if not isinstance(arg, (tuple, list, np.ndarray)): + raise TypeError(f"Cannot convert argument to 4D vector: {arg!r}") + arg = np.array(arg) + flatten = arg.ndim == 1 + arg = as_vec4(arg) + + ret = func(self_, arg) + return np.copy(np.ravel(ret)) if flatten and ret is not None else ret + + return wrapper + + +def _validate_matrix(val: Any) -> np.ndarray: + if val is None: + return np.eye(4) + if not isinstance(val, np.ndarray): + val = np.asarray(val, dtype=float) + if val.shape != (4, 4): + raise ValueError(f"Matrix must be 4x4, not {val.shape}") + return val # type: ignore + + +class Matrix(np.ndarray): + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + # Serialize the matrix as a list of lists, unless it is the identity matrix + # (for brevity in the serialized output). + def _serialize(val: np.ndarray) -> list | None: + if np.allclose(val, np.eye(4)): + return None + return val.tolist() # type: ignore + + list_of_float = core_schema.list_schema(core_schema.float_schema()) + ser_schema = core_schema.plain_serializer_function_ser_schema( + _serialize, + return_schema=core_schema.union_schema( + [core_schema.list_schema(list_of_float), core_schema.none_schema()] + ), + ) + + return core_schema.no_info_before_validator_function( + _validate_matrix, + core_schema.any_schema(), + serialization=ser_schema, + ) + + +class Transform(RootModel): + """Transformation.""" + + root: Matrix = Field( + default_factory=lambda: np.eye(4), # type: ignore + description="Transformation matrix.", + ) + + model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True, validate_default=True) + + def __array__(self, dtype: DTypeLike | None = None) -> np.ndarray: + return self.root.astype(dtype) + + def __repr__(self) -> str: + return repr(self.root).replace("array", "Transform") + + def is_null(self) -> bool: + return np.allclose(self.root, np.eye(4)) + + def __matmul__(self, other: Transform | ArrayLike) -> Transform: + """Return the dot product of this transform with another.""" + if isinstance(other, Transform): + other = other.root + return Transform(self.root @ other) # type: ignore + + def dot(self, other: Transform | ArrayLike) -> Transform: + """Return the dot product of this transform with another.""" + if isinstance(other, Transform): + other = other.root + return Transform(np.dot(self.root, other)) + + @property + def T(self) -> Transform: + """Return the transpose of the transform.""" + return Transform(self.root.T) + + def inv(self) -> Transform: + """Return the inverse of the transform.""" + return Transform(np.linalg.inv(self.root)) # type: ignore + + def translated(self, pos: ArrayLike) -> Transform: + """Return new transform, translated by pos. + + The translation is applied *after* the transformations already present + in the matrix. + + Parameters + ---------- + pos : ArrayLike + Position (x, y, z) to translate by. + """ + pos = as_vec4(np.array(pos)) + return self.dot(translate(pos[0, :3])) + + def rotated( + self, angle: float, axis: ArrayLike, about: ArrayLike | None = None + ) -> Transform: + """Return new transform, rotated some angle about a given axis. + + The rotation is applied *after* the transformations already present + in the matrix. + + Parameters + ---------- + angle : float + The angle of rotation, in degrees. + axis : array-like + The x, y and z coordinates of the axis vector to rotate around. + about : array-like or None + The x, y and z coordinates to rotate around. If None, will rotate around + the origin (0, 0, 0). + """ + if about is not None: + about = as_vec4(about)[0, :3] + return self.translated(-about).dot(rotate(angle, axis)).translated(about) + return self.dot(rotate(angle, axis)) + + def scaled( + self, scale_factor: ArrayLike, center: ArrayLike | None = None + ) -> Transform: + """Return new transform, scaled about a given origin. + + The scaling is applied *after* the transformations already present + in the matrix. + + Parameters + ---------- + scale_factor : array-like + Scale factors along x, y and z axes. + center : array-like or None + The x, y and z coordinates to scale around. If None, + (0, 0, 0) will be used. + """ + _scale = scale(as_vec4(scale_factor, default=(1, 1, 1, 1))[0, :3]) + if center is not None: + center = as_vec4(center)[0, :3] + _scale = np.dot(np.dot(translate(-center), _scale), translate(center)) + return self.dot(_scale) + + @_arg_to_vec4 + def map(self, coords: ArrayLike) -> NDArray: + """Map coordinates. + + Parameters + ---------- + coords : array-like + Coordinates to map. + + Returns + ------- + coords : ndarray + Coordinates. + """ + # looks backwards, but both matrices are transposed. + return cast(NDArray, np.dot(coords, self.root)) + + @_arg_to_vec4 + def imap(self, coords: ArrayLike) -> NDArray: + """Inverse map coordinates. + + Parameters + ---------- + coords : array-like + Coordinates to inverse map. + + Returns + ------- + coords : ndarray + Coordinates. + """ + return cast(NDArray, np.dot(coords, np.linalg.inv(self.root))) + + @classmethod + def chain(cls, *transforms: Transform) -> Transform: + """Chain multiple transforms together. + + Parameters + ---------- + transforms : Transform + Transforms to chain. + + Returns + ------- + transform : Transform + Chained transform. + """ + return reduce(lambda a, b: a @ b, transforms, cls()) + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Transform): + return NotImplemented + return np.allclose(self.root, value.root) + + def __hash__(self) -> int: + return hash(self.root.tobytes()) + + +# from vispy ... + + +def rotate(angle: float, axis: ArrayLike) -> np.ndarray: + """Return 4x4 rotation matrix for rotation about a vector. + + Parameters + ---------- + angle : float + The angle of rotation, in degrees. + axis : ndarray + The x, y, z coordinates of the axis direction vector. + + Returns + ------- + M : ndarray + Transformation matrix describing the rotation. + """ + angle = np.radians(angle) + axis = np.array(axis, copy=False) + if len(axis) != 3: + raise ValueError("axis must be a 3-element vector") + x, y, z = axis / np.linalg.norm(axis) + c, s = math.cos(angle), math.sin(angle) + cx, cy, cz = (1 - c) * x, (1 - c) * y, (1 - c) * z + M = [ + [cx * x + c, cy * x - z * s, cz * x + y * s, 0.0], + [cx * y + z * s, cy * y + c, cz * y - x * s, 0.0], + [cx * z - y * s, cy * z + x * s, cz * z + c, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + return np.array(M).T + + +def translate(offset: Iterable[float]) -> np.ndarray: + """Translate by an offset (x, y, z) . + + Parameters + ---------- + offset : Iterable[float] + Must be length 3. Translation in x, y, z. + + Returns + ------- + M : ndarray + Transformation matrix describing the translation. + """ + _offset = tuple(offset) + if len(_offset) != 3: + raise ValueError("offset must be a length 3 sequence") + x, y, z = _offset + return np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [x, y, z, 1.0], + ] + ) + + +def scale(s: Sized) -> np.ndarray: + """Non-uniform scaling along the x, y, and z axes. + + Parameters + ---------- + s : array-like, shape (3,) + Scaling in x, y, z. + + Returns + ------- + M : ndarray + Transformation matrix describing the scaling. + """ + if len(s) != 3: + raise ValueError("scale must be a length 3 sequence") + return np.array(np.diag(np.concatenate([s, (1.0,)]))) + + +def as_vec4(obj: ArrayLike, default: ArrayLike = (0, 0, 0, 1)) -> np.ndarray: + """Convert `obj` to 4-element vector (numpy array with shape[-1] == 4). + + Parameters + ---------- + obj : array-like + Original object. + default : array-like + The defaults to use if the object does not have 4 entries. + + Returns + ------- + obj : array-like + The object promoted to have 4 elements. + + Notes + ----- + `obj` will have at least two dimensions. + + If `obj` has < 4 elements, then new elements are added from `default`. + For inputs intended as a position or translation, use default=(0,0,0,1). + For inputs intended as scale factors, use default=(1,1,1,1). + + """ + obj = np.atleast_2d(obj) + # For multiple vectors, reshape to (..., 4) + if obj.shape[-1] < 4: + new = np.empty(obj.shape[:-1] + (4,), dtype=obj.dtype) + new[:] = default + new[..., : obj.shape[-1]] = obj + obj = new + elif obj.shape[-1] > 4: + raise TypeError(f"Array shape {obj.shape} cannot be converted to vec4") + return obj diff --git a/src/ndv/models/_scene/_vis_model.py b/src/ndv/models/_scene/_vis_model.py new file mode 100644 index 00000000..aa567e73 --- /dev/null +++ b/src/ndv/models/_scene/_vis_model.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +import contextlib +import logging +from abc import abstractmethod +from contextlib import suppress +from importlib import import_module +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Protocol, + TypeVar, + cast, +) + +from psygnal import EmissionInfo, SignalGroupDescriptor +from pydantic import BaseModel, ConfigDict +from pydantic.fields import Field + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + +__all__ = ["Field", "ModelBase", "SupportsVisibility", "VisModel"] + +logger = logging.getLogger(__name__) +# logging.basicConfig(level=logging.DEBUG) +SETTER_METHOD = "_vis_set_{name}" + + +class ModelBase(BaseModel): + """Base class for all evented pydantic-style models.""" + + events: ClassVar[SignalGroupDescriptor] = SignalGroupDescriptor() + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="ignore", + validate_default=True, + validate_assignment=True, + repr_exclude_defaults=True, # type: ignore [typeddict-unknown-key] + ) + + # repr that excludes default values + def __repr_args__(self) -> Iterable[tuple[str | None, Any]]: + super_args = super().__repr_args__() + if not self.model_config.get("repr_exclude_defaults"): + yield from super_args + return + + fields = self.model_fields + for key, val in super_args: + default = fields[key].get_default( # type: ignore + call_default_factory=True, validated_data={} + ) + with suppress(Exception): + if val == default: + continue + yield key, val + + +F = TypeVar("F", covariant=True, bound="VisModel") + + +class BackendAdaptor(Protocol[F]): + """Protocol for backend adaptor classes. + + An adaptor class is responsible for converting all of the ndv protocol methods + into native calls for the given backend. + """ + + @abstractmethod + def __init__(self, obj: F, **backend_kwargs: Any) -> None: + """All backend adaptor objects receive the object they are adapting.""" + ... + + @abstractmethod + def _vis_get_native(self) -> Any: + """Return the native object for the backend.""" + + # TODO: add a "detach" or "cleanup" method? + + +class SupportsVisibility(BackendAdaptor[F], Protocol): + """Protocol for objects that support visibility (show/hide).""" + + @abstractmethod + def _vis_set_visible(self, arg: bool) -> None: + """Set the visibility of the object.""" + + +AdaptorType = TypeVar("AdaptorType", bound=BackendAdaptor, covariant=True) + + +class VisModel(ModelBase, Generic[AdaptorType]): + """Front end object driving a backend interface. + + This is an important class. Most things subclass this. It provides the event + connection between the model object and a backend adaptor. + + A backend adaptor is a class that implements the BackendAdaptor protocol (of type + `T`... for which this class is a generic). The backend adaptor is an object + responsible for converting all of the ndv protocol methods (stuff like + "_vis_set_width", "_vis_set_visible", etc...) into the appropriate calls for + the given backend. + """ + + # Really, this should be `_backend_adaptors: ClassVar[dict[str, T]]``, + # but thats a type error. + # PEP 526 states that ClassVar cannot include any type variables... + # but there is discussion that this might be too limiting. + # dicsussion: https://github.com/python/mypy/issues/5144 + # _backend_adaptors: ClassVar[dict[str, BackendAdaptor]] = PrivateAttr({}) + + # This is the set of all field names that must have setters in the backend adaptor. + # set during the init + # _evented_fields: ClassVar[set[str]] = PrivateAttr(set()) + + # this is a cache of all adaptor classes that have been validated to implement + # the correct methods (via validate_adaptor_class). + _validated_adaptor_classes: ClassVar[set[type]] = set() + + def model_post_init(self, __context: Any) -> None: + # if using this in an EventedModel, connect to the events + self.events.connect(self._on_any_event) + # determine fields that need setter methods in the backend adaptor + # TODO: + # this really shouldn't need to be in the init. `__init_subclass__` would be + # better, but that unfortunately gets called after EventedModel.__new__. + # need to look into it + signal_names = set(self.events) + self._evented_fields = set(self.model_fields).intersection(signal_names) + self._backend_adaptors: dict[str, BackendAdaptor] = {} + + def has_backend_adaptor(self, backend: str | None = None) -> bool: + """Return True if the object has a backend adaptor. + + If None is passed, the returned bool indicates the presence of any + adaptor class. + """ + if backend is None: + return bool(self._backend_adaptors) + return backend in self._backend_adaptors + + def backend_adaptor(self, backend: str | None = None) -> AdaptorType: + """Get the backend adaptor for this object. Creates one if it doesn't exist. + + Parameters + ---------- + backend : str, optional + The name of the backend to use, by default None. If None, the default + backend will be used. + """ + backend = backend or _get_default_backend() + if backend not in self._backend_adaptors: + cls = self._get_adaptor_class(backend) + self._backend_adaptors[backend] = self._create_adaptor(cls) + return cast("AdaptorType", self._backend_adaptors[backend]) + + @property + def backend_adaptors(self) -> Iterable[AdaptorType]: + """Convenient, public iterator for backend adaptor objects.""" + yield from self._backend_adaptors.values() # type: ignore + + def dangerously_get_native_object(self, backend: str | None = None) -> Any: + """Return the native object for a backend. + + NOTE! Directly modifying the backend objects is not supported. This method + is here as a convenience for debugging, development, and experimentation. + Direct modification of the backend object may lead to desyncronization of + the model and the backend object, or other unexpected behavior. + """ + adaptor = self.backend_adaptor(backend=backend) + return adaptor._vis_get_native() + + def _get_adaptor_class( + self, + backend: str, + class_name: str | None = None, + ) -> type[AdaptorType]: + """Retrieve the adaptor class with the same name as the object class.""" + class_name = class_name or type(self).__name__ + backend_module = import_module(f"ndv.views._scene.{backend}") + adaptor_class = getattr(backend_module, class_name) + return self.validate_adaptor_class(adaptor_class) + + def _create_adaptor(self, cls: type[AdaptorType]) -> AdaptorType: + """Instantiate the backend adaptor object. + + The purpose of this method is to allow subclasses to override the + creation of the backend object. Or do something before/after. + """ + logger.debug(f"Attaching {type(self)} to backend {cls}") + adaptor = cls(self) + sync_adaptor(adaptor, self) + return adaptor + + def _sync_adaptors(self) -> None: + for adaptor in self.backend_adaptors: + sync_adaptor(adaptor, self) + + def _on_any_event(self, info: EmissionInfo) -> None: + signal_name = info.signal.name + if signal_name not in self._evented_fields: + return + + # NOTE: this loop runs anytime any attribute on any model is changed... + # so it has the potential to be a performance bottleneck. + # It is the the apparent cost, however, for allowing a model object to have + # multiple simultaneous backend adaptors. This should be re-evaluated often. + for adaptor in self.backend_adaptors: + try: + name = SETTER_METHOD.format(name=signal_name) + setter = getattr(adaptor, name) + except AttributeError as e: + logger.exception(e) + return + + event_name = f"{type(self).__name__}.{signal_name}" + logger.debug(f"{event_name}={info.args} emitting to backend") + + try: + setter(info.args[0]) + except Exception as e: + logger.exception(e) + breakpoint() + + # TODO: + # def detach(self) -> None: + # """Disconnect and destroy the backend adaptor from the object.""" + # self._backend = None + + def validate_adaptor_class(self, adaptor_class: Any) -> type[AdaptorType]: + """Validate that the adaptor class is appropriate for the core object.""" + # XXX: this could be a classmethod, but it's turning out to be difficult to + # set _evented_fields on that class (see note in __init__) + + cls = type(self) + if adaptor_class in cls._validated_adaptor_classes: + return cast("type[AdaptorType]", adaptor_class) + + # logger.debug(f"Validating adaptor class {adaptor_class} for {cls}") + if missing := { + SETTER_METHOD.format(name=field) + for field in self._evented_fields + if not hasattr(adaptor_class, SETTER_METHOD.format(name=field)) + }: + raise ValueError( + f"{adaptor_class} cannot be used as a backend object for " + f"{cls}: it is missing the following methods: {missing}" + ) + cls._validated_adaptor_classes.add(adaptor_class) + return cast("type[AdaptorType]", adaptor_class) + + +# XXX: the default behavior should be to +# pick the "right" backend for the current environment. +# i.e. ndv should work with no configuration in both jupyter and ipython desktop.) +def _get_default_backend() -> str: + """Stub function for the concept of picking a backend when none is specified. + + This will likely be context dependent. + """ + from ndv.views._app import canvas_backend + + return canvas_backend(None).value + + +def _update_blocker(adaptor: BackendAdaptor) -> contextlib.AbstractContextManager: + from ndv.models._scene.nodes.node import NodeAdaptorProtocol + + if isinstance(adaptor, NodeAdaptorProtocol): + + @contextlib.contextmanager + def blocker() -> Iterator[None]: + adaptor._vis_block_updates() + try: + yield + finally: + adaptor._vis_unblock_updates() + + return blocker() + return contextlib.nullcontext() + + +def sync_adaptor(adaptor: BackendAdaptor, model: VisModel) -> None: + """Decorator to validate and cache adaptor classes.""" + with _update_blocker(adaptor): + for field_name in model.model_fields: + method_name = SETTER_METHOD.format(name=field_name) + value = getattr(model, field_name) + try: + vis_set = getattr(adaptor, method_name) + vis_set(value) + except Exception as e: + logger.error( + "Failed to set field %r on adaptor %r: %s", field_name, adaptor, e + ) + force_update = getattr(adaptor, "_vis_force_update", lambda: None) + force_update() diff --git a/src/ndv/models/_scene/canvas.py b/src/ndv/models/_scene/canvas.py new file mode 100644 index 00000000..1968c90c --- /dev/null +++ b/src/ndv/models/_scene/canvas.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import warnings +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +from cmap import Color # noqa: TC002 +from psygnal.containers import EventedList +from pydantic import Field + +from ._vis_model import SupportsVisibility, VisModel +from .view import View + +if TYPE_CHECKING: + import numpy as np + + +ViewType = TypeVar("ViewType", bound=View) + + +class CanvasAdaptorProtocol(SupportsVisibility["Canvas"], Protocol): + """Protocol defining the interface for a Canvas adaptor.""" + + @abstractmethod + def _vis_set_width(self, arg: int) -> None: ... + @abstractmethod + def _vis_set_height(self, arg: int) -> None: ... + @abstractmethod + def _vis_set_background_color(self, arg: Color | None) -> None: ... + @abstractmethod + def _vis_set_title(self, arg: str) -> None: ... + @abstractmethod + def _vis_close(self) -> None: ... + @abstractmethod + def _vis_render(self) -> np.ndarray: ... + @abstractmethod + def _vis_add_view(self, view: View) -> None: ... + def _vis_get_ipython_mimebundle( + self, *args: Any, **kwargs: Any + ) -> dict | tuple[dict, dict] | Any: + return NotImplemented + + def _vis_set_views(self, views: list[View]) -> None: + pass + + +class ViewList(EventedList[ViewType]): + def _pre_insert(self, value: ViewType) -> ViewType: + if not isinstance(value, View): # pragma: no cover + raise TypeError("Canvas views must be View objects") + return super()._pre_insert(value) + + +class Canvas(VisModel[CanvasAdaptorProtocol]): + """Canvas onto which views are rendered. + + In desktop applications, this will be a window. In web applications, this will be a + div. The canvas has one or more views, which are rendered onto it. For example, + an orthoviewer might be a single canvas with three views, one for each axis. + """ + + width: int = Field(default=500, description="The width of the canvas in pixels.") + height: int = Field(default=500, description="The height of the canvas in pixels.") + background_color: Color | None = Field( + default=None, + description="The background color. None implies transparent " + "(which is usually black)", + ) + visible: bool = Field(default=False, description="Whether the canvas is visible.") + title: str = Field(default="", description="The title of the canvas.") + views: ViewList[View] = Field(default_factory=lambda: ViewList(), frozen=True) + + @property + def size(self) -> tuple[int, int]: + """Return the size of the canvas.""" + return self.width, self.height + + @size.setter + def size(self, value: tuple[int, int]) -> None: + """Set the size of the canvas.""" + self.width, self.height = value + + def close(self, backend: str | None = None) -> None: + """Close the canvas.""" + if self.has_backend_adaptor(backend=backend): + for adaptor in self.backend_adaptors: + adaptor._vis_close() + + # show and render will trigger a backend connection + + def show(self, *, backend: str | None = None) -> None: + """Show the canvas. + + Parameters + ---------- + backend : str, optional + The backend to use. If not provided, the default backend will be used. + TODO: clarify how this is chosen. + """ + # Note: the canvas.show() method is THE primary place where we create a tree + # of backend objects. (None of the lower level Node objects actually *need* + # any backend representation until they need to be shown visually) + # So, this method really bootstraps the entire "hydration" of the backend tree. + # Here, we make sure that all of the views have a backend adaptor. + + # If you need to add any additional logic to handle the moment of backend + # creation in a specific Node subtype, you can override the `_create_backend` + # method (see, for example, the View._create_backend method) + + # ensure we have a backend adaptor + adapter = self.backend_adaptor(backend=backend) + for view in self.views: + if not view.has_backend_adaptor(): + # make sure all of the views have a backend adaptor + view.backend_adaptor(backend=backend) + + for view in self.views: + adapter._vis_add_view(view) + + self.visible = True + + def hide(self) -> None: + """Hide the canvas.""" + self.visible = False + + def render(self, backend: str | None = None) -> np.ndarray: + """Render canvas to offscren buffer and return as numpy array.""" + # TODO: do we need to set visible=True temporarily here? + return self.backend_adaptor(backend=backend)._vis_render() + + # consider using canvas.views.append? + def add_view(self, view: View | None = None, **kwargs: Any) -> View: + """Add a new view to the canvas.""" + # TODO: change kwargs to params + if view is None: + view = View(**kwargs) + elif kwargs: # pragma: no cover + warnings.warn("kwargs ignored when view is provided", stacklevel=2) + elif not isinstance(view, View): # pragma: no cover + raise TypeError("view must be an instance of View") + + self.views.append(view) + if self.has_backend_adaptor(): + for adaptor in self.backend_adaptors: + adaptor._vis_add_view(view) + return view + + def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> dict[str, Any] | Any: + """Return a mimebundle for the canvas. + + This defers to the native object's _vis_get_ipython_mimebundle method + if it exists. + Allowing different backends to support Jupyter or other rich display. + + https://ipython.readthedocs.io/en/stable/config/integrating.html#more-powerful-methods + """ + adaptor = self.backend_adaptor() + if hasattr(adaptor, "_vis_get_ipython_mimebundle"): + return adaptor._vis_get_ipython_mimebundle(*args, **kwargs) + return NotImplemented + + +class GridCanvas(Canvas): + """Subclass with numpy-style indexing.""" + + # def __getitem__(self, key: tuple[int, int]) -> View: + # """Get the View at the given row and column.""" diff --git a/src/ndv/models/_scene/nodes/__init__.py b/src/ndv/models/_scene/nodes/__init__.py new file mode 100644 index 00000000..420ed9f0 --- /dev/null +++ b/src/ndv/models/_scene/nodes/__init__.py @@ -0,0 +1,7 @@ +from .camera import Camera +from .image import Image +from .node import GenericNode, Node +from .points import Points +from .scene import Scene + +__all__ = ["Camera", "GenericNode", "Image", "Node", "Points", "Scene"] diff --git a/src/ndv/models/_scene/nodes/camera.py b/src/ndv/models/_scene/nodes/camera.py new file mode 100644 index 00000000..fb9529dc --- /dev/null +++ b/src/ndv/models/_scene/nodes/camera.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Literal, Protocol + +from pydantic import Field + +from ndv._types import CameraType + +from .node import Node, NodeAdaptorProtocol + + +class CameraAdaptorProtocol(NodeAdaptorProtocol["Camera"], Protocol): + """Protocol for a backend camera adaptor object.""" + + @abstractmethod + def _vis_set_type(self, arg: CameraType) -> None: ... + @abstractmethod + def _vis_set_zoom(self, arg: float) -> None: ... + @abstractmethod + def _vis_set_center(self, arg: tuple[float, ...]) -> None: ... + @abstractmethod + def _vis_set_range(self, margin: float) -> None: ... + + +class Camera(Node["CameraAdaptorProtocol"]): + """A camera that defines the view of a scene.""" + + node_type: Literal["camera"] = "camera" + + type: CameraType = Field(default=CameraType.PANZOOM, description="Camera type.") + interactive: bool = Field( + default=True, + description="Whether the camera responds to user interaction, " + "such as mouse and keyboard events.", + ) + zoom: float = Field(default=1.0, description="Zoom factor of the camera.") + center: tuple[float, float, float] | tuple[float, float] = Field( + default=(0, 0, 0), description="Center position of the view." + ) + + def _set_range(self, margin: float = 0) -> None: + adaptor = self.backend_adaptor() + # TODO: this method should probably be pulled off of the backend, + # calculated directly in the core, and then applied as a change to the + # camera transform + adaptor._vis_set_range(margin=margin) diff --git a/src/ndv/models/_scene/nodes/image.py b/src/ndv/models/_scene/nodes/image.py new file mode 100644 index 00000000..b686fefe --- /dev/null +++ b/src/ndv/models/_scene/nodes/image.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Literal, Protocol + +from cmap import Colormap +from pydantic import Field + +from ndv._types import ImageInterpolation + +from .node import Node, NodeAdaptorProtocol + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class ImageBackend(NodeAdaptorProtocol["Image"], Protocol): + """Protocol for a backend Image adaptor object.""" + + @abstractmethod + def _vis_set_data(self, arg: NDArray) -> None: ... + @abstractmethod + def _vis_set_cmap(self, arg: Colormap) -> None: ... + @abstractmethod + def _vis_set_clims(self, arg: tuple[float, float] | None) -> None: ... + @abstractmethod + def _vis_set_gamma(self, arg: float) -> None: ... + @abstractmethod + def _vis_set_interpolation(self, arg: ImageInterpolation) -> None: ... + + +class Image(Node[ImageBackend]): + """A Image that can be placed in scene.""" + + node_type: Literal["image"] = "image" + + data: Any = Field(default=None, repr=False, exclude=True) + cmap: Colormap = Field( + default_factory=lambda: Colormap("gray"), + description="The colormap to use for the image.", + ) + clims: tuple[float, float] | None = Field( + default=None, + description="The contrast limits to use for the image.", + ) + gamma: float = Field(default=1.0, description="The gamma correction to use.") + interpolation: ImageInterpolation = Field( + default=ImageInterpolation.NEAREST, + description="The interpolation to use.", + ) diff --git a/src/ndv/models/_scene/nodes/node.py b/src/ndv/models/_scene/nodes/node.py new file mode 100644 index 00000000..806307f3 --- /dev/null +++ b/src/ndv/models/_scene/nodes/node.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import logging +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Literal, + Protocol, + TypeVar, + Union, + cast, + runtime_checkable, +) + +from pydantic import ( + Field, + SerializerFunctionWrapHandler, + field_validator, + model_serializer, +) + +from ndv.models._scene._transform import Transform +from ndv.models._scene._vis_model import SupportsVisibility, VisModel +from ndv.models._sequence import ValidatedEventedList + +if TYPE_CHECKING: + from collections.abc import Iterator + +logger = logging.getLogger(__name__) +NodeTypeCoV = TypeVar("NodeTypeCoV", bound="Node", covariant=True) +NodeAdaptorProtocolTypeCoV = TypeVar( + "NodeAdaptorProtocolTypeCoV", bound="NodeAdaptorProtocol", covariant=True +) + + +@runtime_checkable +class NodeAdaptorProtocol(SupportsVisibility[NodeTypeCoV], Protocol): + """Backend interface for a Node.""" + + @abstractmethod + def _vis_set_name(self, arg: str) -> None: ... + @abstractmethod + def _vis_set_parent(self, arg: Node | None) -> None: ... + @abstractmethod + def _vis_set_children(self, arg: list[Node]) -> None: ... + @abstractmethod + def _vis_set_opacity(self, arg: float) -> None: ... + @abstractmethod + def _vis_set_order(self, arg: int) -> None: ... + @abstractmethod + def _vis_set_interactive(self, arg: bool) -> None: ... + @abstractmethod + def _vis_set_transform(self, arg: Transform) -> None: ... + @abstractmethod + def _vis_add_node(self, node: Node) -> None: ... + + def _vis_set_node_type(self, arg: str) -> None: + pass + + @abstractmethod + def _vis_block_updates(self) -> None: + """Block future updates until `unblock_updates` is called.""" + + @abstractmethod + def _vis_unblock_updates(self) -> None: + """Unblock updates after `block_updates` was called.""" + + @abstractmethod + def _vis_force_update(self) -> None: + """Force an update to the node.""" + + +# improve me... Read up on: https://docs.pydantic.dev/latest/concepts/unions/ +AnyNode = Annotated[ + Union["Image", "Scene", "GenericNode", "Points"], Field(discriminator="node_type") +] + + +class Node(VisModel[NodeAdaptorProtocolTypeCoV]): + """Base class for all nodes. + + Do not instantiate this class directly. Use a subclass. GenericNode may + be used in place of Node. + """ + + node_type: str # discriminator field defined in subclasses + + name: str | None = Field(default=None, description="Name of the node.") + parent: AnyNode | None = Field( + default=None, + description="Parent node. If None, this node is a root node.", + exclude=True, # prevents recursion in serialization. + repr=False, # recursion is just confusing + # TODO: maybe make children the derived field? + ) + + children: ValidatedEventedList[AnyNode] = Field( + default_factory=lambda: ValidatedEventedList(), frozen=True + ) + visible: bool = Field(default=True, description="Whether this node is visible.") + interactive: bool = Field( + default=False, description="Whether this node accepts mouse and touch events" + ) + opacity: float = Field(default=1.0, ge=0, le=1, description="Opacity of this node.") + order: int = Field( + default=0, + ge=0, + description="A value used to determine the order in which nodes are drawn. " + "Greater values are drawn later. Children are always drawn after their parent", + ) + transform: Transform = Field( + default_factory=Transform, + description="Transform that maps the local coordinate frame to the coordinate " + "frame of the parent.", + ) + + @model_serializer(mode="wrap") + def _serialize_with_node_type(self, handler: SerializerFunctionWrapHandler) -> Any: + # modified serializer that ensures node_type is included, + # (e.g. even if exclude_defaults=True) + return {**handler(self), "node_type": self.node_type} + + # prevent direct instantiation, which makes it easier to use NodeUnion without + # having to deal with self-reference. + def __init__(self, /, **data: Any) -> None: + if type(self) is Node: + raise TypeError("Node cannot be instantiated directly. Use a subclass.") + super().__init__(**data) + + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + for child in self.children: + child.parent = cast("AnyNode", self) + self.children.item_inserted.connect(self._on_child_inserted) + + def _on_child_inserted(self, index: int, obj: Node) -> None: + # ensure parent is set + self.add(obj) + + def __contains__(self, item: Node) -> bool: + """Return True if this node is an ancestor of item.""" + return item in self.children + + def add(self, node: Node) -> None: + """Add a child node.""" + node = cast("AnyNode", node) + node.parent = cast("AnyNode", self) + if self.has_backend_adaptor() and not node.has_backend_adaptor(): + node.backend_adaptor() + if node not in self.children: + nd = f"{node.__class__.__name__} {id(node)}" + slf = f"{self.__class__.__name__} {id(self)}" + logger.debug(f"Adding node {nd} to {slf}") + self.children.append(node) + if self.has_backend_adaptor(): + self.backend_adaptor()._vis_add_node(node) + + @field_validator("transform", mode="before") + @classmethod + def _validate_transform(cls, v: Any) -> Any: + return Transform() if v is None else v + + # below borrowed from vispy.scene.Node + + def transform_to_node(self, other: Node) -> Transform: + """Return Transform that maps from coordinate frame of `self` to `other`. + + Note that there must be a _single_ path in the scenegraph that connects + the two entities; otherwise an exception will be raised. + + Parameters + ---------- + other : instance of Node + The other node. + + Returns + ------- + transform : instance of ChainTransform + The transform. + """ + a, b = self.path_to_node(other) + tforms = [n.transform for n in a[:-1]] + [n.transform.inv() for n in b] + return Transform.chain(*tforms[::-1]) + + def path_to_node(self, other: Node) -> tuple[list[Node], list[Node]]: + """Return two lists describing the path from this node to another. + + Parameters + ---------- + other : instance of Node + The other node. + + Returns + ------- + p1 : list + First path (see below). + p2 : list + Second path (see below). + + Notes + ----- + The first list starts with this node and ends with the common parent + between the endpoint nodes. The second list contains the remainder of + the path from the common parent to the specified ending node. + + For example, consider the following scenegraph:: + + A --- B --- C --- D + \ + --- E --- F + + Calling `D.node_path(F)` will return:: + + ([D, C, B], [E, F]) + + """ + my_parents = list(self.iter_parents()) + their_parents = list(other.iter_parents()) + common_parent = next((p for p in my_parents if p in their_parents), None) + if common_parent is None: + slf = f"{self.__class__.__name__} {id(self)}" + nd = f"{other.__class__.__name__} {id(other)}" + raise RuntimeError(f"No common parent between nodes {slf} and {nd}.") + + up = my_parents[: my_parents.index(common_parent) + 1] + down = their_parents[: their_parents.index(common_parent)][::-1] + return (up, down) + + def iter_parents(self) -> Iterator[Node]: + """Return list of parents starting from this node. + + The chain ends at the first node with no parents. + """ + yield self + + x = cast("AnyNode", self) + while True: + try: + parent = x.parent + except Exception: + break + if parent is None: + break + yield parent + x = parent + + +class GenericNode(Node[NodeAdaptorProtocol]): + """A generic node that can be placed in a scene.""" + + node_type: Literal["node"] = "node" + + +# TODO: gotta be a better pattern to populate AnyNode above... +from .image import Image # noqa: E402, TC001 +from .points import Points # noqa: E402, TC001 +from .scene import Scene # noqa: E402, TC001 + +Node.model_rebuild() diff --git a/src/ndv/models/_scene/nodes/points.py b/src/ndv/models/_scene/nodes/points.py new file mode 100644 index 00000000..49867cbc --- /dev/null +++ b/src/ndv/models/_scene/nodes/points.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Literal, Protocol + +from cmap import Color +from pydantic import Field + +from .node import Node, NodeAdaptorProtocol + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class PointsBackend(NodeAdaptorProtocol["Points"], Protocol): + """Protocol for a backend Image adaptor object.""" + + @abstractmethod + def _vis_set_coords(self, coords: NDArray) -> None: ... + @abstractmethod + def _vis_set_size(self, size: float) -> None: ... + @abstractmethod + def _vis_set_face_color(self, face_color: Color) -> None: ... + @abstractmethod + def _vis_set_edge_color(self, edge_color: Color) -> None: ... + @abstractmethod + def _vis_set_edge_width(self, edge_width: float) -> None: ... + @abstractmethod + def _vis_set_symbol(self, symbol: str) -> None: ... + @abstractmethod + def _vis_set_scaling(self, scaling: str) -> None: ... + @abstractmethod + def _vis_set_antialias(self, antialias: float) -> None: ... + @abstractmethod + def _vis_set_opacity(self, opacity: float) -> None: ... + + +SymbolName = Literal[ + "disc", + "arrow", + "ring", + "clobber", + "square", + "x", + "diamond", + "vbar", + "hbar", + "cross", + "tailed_arrow", + "triangle_up", + "triangle_down", + "star", + "cross_lines", +] +ScalingMode = Literal[True, False, "fixed", "scene", "visual"] + + +class Points(Node[PointsBackend]): + """Points that can be placed in scene.""" + + node_type: Literal["points"] = "points" + + # numpy array of 2D/3D point centers, shape (N, 2) or (N, 3) + coords: Any = Field(default=None, repr=False, exclude=True) + size: float = Field(default=10.0, description="The size of the points.") + face_color: Color | None = Field( + default=Color("white"), description="The color of the faces." + ) + edge_color: Color | None = Field( + default=Color("black"), description="The color of the edges." + ) + edge_width: float | None = Field(default=1.0, description="The width of the edges.") + symbol: SymbolName = Field( + default="disc", description="The symbol to use for the points." + ) + # TODO: these are vispy-specific names. Determine more general names + scaling: ScalingMode = Field( + default=True, description="Determines how points scale when zooming." + ) + + antialias: float = Field(default=1, description="Anti-aliasing factor, in px.") + opacity: float = Field(default=1.0, description="The opacity of the points.") diff --git a/src/ndv/models/_scene/nodes/scene.py b/src/ndv/models/_scene/nodes/scene.py new file mode 100644 index 00000000..78c941dc --- /dev/null +++ b/src/ndv/models/_scene/nodes/scene.py @@ -0,0 +1,13 @@ +from typing import Literal + +from .node import Node, NodeAdaptorProtocol + + +class Scene(Node[NodeAdaptorProtocol]): + """A Root node for a scene graph. + + This really isn't anything more than a regular Node, but it's an explicit + marker that this node is the root of a scene graph. + """ + + node_type: Literal["scene"] = "scene" diff --git a/src/ndv/models/_scene/view.py b/src/ndv/models/_scene/view.py new file mode 100644 index 00000000..ac4aaa42 --- /dev/null +++ b/src/ndv/models/_scene/view.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +from cmap import Color +from pydantic import ConfigDict, Field, PrivateAttr, computed_field + +from ._vis_model import ModelBase, SupportsVisibility, VisModel +from .nodes import Camera, Scene +from .nodes.node import Node + +if TYPE_CHECKING: + from psygnal import EmissionInfo + + from .canvas import Canvas + +NodeType = TypeVar("NodeType", bound=Node) +logger = logging.getLogger(__name__) + + +class ViewAdaptorProtocol(SupportsVisibility["View"], Protocol): + """Protocol defining the interface for a View adaptor.""" + + @abstractmethod + def _vis_set_camera(self, arg: Camera) -> None: ... + @abstractmethod + def _vis_set_scene(self, arg: Scene) -> None: ... + @abstractmethod + def _vis_set_position(self, arg: tuple[float, float]) -> None: ... + @abstractmethod + def _vis_set_size(self, arg: tuple[float, float] | None) -> None: ... + @abstractmethod + def _vis_set_background_color(self, arg: Color | None) -> None: ... + @abstractmethod + def _vis_set_border_width(self, arg: float) -> None: ... + @abstractmethod + def _vis_set_border_color(self, arg: Color | None) -> None: ... + @abstractmethod + def _vis_set_padding(self, arg: int) -> None: ... + @abstractmethod + def _vis_set_margin(self, arg: int) -> None: ... + + def _vis_set_layout(self, arg: Layout) -> None: + pass + + +class Layout(ModelBase): + """Rectangular layout model. + + y + | + v + x-> +--------------------------------+ ^ + | margin | | + | +--------------------------+ | | + | | border | | | + | | +--------------------+ | | | + | | | padding | | | | + | | | +--------------+ | | | height + | | | | content | | | | | + | | | | | | | | | + | | | +--------------+ | | | | + | | +--------------------+ | | | + | +--------------------------+ | | + +--------------------------------+ v + + <------------ width -------------> + """ + + x: float = Field( + default=0, description="The x-coordinate of the object (wrt parent)." + ) + y: float = Field( + default=0, description="The y-coordinate of the object (wrt parent)." + ) + width: float = Field(default=0, description="The width of the object.") + height: float = Field(default=0, description="The height of the object.") + background_color: Color | None = Field( + default=Color("black"), + description="The background color (inside of the border). " + "None implies transparent.", + ) + border_width: float = Field( + default=0, description="The width of the border in pixels." + ) + border_color: Color | None = Field( + default=Color("black"), description="The color of the border." + ) + padding: int = Field( + default=0, + description="The amount of padding in the widget " + "(i.e. the space reserved between the contents and the border).", + ) + margin: int = Field( + default=0, description="he margin to keep outside the widget's border" + ) + + @property + def position(self) -> tuple[float, float]: + return self.x, self.y + + @property + def size(self) -> tuple[float, float]: + return self.width, self.height + + +class View(VisModel[ViewAdaptorProtocol]): + """A rectangular area on a canvas that displays a scene, with a camera. + + A canvas can have one or more views. Each view has a single scene (i.e. a + scene graph of nodes) and a single camera. The camera defines the view + transformation. This class just exists to associate a single scene and + camera. + """ + + scene: Scene = Field(default_factory=Scene) + camera: Camera = Field(default_factory=Camera) + layout: Layout = Field(default_factory=Layout, frozen=True) + + model_config = ConfigDict(repr_exclude_defaults=False) # type: ignore + + _canvas: Canvas | None = PrivateAttr(None) + + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + self.camera.parent = self.scene + self.layout.events.connect(self._on_layout_event) + + @computed_field # type: ignore + @property + def canvas(self) -> Canvas: + """The canvas that the view is on. + + If one hasn't been created/assigned, a new one is created. + """ + if (canvas := self._canvas) is None: + from .canvas import Canvas + + self.canvas = canvas = Canvas() + return canvas + + @canvas.setter + def canvas(self, value: Canvas) -> None: + self._canvas = value + self._canvas.add_view(self) + + def _on_layout_event(self, info: EmissionInfo) -> None: + _signal_name = info.signal.name + ... + + def show(self) -> Canvas: + """Show the view. + + Convenience method for showing the canvas that the view is on. + If no canvas exists, a new one is created. + """ + canvas = self.canvas + canvas.show() + return self.canvas + + def add_node(self, node: NodeType) -> NodeType: + """Add any node to the scene.""" + self.scene.add(node) + self.camera._set_range(margin=0) + return node + + def _create_adaptor(self, cls: type[ViewAdaptorProtocol]) -> ViewAdaptorProtocol: + adaptor = super()._create_adaptor(cls) + logger.debug("VIEW Setting scene %r and camera %r", self.scene, self.camera) + adaptor._vis_set_scene(self.scene) + adaptor._vis_set_camera(self.camera) + return adaptor diff --git a/src/ndv/models/_sequence.py b/src/ndv/models/_sequence.py new file mode 100644 index 00000000..42901ca1 --- /dev/null +++ b/src/ndv/models/_sequence.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping, MutableSequence +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + Callable, + SupportsIndex, + TypeVar, + get_args, + overload, +) + +from psygnal import Signal +from pydantic import ( + TypeAdapter, +) +from pydantic_core import core_schema + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler + +_T = TypeVar("_T") + + +class ValidatedEventedList(MutableSequence[_T]): + item_inserted = Signal(int, object) # (idx, value) + item_removed = Signal(int, object) # (idx, value) + item_changed = Signal(object, object, object) # (int | slice, new, old) + items_reordered = Signal() + + @overload + def __init__(self) -> None: ... + @overload + def __init__( + self, + iterable: Iterable[_T], + *, + _item_adaptor: TypeAdapter | None = ..., + ) -> None: ... + def __init__( + self, + iterable: Iterable[_T] = (), + *, + _item_adaptor: TypeAdapter | None = None, + ) -> None: + self._item_adaptor = _item_adaptor + if self._item_adaptor is not None: + iterable = (self._item_adaptor.validate_python(v) for v in iterable) + self._list = list(iterable) + + # ---------------- abstract interface ---------------- + + @overload + def __getitem__(self, i: SupportsIndex) -> _T: ... + @overload + def __getitem__(self, i: slice) -> list[_T]: ... + def __getitem__(self, i: SupportsIndex | slice) -> _T | list[_T]: + return self._list[i] + + @overload + def __setitem__(self, key: SupportsIndex, value: _T) -> None: ... + @overload + def __setitem__(self, key: slice, value: Iterable[_T]) -> None: ... + def __setitem__(self, key: slice | SupportsIndex, value: _T | Iterable[_T]) -> None: + if isinstance(value, Iterable): + value = (self._validate_item(v) for v in value) + else: + value = self._validate_item(value) + + # no-op if value is identical + old = self._list[key] + if value is old: + return + + self._list[key] = value # type: ignore [index,assignment] + self.item_changed.emit(key, value, old) + + def __delitem__(self, key: SupportsIndex | slice) -> None: + item = self._list[key] + del self._list[key] + self.item_removed.emit(key, item) + + def insert(self, index: SupportsIndex, obj: _T) -> None: + obj = self._validate_item(obj) + self._list.insert(index, obj) + self.item_inserted.emit(index, obj) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, value: object) -> bool: + # TODO: this can cause recursion errors for recursive models + if isinstance(value, ValidatedEventedList): + return self._list == value._list + return NotImplemented + + # ----------------------------------------------------- + + def __repr__(self) -> str: + return repr(self._list) + # return f"{type(self).__name__}({self._list!r})" + + @cached_property + def _validate_item(self) -> Callable[[Any], _T]: + if self._item_adaptor is None: + # __orig_class__ is not available during __init__ + # https://discuss.python.org/t/runtime-access-to-type-parameters/37517 + cls = getattr(self, "__orig_class__", None) or type(self) + if args := get_args(cls): + self._item_adaptor = TypeAdapter(args[0]) + + if self._item_adaptor is not None: + return self._item_adaptor.validate_python + + return lambda x: x + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> Mapping[str, Any]: + """Return the Pydantic core schema for this object.""" + item_type = args[0] if (args := get_args(source_type)) else Any + + def _validate(obj: Any, _item_type: Any = item_type) -> Any: + # delayed instantiation of TypeAdapter to allow recursive models + # time to rebuild + adapter = TypeAdapter(_item_type) + return cls(obj, _item_adaptor=adapter) + + def _serialize(obj: ValidatedEventedList[_T]) -> Any: + return obj._list + + items_schema = handler.generate_schema(item_type) + list_schema = core_schema.list_schema(items_schema=items_schema) + return core_schema.no_info_plain_validator_function( + function=_validate, + json_schema_input_schema=list_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + _serialize, + return_schema=list_schema, + ), + ) diff --git a/src/ndv/views/_scene/__init__.py b/src/ndv/views/_scene/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/views/_scene/pygfx/__init__.py b/src/ndv/views/_scene/pygfx/__init__.py new file mode 100644 index 00000000..eadae64f --- /dev/null +++ b/src/ndv/views/_scene/pygfx/__init__.py @@ -0,0 +1,9 @@ +from ._camera import Camera +from ._canvas import Canvas +from ._image import Image +from ._node import Node +from ._points import Points +from ._scene import Scene +from ._view import View + +__all__ = ["Camera", "Canvas", "Image", "Node", "Points", "Scene", "View"] diff --git a/src/ndv/views/_scene/pygfx/_camera.py b/src/ndv/views/_scene/pygfx/_camera.py new file mode 100644 index 00000000..750dbd97 --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_camera.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, cast + +import numpy as np +import pygfx + +from ndv._types import CameraType +from ndv.models._scene.nodes import camera + +from ._node import Node + + +class Camera(Node, camera.CameraAdaptorProtocol): + """Adaptor for pygfx camera.""" + + _pygfx_node: pygfx.PerspectiveCamera + pygfx_controller: pygfx.Controller + + def __init__(self, camera: camera.Camera, **backend_kwargs: Any) -> None: + self._camera_model = camera + if camera.type == CameraType.PANZOOM: + self._pygfx_node = pygfx.OrthographicCamera() + self.pygfx_controller = pygfx.PanZoomController(self._pygfx_node) + elif camera.type == CameraType.ARCBALL: + self._pygfx_node = pygfx.PerspectiveCamera(70, 4 / 3) + self.pygfx_controller = pygfx.OrbitController(self._pygfx_node) + + self._pygfx_node.local.scale_y = -1 # don't think this is working... + + def _vis_set_zoom(self, zoom: float) -> None: + raise NotImplementedError + + def _vis_set_center(self, arg: tuple[float, ...]) -> None: + raise NotImplementedError + + def _vis_set_type(self, arg: CameraType) -> None: + raise NotImplementedError + + def _view_size(self) -> tuple[float, float] | None: + """Return the size of first parent viewbox in pixels.""" + raise NotImplementedError + + def update_controller(self) -> None: + # This is called by the View Adaptor in the `_visit` method + # ... which is in turn called by the Canvas backend adaptor's `_animate` method + # i.e. the main render loop. + self.pygfx_controller.update_camera(self._pygfx_node) + + def set_viewport(self, viewport: pygfx.Viewport) -> None: + # This is used by the Canvas backend adaptor... + # and should perhaps be moved to the View Adaptor + self.pygfx_controller.add_default_event_handlers(viewport, self._pygfx_node) + + def _vis_set_range(self, margin: float) -> None: + # reset camera to fit all objects + if not (scene := self._camera_model.parent): + print("No scene found for camera") + return + + py_scene = cast("pygfx.Scene", scene.backend_adaptor("pygfx")._vis_get_native()) + cam = self._pygfx_node + cam.show_object(py_scene) + + if (bb := py_scene.get_world_bounding_box()) is not None: + width, height, _depth = np.ptp(bb, axis=0) + if width < 0.01: + width = 1 + if height < 0.01: + height = 1 + cam.width = width + cam.height = height + cam.zoom = 1 - margin diff --git a/src/ndv/views/_scene/pygfx/_canvas.py b/src/ndv/views/_scene/pygfx/_canvas.py new file mode 100644 index 00000000..463ca572 --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_canvas.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +from ndv.models import _scene as core + +if TYPE_CHECKING: + import numpy as np + from cmap import Color + from rendercanvas.auto import RenderCanvas + + from ._view import View + + +class Canvas(core.canvas.CanvasAdaptorProtocol): + """Canvas interface for pygfx Backend.""" + + def __init__(self, canvas: core.Canvas, **backend_kwargs: Any) -> None: + from rendercanvas.auto import RenderCanvas + + self._wgpu_canvas = RenderCanvas() + # Qt RenderCanvas calls show() in its __init__ method, so we need to hide it + if hasattr(self._wgpu_canvas, "hide"): + self._wgpu_canvas.hide() + + self._wgpu_canvas.set_logical_size(canvas.width, canvas.height) + self._wgpu_canvas.set_title(canvas.title) + self._views = canvas.views + + def _vis_get_native(self) -> RenderCanvas: + return self._wgpu_canvas + + def _vis_set_visible(self, arg: bool) -> None: + # show the qt canvas we patched earlier in __init__ + if hasattr(self._wgpu_canvas, "show"): + self._wgpu_canvas.show() + self._wgpu_canvas.request_draw(self._draw) + + def _draw(self) -> None: + for view in self._views: + adaptor = cast("View", view.backend_adaptor("pygfx")) + adaptor._draw() + + def _vis_add_view(self, view: core.View) -> None: + pass + # adaptor = cast("View", view.backend_adaptor()) + # adaptor._pygfx_cam.set_viewport(self._viewport) + # self._views.append(adaptor) + + def _vis_set_width(self, arg: int) -> None: + _, height = self._wgpu_canvas.get_logical_size() + self._wgpu_canvas.set_logical_size(arg, height) + + def _vis_set_height(self, arg: int) -> None: + width, _ = self._wgpu_canvas.get_logical_size() + self._wgpu_canvas.set_logical_size(width, arg) + + def _vis_set_background_color(self, arg: Color) -> None: + # not sure if pygfx has both a canavs and view background color... + pass + + def _vis_set_title(self, arg: str) -> None: + self._wgpu_canvas.set_title(arg) + + def _vis_close(self) -> None: + """Close canvas.""" + self._wgpu_canvas.close() + + def _vis_render( + self, + region: tuple[int, int, int, int] | None = None, + size: tuple[int, int] | None = None, + bgcolor: Color | None = None, + crop: np.ndarray | tuple[int, int, int, int] | None = None, + alpha: bool = True, + ) -> np.ndarray: + """Render to screenshot.""" + from rendercanvas.offscreen import OffscreenRenderCanvas + + # not sure about this... + w, h = self._wgpu_canvas.get_logical_size() + canvas = OffscreenRenderCanvas(width=w, height=h, pixel_ratio=1) + canvas.request_draw(self._draw) + return cast("np.ndarray", canvas.draw()) diff --git a/src/ndv/views/_scene/pygfx/_image.py b/src/ndv/views/_scene/pygfx/_image.py new file mode 100644 index 00000000..3c027a2e --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_image.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any + +import pygfx + +from ndv._types import ImageInterpolation + +from ._node import Node + +if TYPE_CHECKING: + from cmap import Colormap + + from ndv._types import ArrayLike + from ndv.models._scene import nodes + + +class Image(Node): + """pygfx backend adaptor for an Image node.""" + + _pygfx_node: pygfx.Image + _material: pygfx.ImageBasicMaterial + _geometry: pygfx.Geometry + + def __init__(self, image: nodes.Image, **backend_kwargs: Any) -> None: + self._vis_set_data(image.data) + self._material = pygfx.ImageBasicMaterial(clim=image.clims) + self._pygfx_node = pygfx.Image(self._geometry, self._material) + + def _vis_set_cmap(self, arg: Colormap) -> None: + self._material.map = arg.to_pygfx() + + def _vis_set_clims(self, arg: tuple[float, float] | None) -> None: + self._material.clim = arg + + def _vis_set_gamma(self, arg: float) -> None: + warnings.warn( + "Gamma correction not supported by pygfx", RuntimeWarning, stacklevel=2 + ) + + def _vis_set_interpolation(self, arg: ImageInterpolation) -> None: + if arg is ImageInterpolation.BICUBIC: + warnings.warn( + "Bicubic interpolation not supported by pygfx", + RuntimeWarning, + stacklevel=2, + ) + arg = ImageInterpolation.LINEAR + self._material.interpolation = arg.value + + def _create_texture(self, data: ArrayLike) -> pygfx.Texture: + if data is not None: + dim = data.ndim + if dim > 2 and data.shape[-1] <= 4: + dim -= 1 # last array dim is probably (a subset of) rgba + else: + dim = 2 + # TODO: unclear whether get_view() is better here... + return pygfx.Texture(data, dim=dim) + + def _vis_set_data(self, data: ArrayLike) -> None: + self._texture = self._create_texture(data) + self._geometry = pygfx.Geometry(grid=self._texture) diff --git a/src/ndv/views/_scene/pygfx/_node.py b/src/ndv/views/_scene/pygfx/_node.py new file mode 100644 index 00000000..265754cd --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_node.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, cast + +from ndv.models._scene.nodes import node as core_node + +if TYPE_CHECKING: + from pygfx.geometries import Geometry + from pygfx.materials import Material + from pygfx.objects import WorldObject + + from ndv.models._scene import Transform + + +class Node(core_node.NodeAdaptorProtocol): + """Node adaptor for pygfx Backend.""" + + _pygfx_node: WorldObject + _material: Material + _geometry: Geometry + _name: str + + def _vis_get_native(self) -> Any: + return self._pygfx_node + + def _vis_set_name(self, arg: str) -> None: + # not sure pygfx has a name attribute... + # TODO: for that matter... do we need a name attribute? + # Could this be entirely managed on the model side/ + self._name = arg + + def _vis_set_parent(self, arg: core_node.Node | None) -> None: + warnings.warn("Parenting not implemented in pygfx backend", stacklevel=2) + + def _vis_set_children(self, arg: list[core_node.Node]) -> None: + # This is probably redundant with _vis_add_node + # could maybe be a clear then add *arg + warnings.warn("Parenting not implemented in pygfx backend", stacklevel=2) + + def _vis_set_visible(self, arg: bool) -> None: + self._pygfx_node.visible = arg + + def _vis_set_opacity(self, arg: float) -> None: + if material := getattr(self, "_material", None): + material.opacity = arg + + def _vis_set_order(self, arg: int) -> None: + self._pygfx_node.render_order = arg + + def _vis_set_interactive(self, arg: bool) -> None: + # this one requires knowledge of the controller + warnings.warn("interactive not implemented in pygfx backend", stacklevel=2) + + def _vis_set_transform(self, arg: Transform) -> None: + self._pygfx_node.local.matrix = arg.root + + def _vis_add_node(self, node: core_node.Node) -> None: + # create if it doesn't exist + adaptor = cast("Node", node.backend_adaptor("pygfx")) + self._pygfx_node.add(adaptor._vis_get_native()) + + def _vis_force_update(self) -> None: + pass + + def _vis_block_updates(self) -> None: + pass + + def _vis_unblock_updates(self) -> None: + pass diff --git a/src/ndv/views/_scene/pygfx/_points.py b/src/ndv/views/_scene/pygfx/_points.py new file mode 100644 index 00000000..d5cd0638 --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_points.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import pygfx + +from ._node import Node + +if TYPE_CHECKING: + from collections.abc import Mapping + + import numpy.typing as npt + from cmap import Color + + from ndv.models._scene import nodes + from ndv.models._scene.nodes.points import ScalingMode + +SPACE_MAP: Mapping[ScalingMode, Literal["model", "screen", "world"]] = { + True: "world", + False: "screen", + "fixed": "screen", + "scene": "world", + "visual": "model", +} + + +class Points(Node): + """Vispy backend adaptor for an Points node.""" + + _pygfx_node: pygfx.Points + + def __init__(self, points: nodes.Points, **backend_kwargs: Any) -> None: + # TODO: unclear whether get_view() is better here... + coords = np.asarray(points.coords) + n_coords = len(coords) + + # ensure (N, 3) + if coords.shape[1] == 2: + coords = np.column_stack((coords, np.zeros(coords.shape[0]))) + + geo_kwargs = {} + if points.face_color is not None: + colors = np.tile(np.asarray(points.face_color), (n_coords, 1)) + geo_kwargs["colors"] = colors.astype(np.float32) + + # TODO: not sure whether/how pygfx implements all the other properties + + self._geometry = pygfx.Geometry( + positions=coords.astype(np.float32), + sizes=np.full(n_coords, points.size, dtype=np.float32), + **geo_kwargs, + ) + self._material = pygfx.PointsMaterial( + size=points.size, + size_space=SPACE_MAP[points.scaling], + aa=points.antialias > 0, + opacity=points.opacity, + color_mode="vertex", + size_mode="vertex", + ) + self._pygfx_node = pygfx.Points(self._geometry, self._material) + + def _vis_set_coords(self, coords: npt.NDArray) -> None: ... + + def _vis_set_size(self, size: float) -> None: ... + + def _vis_set_face_color(self, face_color: Color) -> None: ... + + def _vis_set_edge_color(self, edge_color: Color) -> None: ... + + def _vis_set_edge_width(self, edge_width: float) -> None: ... + + def _vis_set_symbol(self, symbol: str) -> None: ... + + def _vis_set_scaling(self, scaling: str) -> None: ... + + def _vis_set_antialias(self, antialias: float) -> None: ... + + def _vis_set_opacity(self, opacity: float) -> None: ... diff --git a/src/ndv/views/_scene/pygfx/_scene.py b/src/ndv/views/_scene/pygfx/_scene.py new file mode 100644 index 00000000..f49672d2 --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_scene.py @@ -0,0 +1,21 @@ +from typing import Any + +import pygfx + +from ndv.models import _scene as core + +from ._node import Node + + +class Scene(Node): + _pygfx_node: pygfx.Scene + + def __init__(self, scene: core.Scene, **backend_kwargs: Any) -> None: + self._pygfx_node = pygfx.Scene(visible=scene.visible, **backend_kwargs) + self._pygfx_node.render_order = scene.order + + # Almar does this in Display.show... + self._pygfx_node.add(pygfx.AmbientLight()) + + for node in scene.children: + self._vis_add_node(node) diff --git a/src/ndv/views/_scene/pygfx/_view.py b/src/ndv/views/_scene/pygfx/_view.py new file mode 100644 index 00000000..f0a3cacc --- /dev/null +++ b/src/ndv/views/_scene/pygfx/_view.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, cast + +import pygfx + +from ndv.models import _scene as core + +if TYPE_CHECKING: + from cmap import Color + + from . import _camera, _canvas, _scene + + +class View(core.view.ViewAdaptorProtocol): + """View interface for pygfx Backend. + + A view combines a scene and a camera to render a scene (onto a canvas). + """ + + _pygfx_scene: pygfx.Scene + _pygfx_cam: pygfx.Camera + + def __init__(self, view: core.View, **backend_kwargs: Any) -> None: + canvas_adaptor = cast("_canvas.Canvas", view.canvas.backend_adaptor("pygfx")) + wgpu_canvas = canvas_adaptor._vis_get_native() + self._renderer = pygfx.renderers.WgpuRenderer(wgpu_canvas) + + self._vis_set_scene(view.scene) + self._vis_set_camera(view.camera) + + def _vis_get_native(self) -> pygfx.Viewport: + return pygfx.Viewport(self._renderer) + + def _vis_set_visible(self, arg: bool) -> None: + pass + + def _vis_set_scene(self, scene: core.Scene) -> None: + self._scene_adaptor = cast("_scene.Scene", scene.backend_adaptor("pygfx")) + self._pygfx_scene = self._scene_adaptor._pygfx_node + + def _vis_set_camera(self, cam: core.Camera) -> None: + self._cam_adaptor = cast("_camera.Camera", cam.backend_adaptor("pygfx")) + self._pygfx_cam = self._cam_adaptor._pygfx_node + self._cam_adaptor.pygfx_controller.register_events(self._renderer) + + def _draw(self) -> None: + renderer = self._renderer + renderer.render(self._pygfx_scene, self._pygfx_cam) + renderer.request_draw() + + def _vis_set_position(self, arg: tuple[float, float]) -> None: + warnings.warn( + "set_position not implemented for pygfx", RuntimeWarning, stacklevel=2 + ) + + def _vis_set_size(self, arg: tuple[float, float] | None) -> None: + warnings.warn( + "set_size not implemented for pygfx", RuntimeWarning, stacklevel=2 + ) + + def _vis_set_background_color(self, color: Color | None) -> None: + colors = (color.rgba,) if color is not None else () + background = pygfx.Background(None, material=pygfx.BackgroundMaterial(*colors)) + self._pygfx_scene.add(background) + + def _vis_set_border_width(self, arg: float) -> None: + warnings.warn( + "set_border_width not implemented for pygfx", RuntimeWarning, stacklevel=2 + ) + + def _vis_set_border_color(self, arg: Color | None) -> None: + warnings.warn( + "set_border_color not implemented for pygfx", RuntimeWarning, stacklevel=2 + ) + + def _vis_set_padding(self, arg: int) -> None: + warnings.warn( + "set_padding not implemented for pygfx", RuntimeWarning, stacklevel=2 + ) + + def _vis_set_margin(self, arg: int) -> None: + warnings.warn( + "set_margin not implemented for pygfx", RuntimeWarning, stacklevel=2 + ) diff --git a/src/ndv/views/_scene/vispy/__init__.py b/src/ndv/views/_scene/vispy/__init__.py new file mode 100644 index 00000000..eadae64f --- /dev/null +++ b/src/ndv/views/_scene/vispy/__init__.py @@ -0,0 +1,9 @@ +from ._camera import Camera +from ._canvas import Canvas +from ._image import Image +from ._node import Node +from ._points import Points +from ._scene import Scene +from ._view import View + +__all__ = ["Camera", "Canvas", "Image", "Node", "Points", "Scene", "View"] diff --git a/src/ndv/views/_scene/vispy/_camera.py b/src/ndv/views/_scene/vispy/_camera.py new file mode 100644 index 00000000..eb453e98 --- /dev/null +++ b/src/ndv/views/_scene/vispy/_camera.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +from vispy import scene + +from ndv.models._scene.nodes import camera + +from ._node import Node + +if TYPE_CHECKING: + from ndv._types import CameraType + from ndv.models._scene._transform import Transform + + +class Camera(Node, camera.CameraAdaptorProtocol): + """Adaptor for vispy camera.""" + + _vispy_node: scene.cameras.BaseCamera + + def __init__(self, camera: camera.Camera, **backend_kwargs: Any) -> None: + backend_kwargs.setdefault("flip", (0, 1, 0)) # Add to core schema? + # backend_kwargs.setdefault("aspect", 1) + cam = scene.cameras.make_camera(str(camera.type), **backend_kwargs) + self._vispy_node = cam + + def _vis_set_zoom(self, zoom: float) -> None: + if (view_size := self._view_size()) is None: + return + scale = np.array(view_size) / zoom + if hasattr(self._vispy_node, "scale_factor"): + self._vispy_node.scale_factor = np.min(scale) + else: + # Set view rectangle, as left, right, width, height + corner = np.subtract(self._vispy_node.center[:2], scale / 2) + self._vispy_node.rect = tuple(corner) + tuple(scale) + + def _vis_set_center(self, arg: tuple[float, ...]) -> None: + self._vispy_node.center = arg[::-1] # TODO + self._vispy_node.view_changed() + + def _vis_set_type(self, arg: CameraType) -> None: + if isinstance(self._vispy_node.parent, scene.ViewBox): + self._vispy_node.parent.camera = str(arg) + # else: + # raise TypeError("Camera must be attached to a ViewBox") + + def _view_size(self) -> tuple[float, float] | None: + """Return the size of first parent viewbox in pixels.""" + obj = self._vispy_node + while (obj := obj.parent) is not None: + if isinstance(obj, scene.ViewBox): + return cast("tuple[float, float]", obj.size) + return None + + def _vis_set_range(self, margin: float) -> None: + self._vispy_node.set_range(margin=margin) + + def _vis_set_transform(self, arg: Transform) -> None: + # TODO: camera transform needs special handling + # return super()._vis_set_transform(arg) + pass diff --git a/src/ndv/views/_scene/vispy/_canvas.py b/src/ndv/views/_scene/vispy/_canvas.py new file mode 100644 index 00000000..8199c997 --- /dev/null +++ b/src/ndv/views/_scene/vispy/_canvas.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +from vispy import scene + +from ndv.models import _scene as core + +from ._util import pyd_color_to_vispy + +if TYPE_CHECKING: + import numpy as np + from cmap import Color + + +class Canvas(core.canvas.CanvasAdaptorProtocol): + """Canvas interface for Vispy Backend.""" + + def __init__(self, canvas: core.Canvas, **backend_kwargs: Any) -> None: + backend_kwargs.setdefault("keys", "interactive") + self._vispy_canvas = scene.SceneCanvas( + size=(canvas.width, canvas.height), + title=canvas.title, + show=canvas.visible, + bgcolor=pyd_color_to_vispy(canvas.background_color), + **backend_kwargs, + ) + + def _vis_get_native(self) -> scene.SceneCanvas: + return self._vispy_canvas + + def _vis_set_visible(self, arg: bool) -> None: + self._vispy_canvas.show(visible=arg) + + def _vis_add_view(self, view: core.View) -> None: + vispy_view = view.backend_adaptor("vispy")._vis_get_native() + if not isinstance(vispy_view, scene.ViewBox): + raise TypeError("View must be a Vispy ViewBox") + self._vispy_canvas.central_widget.add_widget(vispy_view) + + def _vis_set_width(self, arg: int) -> None: + _height = self._vispy_canvas.size[1] + self._vispy_canvas.size = (int(arg), int(_height)) + + def _vis_set_height(self, arg: int) -> None: + _width = self._vispy_canvas.size[0] + self._vispy_canvas.size = (int(_width), int(arg)) + + def _vis_set_background_color(self, arg: Color | None) -> None: + self._vispy_canvas.bgcolor = pyd_color_to_vispy(arg) + + def _vis_set_title(self, arg: str) -> None: + self._vispy_canvas.title = arg + + def _vis_close(self) -> None: + """Close canvas.""" + self._vispy_canvas.close() + + def _vis_render( + self, + region: tuple[int, int, int, int] | None = None, + size: tuple[int, int] | None = None, + bgcolor: Color | None = None, + crop: np.ndarray | tuple[int, int, int, int] | None = None, + alpha: bool = True, + ) -> np.ndarray: + """Render to screenshot.""" + data = self._vispy_canvas.render( + region=region, size=size, bgcolor=bgcolor, crop=crop, alpha=alpha + ) + return cast("np.ndarray", data) + + def _vis_get_ipython_mimebundle( + self, *args: Any, **kwargs: Any + ) -> dict | tuple[dict, dict]: + return self._vis_get_native()._repr_mimebundle_(*args, **kwargs) # type: ignore diff --git a/src/ndv/views/_scene/vispy/_image.py b/src/ndv/views/_scene/vispy/_image.py new file mode 100644 index 00000000..86215d65 --- /dev/null +++ b/src/ndv/views/_scene/vispy/_image.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from vispy import scene + +from ._node import Node + +if TYPE_CHECKING: + from cmap import Colormap + + from ndv._types import ArrayLike, ImageInterpolation + from ndv.models._scene import nodes + + +class Image(Node): + """Vispy backend adaptor for an Image node.""" + + _vispy_node: scene.Image + + def __init__(self, image: nodes.Image, **backend_kwargs: Any) -> None: + backend_kwargs.update( + { + "cmap": image.cmap.to_vispy(), + # "clim": image.clim_applied(), + "gamma": image.gamma, + "interpolation": image.interpolation.value, + } + ) + try: + backend_kwargs.setdefault("texture_format", "auto") + self._vispy_node = scene.Image(image.data, **backend_kwargs) + except Exception: + backend_kwargs.pop("texture_format") + self._vispy_node = scene.Image(image.data, **backend_kwargs) + + def _vis_set_cmap(self, arg: Colormap) -> None: + self._vispy_node.cmap = arg.to_vispy() + + def _vis_set_clims(self, arg: tuple[float, float] | None) -> None: + if arg is not None: + self._vispy_node.clim = arg + + def _vis_set_gamma(self, arg: float) -> None: + self._vispy_node.gamma = arg + + def _vis_set_interpolation(self, arg: ImageInterpolation) -> None: + self._vispy_node.interpolation = arg.value + + def _vis_set_data(self, arg: ArrayLike) -> None: + self._vispy_node.set_data(arg) + self._vispy_node.update() diff --git a/src/ndv/views/_scene/vispy/_node.py b/src/ndv/views/_scene/vispy/_node.py new file mode 100644 index 00000000..e084bbfd --- /dev/null +++ b/src/ndv/views/_scene/vispy/_node.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from vispy import scene +from vispy.visuals.transforms import MatrixTransform, NullTransform + +from ndv.models._scene.nodes import node as core_node + +if TYPE_CHECKING: + from ndv.models._scene import Transform + + +class Node(core_node.NodeAdaptorProtocol): + """Node adaptor for Vispy Backend.""" + + _vispy_node: scene.VisualNode + _update: Any = None + + def _vis_force_update(self) -> None: + self._vispy_node.update() + + def _vis_block_updates(self) -> None: + self._update, self._vispy_node.update = self._vispy_node.update, lambda: None + + def _vis_unblock_updates(self) -> None: + self._vispy_node.update, self._update = self._update, None + + def _vis_get_native(self) -> Any: + return self._vispy_node + + def _vis_set_name(self, arg: str) -> None: + self._vispy_node.name = arg + + def _vis_set_parent(self, arg: core_node.Node | None) -> None: + if arg is None: + self._vispy_node.parent = None + else: + return # this causes recursion error # FIXME + vispy_node = arg.backend_adaptor("vispy")._vis_get_native() + if not isinstance(vispy_node, scene.Node): + raise TypeError("Parent must be a Node") + self._vispy_node.parent = vispy_node + + def _vis_set_children(self, arg: list[core_node.Node]) -> None: + pass + + def _vis_set_visible(self, arg: bool) -> None: + self._vispy_node.visible = arg + + def _vis_set_opacity(self, arg: float) -> None: + self._vispy_node.opacity = arg + + def _vis_set_order(self, arg: int) -> None: + self._vispy_node.order = arg + + def _vis_set_interactive(self, arg: bool) -> None: + self._vispy_node.interactive = arg + + def _vis_set_transform(self, arg: Transform) -> None: + T = NullTransform() if arg.is_null() else MatrixTransform(arg.root) + self._vispy_node.transform = T + + def _vis_add_node(self, node: core_node.Node) -> None: + vispy_node = node.backend_adaptor("vispy")._vis_get_native() + if not isinstance(vispy_node, scene.Node): + raise TypeError("Node must be a Vispy Node") + vispy_node.parent = self._vispy_node diff --git a/src/ndv/views/_scene/vispy/_points.py b/src/ndv/views/_scene/vispy/_points.py new file mode 100644 index 00000000..13c853b1 --- /dev/null +++ b/src/ndv/views/_scene/vispy/_points.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from vispy import scene + +from ._node import Node + +if TYPE_CHECKING: + import numpy.typing as npt + from cmap import Color + + from ndv.models._scene import nodes + + +class Points(Node): + """Vispy backend adaptor for an Points node.""" + + _vispy_node: scene.Markers + + def __init__(self, points: nodes.Points, **backend_kwargs: Any) -> None: + backend_kwargs.update( + { + "scaling": points.scaling, + "alpha": points.opacity, + "antialias": points.antialias, + "pos": points.coords, + "size": points.size, + "edge_width": points.edge_width, + "face_color": points.face_color, + "edge_color": points.edge_color, + "symbol": points.symbol, + } + ) + self._vispy_node = scene.Markers(**backend_kwargs) + + # TODO: + # vispy has an odd way of selectively setting individual markers properties + # without altering the rest of the properties (you generally have to include + # most of the state each time or you will overwrite the rest of the state) + # this goes for size, face/edge color, edge width, symbol + def _vis_set_coords(self, coords: npt.NDArray) -> None: + if self._vispy_node._data is None: + self._vispy_node.set_data(coords) + + def _vis_set_size(self, size: float) -> None: + self._vispy_node.set_data(size=size) + + def _vis_set_face_color(self, face_color: Color) -> None: + self._vispy_node.set_data(face_color=face_color) + + def _vis_set_edge_color(self, edge_color: Color) -> None: + self._vispy_node.set_data(edge_color=edge_color) + + def _vis_set_edge_width(self, edge_width: float) -> None: + return + self._vispy_node.set_data(edge_width=edge_width) + + def _vis_set_symbol(self, symbol: str) -> None: + self._vispy_node.symbol = symbol + + def _vis_set_scaling(self, scaling: str) -> None: + self._vispy_node.scaling = scaling + + def _vis_set_antialias(self, antialias: float) -> None: + self._vispy_node.antialias = antialias + + def _vis_set_opacity(self, opacity: float) -> None: + self._vispy_node.alpha = opacity diff --git a/src/ndv/views/_scene/vispy/_scene.py b/src/ndv/views/_scene/vispy/_scene.py new file mode 100644 index 00000000..4072b6dc --- /dev/null +++ b/src/ndv/views/_scene/vispy/_scene.py @@ -0,0 +1,22 @@ +from typing import Any + +from vispy.scene.subscene import SubScene +from vispy.visuals.filters import Clipper + +from ndv.models import _scene as core + +from ._node import Node + + +class Scene(Node): + _vispy_node: SubScene + + def __init__(self, scene: core.Scene, **backend_kwargs: Any) -> None: + self._vispy_node = SubScene(**backend_kwargs) + self._vispy_node._clipper = Clipper() + self._vispy_node.clip_children = True + + # XXX: this logic should be moved to the model + for node in scene.children: + node.backend_adaptor("vispy") # create backend adaptor if it doesn't exist + self._vis_add_node(node) diff --git a/src/ndv/views/_scene/vispy/_util.py b/src/ndv/views/_scene/vispy/_util.py new file mode 100644 index 00000000..46e67c2c --- /dev/null +++ b/src/ndv/views/_scene/vispy/_util.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from cmap import Color + + +def pyd_color_to_vispy(color: Color | None) -> str: + """Convert a color to a hex string.""" + return color.hex if color is not None else "black" diff --git a/src/ndv/views/_scene/vispy/_view.py b/src/ndv/views/_scene/vispy/_view.py new file mode 100644 index 00000000..65abe70c --- /dev/null +++ b/src/ndv/views/_scene/vispy/_view.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from vispy import scene +from vispy.scene import subscene + +from ndv.models import _scene as core + +from ._node import Node +from ._util import pyd_color_to_vispy + +if TYPE_CHECKING: + from cmap import Color + +logger = logging.getLogger(__name__) + + +# TODO: originally, View did need to be a Node, but not anymore, +# so this could be refactored +class View(Node, core.view.ViewAdaptorProtocol): + """View interface for Vispy Backend.""" + + _vispy_node: scene.ViewBox + + def __init__(self, view: core.View, **backend_kwargs: Any) -> None: + backend_kwargs.update( + { + "pos": view.layout.position, + "border_color": pyd_color_to_vispy(view.layout.border_color), + "border_width": view.layout.border_width, + "bgcolor": pyd_color_to_vispy(view.layout.background_color), + "padding": view.layout.padding, + "margin": view.layout.margin, + } + ) + if (size := view.layout.size) is not None: + backend_kwargs["size"] = size + self._vispy_node = scene.ViewBox(**backend_kwargs) + + def _vis_set_camera(self, cam: core.Camera) -> None: + vispy_cam = cam.backend_adaptor("vispy")._vis_get_native() + if not isinstance(vispy_cam, scene.cameras.BaseCamera): + raise TypeError("Camera must be a Vispy Camera") + try: + # hitting singular matrix here... probably a bad order of operations + self._vispy_node.camera = vispy_cam + # vispy_cam.set_range(margin=0) # TODO: put this elsewhere + except Exception as e: + logger.error("Error setting camera: %s", e) + + def _vis_set_scene(self, scene: core.Scene) -> None: + vispy_scene = scene.backend_adaptor("vispy")._vis_get_native() + if not isinstance(vispy_scene, subscene.SubScene): + raise TypeError("Scene must be a Vispy SubScene") + + self._vispy_node._scene = vispy_scene + vispy_scene.parent = self._vispy_node + + def _vis_set_position(self, arg: tuple[float, float]) -> None: + self._vispy_node.pos = arg + + def _vis_set_size(self, arg: tuple[float, float] | None) -> None: + self._vispy_node.size = arg + + def _vis_set_background_color(self, arg: Color | None) -> None: + self._vispy_node.bgcolor = pyd_color_to_vispy(arg) + + def _vis_set_border_width(self, arg: float) -> None: + self._vispy_node._border_width = arg + self._vispy_node._update_line() + self._vispy_node.update() + + def _vis_set_border_color(self, arg: Color | None) -> None: + self._vispy_node.border_color = pyd_color_to_vispy(arg) + + def _vis_set_padding(self, arg: int) -> None: + self._vispy_node.padding = arg + + def _vis_set_margin(self, arg: int) -> None: + self._vispy_node.margin = arg diff --git a/x.py b/x.py new file mode 100644 index 00000000..ae6dff57 --- /dev/null +++ b/x.py @@ -0,0 +1,54 @@ +from contextlib import suppress + +import numpy as np +from rich import print + +from ndv import run_app +from ndv.models._scene import Transform +from ndv.models._scene.nodes import Image, Points, Scene +from ndv.models._scene.view import View +from ndv.views import _app + +_app.ndv_app() +img1 = Image( + name="Some Image", + data=np.random.randint(0, 255, (100, 100)).astype(np.uint8), + clims=(0, 255), +) + +img2 = Image( + data=np.random.randint(0, 255, (200, 200)).astype(np.uint8), + cmap="viridis", + transform=Transform().scaled((0.7, 0.5)).translated((-10, 20)), + clims=(0, 255), +) + +scene = Scene(children=[img1, img2]) +with suppress(Exception): + points = Points( + coords=np.random.randint(0, 200, (100, 2)).astype(np.uint8), + size=5, + face_color="coral", + edge_color="blue", + opacity=0.8, + ) + scene.children.insert(0, points) +view = View(scene=scene) + + +print(view) +view.show() +view.camera._set_range(margin=0.05) +run_app() + +# sys.exit() + +# print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") +json = view.model_dump_json(indent=2, exclude_unset=True) +print(json) +# print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") +obj = View.model_validate_json(json) +print(obj) + + +assert View.model_json_schema()