From 15c2e5d2dd27a994e65f257d49cccd8d637151d1 Mon Sep 17 00:00:00 2001 From: getzze Date: Thu, 1 Feb 2024 16:22:00 +0000 Subject: [PATCH] add options for a signal aliases adapt to SignalRelay --- src/psygnal/__init__.py | 2 + src/psygnal/_dataclass_utils.py | 265 ++++++++++++++- src/psygnal/_evented_decorator.py | 47 ++- src/psygnal/_group.py | 23 +- src/psygnal/_group_descriptor.py | 207 ++++++++++-- src/psygnal/_signal.py | 6 +- src/psygnal/containers/_evented_set.py | 4 +- src/psygnal/utils.py | 28 +- tests/test_custom_fields.py | 436 +++++++++++++++++++++++++ tests/test_evented_model.py | 5 +- tests/test_group.py | 15 +- tests/test_group_descriptor.py | 2 +- 12 files changed, 962 insertions(+), 78 deletions(-) create mode 100644 tests/test_custom_fields.py diff --git a/src/psygnal/__init__.py b/src/psygnal/__init__.py index f976e061..36891edf 100644 --- a/src/psygnal/__init__.py +++ b/src/psygnal/__init__.py @@ -29,6 +29,7 @@ "EventedModel", "get_evented_namespace", "is_evented", + "PSYGNAL_METADATA", "Signal", "SignalGroup", "SignalGroupDescriptor", @@ -48,6 +49,7 @@ stacklevel=2, ) +from ._dataclass_utils import PSYGNAL_METADATA from ._evented_decorator import evented from ._exceptions import EmitLoopError from ._group import EmissionInfo, SignalGroup diff --git a/src/psygnal/_dataclass_utils.py b/src/psygnal/_dataclass_utils.py index 5b74be47..7468bb10 100644 --- a/src/psygnal/_dataclass_utils.py +++ b/src/psygnal/_dataclass_utils.py @@ -4,15 +4,32 @@ import dataclasses import sys import types -from typing import TYPE_CHECKING, Any, Iterator, List, Protocol, cast, overload +from dataclasses import dataclass, fields +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + List, + Mapping, + Protocol, + cast, + overload, +) if TYPE_CHECKING: + from dataclasses import Field + import attrs import msgspec from pydantic import BaseModel from typing_extensions import TypeGuard # py310 +EqOperator = Callable[[Any, Any], bool] +PSYGNAL_METADATA = "__psygnal_metadata" + + class _DataclassParams(Protocol): init: bool repr: bool @@ -29,6 +46,9 @@ class AttrsType: __attrs_attrs__: tuple[attrs.Attribute, ...] +KW_ONLY = object() +with contextlib.suppress(ImportError): + from dataclasses import KW_ONLY _DATACLASS_PARAMS = "__dataclass_params__" with contextlib.suppress(ImportError): from dataclasses import _DATACLASS_PARAMS # type: ignore @@ -171,8 +191,8 @@ def iter_fields( yield field_name, p_field.annotation else: for p_field in cls.__fields__.values(): # type: ignore [attr-defined] - if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore - yield p_field.name, p_field.outer_type_ # type: ignore + if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore [attr-defined] + yield p_field.name, p_field.outer_type_ # type: ignore [attr-defined] return if (attrs_fields := getattr(cls, "__attrs_attrs__", None)) is not None: @@ -185,3 +205,242 @@ def iter_fields( type_ = cls.__annotations__.get(m_field, None) yield m_field, type_ return + + +@dataclass +class FieldOptions: + name: str + type_: type | None = None + # set KW_ONLY value for compatibility with python < 3.10 + _: KW_ONLY = KW_ONLY # type: ignore [valid-type] + alias: str | None = None + skip: bool | None = None + eq: EqOperator | None = None + disable_setattr: bool | None = None + + +def is_kw_only(f: Field) -> bool: + if hasattr(f, "kw_only"): + return cast(bool, f.kw_only) + # for python < 3.10 + if f.name not in ["name", "type_"]: + return True + return False + + +def sanitize_field_options_dict(d: Mapping) -> dict[str, Any]: + field_options_kws = [f.name for f in fields(FieldOptions) if is_kw_only(f)] + return {k: v for k, v in d.items() if k in field_options_kws} + + +def get_msgspec_metadata( + cls: type[msgspec.Struct], + m_field: str, +) -> tuple[type | None, dict[str, Any]]: + # Look for type in cls and super classes + type_: type | None = None + for super_cls in cls.__mro__: + if not hasattr(super_cls, "__annotations__"): + continue + type_ = super_cls.__annotations__.get(m_field, None) + if type_ is not None: + break + + msgspec = sys.modules.get("msgspec", None) + if msgspec is None: + return type_, {} + + metadata_list = getattr(type_, "__metadata__", []) + + metadata: dict[str, Any] = {} + for meta in metadata_list: + if not isinstance(meta, msgspec.Meta): + continue + single_meta: dict[str, Any] = getattr(meta, "extra", {}).get( + PSYGNAL_METADATA, {} + ) + metadata.update(single_meta) + + return type_, metadata + + +def iter_fields_with_options( + cls: type, exclude_frozen: bool = True +) -> Iterator[FieldOptions]: + """Iterate over all fields in the class, return a field description. + + This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic + models. + + Parameters + ---------- + cls : type + The class to iterate over. + exclude_frozen : bool, optional + If True, frozen fields will be excluded. By default True. + + Yields + ------ + FieldOptions + A dataclass instance with the name, type and metadata of each field. + """ + # Add metadata for dataclasses.dataclass + dclass_fields = getattr(cls, "__dataclass_fields__", None) + if dclass_fields is not None: + """ + Example + ------- + from dataclasses import dataclass, field + + + @dataclass + class Foo: + bar: int = field(metadata={"alias": "bar_alias"}) + + assert ( + Foo.__dataclass_fields__["bar"].metadata == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + + """ + for d_field in dclass_fields.values(): + if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined] + metadata = getattr(d_field, "metadata", {}).get(PSYGNAL_METADATA, {}) + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions(d_field.name, d_field.type, **metadata) + yield options + return + + # Add metadata for pydantic dataclass + if is_pydantic_model(cls): + """ + Example + ------- + from typing import Annotated + + from pydantic import BaseModel, Field + + + # Only works with Pydantic v2 + class Foo(BaseModel): + bar: Annotated[ + str, + {'__psygnal_metadata': {"alias": "bar_alias"}} + ] = Field(...) + + # Working with Pydantic v2 and partially with v1 + # Alternative, using Field `json_schema_extra` keyword argument + class Bar(BaseModel): + bar: str = Field( + json_schema_extra={PSYGNAL_METADATA: {"alias": "bar_alias"}} + ) + + + assert ( + Foo.model_fields["bar"].metadata[0] == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + assert ( + Bar.model_fields["bar"].json_schema_extra == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + + """ + if hasattr(cls, "model_fields"): + # Pydantic v2 + for field_name, p_field in cls.model_fields.items(): + # skip frozen field + if exclude_frozen and p_field.frozen: + continue + metadata_list = getattr(p_field, "metadata", []) + metadata = {} + for field in metadata_list: + metadata.update(field.get(PSYGNAL_METADATA, {})) + # Compat with using Field `json_schema_extra` keyword argument + if isinstance(getattr(p_field, "json_schema_extra", None), Mapping): + meta_dict = cast(Mapping, p_field.json_schema_extra) + metadata.update(meta_dict.get(PSYGNAL_METADATA, {})) + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions(field_name, p_field.annotation, **metadata) + yield options + return + + else: + # Pydantic v1, metadata is not always working + for pv1_field in cls.__fields__.values(): # type: ignore [attr-defined] + # skip frozen field + if exclude_frozen and not pv1_field.field_info.allow_mutation: + continue + meta_dict = getattr(pv1_field.field_info, "extra", {}).get( + "json_schema_extra", {} + ) + metadata = meta_dict.get(PSYGNAL_METADATA, {}) + + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions( + pv1_field.name, + pv1_field.outer_type_, + **metadata, + ) + yield options + return + + # Add metadata for attrs dataclass + attrs_fields = getattr(cls, "__attrs_attrs__", None) + if attrs_fields is not None: + """ + Example + ------- + from attrs import define, field + + + @define + class Foo: + bar: int = field(metadata={"alias": "bar_alias"}) + + assert ( + Foo.__attrs_attrs__.bar.metadata == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + + """ + for a_field in attrs_fields: + metadata = getattr(a_field, "metadata", {}).get(PSYGNAL_METADATA, {}) + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions(a_field.name, a_field.type, **metadata) + yield options + return + + # Add metadata for attrs dataclass + if is_msgspec_struct(cls): + """ + Example + ------- + from typing import Annotated + + from msgspec import Meta, Struct + + + class Foo(Struct): + bar: Annotated[ + str, + Meta(extra={"__psygnal_metadata": {"alias": "bar_alias"})) + ] = "" + + + print(Foo.__annotations__["bar"].__metadata__[0].extra) + # {"__psygnal_metadata": {"alias": "bar_alias"}} + + """ + for m_field in cls.__struct_fields__: + try: + type_, metadata = get_msgspec_metadata(cls, m_field) + metadata = sanitize_field_options_dict(metadata) + except AttributeError: + msg = f"Cannot parse field metadata for {m_field}: {type_}" + # logger.exception(msg) + print(msg) + type_, metadata = None, {} + options = FieldOptions(m_field, type_, **metadata) + yield options + return diff --git a/src/psygnal/_evented_decorator.py b/src/psygnal/_evented_decorator.py index 236015c8..7e09d452 100644 --- a/src/psygnal/_evented_decorator.py +++ b/src/psygnal/_evented_decorator.py @@ -1,24 +1,19 @@ +from __future__ import annotations + from typing import ( - Any, Callable, - Dict, Literal, - Optional, - Type, + Mapping, + Sequence, TypeVar, - Union, overload, ) -from psygnal._group_descriptor import SignalGroupDescriptor +from psygnal._group_descriptor import EqOperator, SignalGroupDescriptor __all__ = ["evented"] -T = TypeVar("T", bound=Type) - -EqOperator = Callable[[Any, Any], bool] -PSYGNAL_GROUP_NAME = "_psygnal_group_" -_NULL = object() +T = TypeVar("T", bound=type) @overload @@ -26,31 +21,37 @@ def evented( cls: T, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = ..., cache_on_instance: bool = ..., + signal_aliases: Mapping[str, str | None] | Callable[[str], str] | None = ..., + skip_signals: Sequence[str] | None = ..., ) -> T: ... @overload def evented( - cls: "Optional[Literal[None]]" = None, + cls: Literal[None] | None = None, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = ..., cache_on_instance: bool = ..., + signal_aliases: Mapping[str, str | None] | Callable[[str], str] | None = ..., + skip_signals: Sequence[str] | None = ..., ) -> Callable[[T], T]: ... def evented( - cls: Optional[T] = None, + cls: T | None = None, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = True, cache_on_instance: bool = True, -) -> Union[Callable[[T], T], T]: + signal_aliases: Mapping[str, str | None] | Callable[[str], str] | None = None, + skip_signals: Sequence[str] | None = None, +) -> Callable[[T], T] | T: """A decorator to add events to a dataclass. See also the documentation for @@ -71,7 +72,7 @@ def evented( The class to decorate. events_namespace : str The name of the namespace to add the events to, by default `"events"` - equality_operators : Optional[Dict[str, Callable]] + equality_operators : dict[str, Callable] | None A dictionary mapping field names to equality operators (a function that takes two values and returns `True` if they are equal). These will be used to determine if a field has changed when setting a new value. By default, this @@ -89,6 +90,14 @@ def evented( access, but means that the owner instance will no longer be pickleable. If `False`, the SignalGroup instance will *still* be cached, but not on the instance itself. + signal_aliases: Mapping[str, str | None] | None + If defined, a mapping between field name and signal name. Field that are not + defined as keys alias to the same name. If the value is None, the field does not + emit a changed signal when mutated. If None, defaults to an empty dict. + Default to None + skip_signals : Sequence[str], optional + A list of field names for which the creation of an associated Signal is skipped. + Default to [] Returns ------- @@ -126,6 +135,8 @@ def _decorate(cls: T) -> T: equality_operators=equality_operators, warn_on_no_fields=warn_on_no_fields, cache_on_instance=cache_on_instance, + signal_aliases=signal_aliases, + skip_signals=skip_signals, ) # as a decorator, this will have already been called descriptor.__set_name__(cls, events_namespace) diff --git a/src/psygnal/_group.py b/src/psygnal/_group.py index 983e52fe..a5c07abb 100644 --- a/src/psygnal/_group.py +++ b/src/psygnal/_group.py @@ -260,6 +260,7 @@ class MySignals(SignalGroup): _psygnal_signals: ClassVar[Mapping[str, Signal]] _psygnal_uniform: ClassVar[bool] = False _psygnal_name_conflicts: ClassVar[set[str]] + _psygnal_aliases: ClassVar[Mapping[str, str | None]] _psygnal_instances: dict[str, SignalInstance] @@ -270,6 +271,7 @@ def __init__(self, instance: Any = None) -> None: "Cannot instantiate `SignalGroup` directly. Use a subclass instead." ) + # Attach SignalInstance to this SignalGroup instance self._psygnal_instances = { name: ( sig._create_signal_instance(self) @@ -278,9 +280,14 @@ def __init__(self, instance: Any = None) -> None: ) for name, sig in cls._psygnal_signals.items() } + # Attach SignalRelay to the object instance self._psygnal_relay = SignalRelay(self._psygnal_instances, instance) - def __init_subclass__(cls, strict: bool = False) -> None: + def __init_subclass__( + cls, + strict: bool = False, + signal_aliases: Mapping[str, str | None] = {}, + ) -> None: """Collects all Signal instances on the class under `cls._psygnal_signals`.""" # Collect Signals and remove from class attributes # Use dir(cls) instead of cls.__dict__ to get attributes from super() @@ -328,6 +335,8 @@ def __init_subclass__(cls, strict: bool = False) -> None: stacklevel=2, ) + cls._psygnal_aliases = {**signal_aliases} + cls._psygnal_uniform = _is_uniform(cls._psygnal_signals.values()) if strict and not cls._psygnal_uniform: raise TypeError( @@ -372,7 +381,7 @@ def signals(self) -> Mapping[str, SignalInstance]: def __len__(self) -> int: """Return the number of signals in the group (not including the relay).""" - return len(self._psygnal_signals) + return len(self._psygnal_instances) def __getitem__(self, item: str) -> SignalInstance: """Get a signal instance by name.""" @@ -390,19 +399,25 @@ def __getattr__(self, __name: str) -> SignalInstance: def __iter__(self) -> Iterator[str]: """Yield the names of all signals in the group.""" - return iter(self._psygnal_signals) + return iter(self._psygnal_instances) def __contains__(self, item: str) -> bool: """Return True if the group contains a signal with the given name.""" # this is redundant with __iter__ and can be removed, but only after # removing the deprecation warning in __getattr__ - return item in self._psygnal_signals + return item in self._psygnal_instances def __repr__(self) -> str: """Return repr(self).""" name = self.__class__.__name__ return f"" + def get_signal_by_alias(self, name: str) -> SignalInstance | None: + sig_name = self._psygnal_aliases.get(name, name) + if sig_name is None or sig_name not in self: + return None + return self[sig_name] + @classmethod def psygnals_uniform(cls) -> bool: """Return true if all signals in the group have the same signature.""" diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index 5a836de5..444029d4 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -1,11 +1,11 @@ from __future__ import annotations import contextlib +import copy import operator import sys import warnings import weakref -from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -13,13 +13,15 @@ ClassVar, Iterable, Literal, + Mapping, + Sequence, Type, TypeVar, cast, overload, ) -from ._dataclass_utils import iter_fields +from ._dataclass_utils import EqOperator, iter_fields_with_options from ._group import SignalGroup from ._signal import Signal, SignalInstance @@ -29,12 +31,11 @@ from psygnal._weak_callback import RefErrorChoice, WeakCallback -__all__ = ["is_evented", "get_evented_namespace", "SignalGroupDescriptor"] +__all__ = ["is_evented", "get_evented_namespace", "EqOperator", "SignalGroupDescriptor"] T = TypeVar("T", bound=Type) S = TypeVar("S") -EqOperator = Callable[[Any, Any], bool] _EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {} _EQ_OPERATOR_NAME = "__eq_operators__" PSYGNAL_GROUP_NAME = "_psygnal_group_" @@ -141,28 +142,113 @@ def connect_setattr( ) -@lru_cache(maxsize=None) +def _identity(x: str) -> str: + return x + + def _build_dataclass_signal_group( - cls: type, equality_operators: Iterable[tuple[str, EqOperator]] | None = None + cls: type, + signal_group_class: type[SignalGroup] = SignalGroup, + equality_operators: Iterable[tuple[str, EqOperator]] | None = None, + skip_signals: Sequence[str] | None = None, + signal_aliases: Mapping[str, str | None] | Callable[[str], str] = {}, ) -> type[SignalGroup]: - """Build a SignalGroup with events for each field in a dataclass.""" + """Build a SignalGroup with events for each field in a dataclass. + + Parameters + ---------- + cls : type + the dataclass to look for the fields to connect with signals. + signal_group_class: type[SignalGroup] + SignalGroup or a subclass of it, to use as a super class. + Default to SignalGroup + equality_operators: Iterable[tuple[str, EqOperator]] | None + If defined, a mapping of field name and equality operator to use to compare if + each field was modified after being set. + Default to None + skip_signals : Sequence[str] | None, optional + A list of field names for which the creation of an associated Signal is skipped. + Default to None + signal_aliases: Mapping[str, str | None] + A mapping between field name and signal name. Field that are not defined as keys + alias to the same name. If the value is None, the field does not emit a changed + signal when mutated. + Default to {} + + """ + if skip_signals is None: + skip_signals = [] + group_name = f"{cls.__name__}_{signal_group_class.__name__}" _equality_operators = dict(equality_operators) if equality_operators else {} + if callable(signal_aliases): + transform = signal_aliases + _signal_aliases = {} + else: + transform = _identity + _signal_aliases = dict(signal_aliases) if signal_aliases else {} signals = {} eq_map = _get_eq_operator_map(cls) # create a Signal for each field in the dataclass - for name, type_ in iter_fields(cls): - if name in _equality_operators: - if not callable(_equality_operators[name]): # pragma: no cover + for f in iter_fields_with_options(cls): + # Convert and validate private fields + if f.skip or f.name in skip_signals: + # Skip this field + continue + + # Equality operator + if f.eq is not None: + if not callable(f.eq): # pragma: no cover + raise TypeError("`eq` field metadata must be callable") + eq_map[f.name] = f.eq + elif f.name in _equality_operators: + if not callable(_equality_operators[f.name]): # pragma: no cover raise TypeError("EqOperator must be callable") - eq_map[name] = _equality_operators[name] + eq_map[f.name] = _equality_operators[f.name] + else: + eq_map[f.name] = _pick_equality_operator(f.type_) + + # Signal name + if f.alias is not None: + sig_name = f.alias + elif f.disable_setattr: + sig_name = f.name + elif f.name in _signal_aliases: + sig_name_or_none = _signal_aliases[f.name] + if sig_name_or_none is None: + # original name, but it will be ignored when setattr + sig_name = f.name + else: + sig_name = sig_name_or_none else: - eq_map[name] = _pick_equality_operator(type_) - field_type = object if type_ is None else type_ - signals[name] = sig = Signal(field_type, field_type) + sig_name = transform(f.name) + + # Add the field and signal name to the table of signals, to emit with `setattr` + if f.disable_setattr: + _signal_aliases[f.name] = None + elif f.name not in _signal_aliases: + _signal_aliases[f.name] = sig_name + + # Repeated signal + if sig_name in signals: + key = next((k for k, v in _signal_aliases.items() if v == sig_name), None) + warnings.warn( + f"Signal {sig_name} was already created in {group_name}, " + f"from field {key}", + UserWarning, + stacklevel=2, + ) + continue + + # Create the Signal + field_type = object if f.type_ is None else f.type_ + signals[sig_name] = sig = Signal(field_type, field_type) # patch in our custom SignalInstance class with maxargs=1 on connect_setattr sig._signal_instance_class = _DataclassFieldSignalInstance - return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals) + # Create SignalGroup subclass with the attached signals and signal_aliases + return type( + group_name, (signal_group_class,), signals, signal_aliases=_signal_aliases + ) def is_evented(obj: object) -> bool: @@ -216,17 +302,22 @@ def __exit__(self, *args: Any) -> None: @overload -def evented_setattr(signal_group_name: str, super_setattr: SetAttr) -> SetAttr: ... +def evented_setattr( + signal_group_name: str, + super_setattr: SetAttr, +) -> SetAttr: ... @overload def evented_setattr( - signal_group_name: str, super_setattr: Literal[None] | None = None + signal_group_name: str, + super_setattr: Literal[None] | None = None, ) -> Callable[[SetAttr], SetAttr]: ... def evented_setattr( - signal_group_name: str, super_setattr: SetAttr | None = None + signal_group_name: str, + super_setattr: SetAttr | None = None, ) -> SetAttr | Callable[[SetAttr], SetAttr]: """Create a new __setattr__ method that emits events when fields change. @@ -269,12 +360,12 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None: return super_setattr(self, name, value) group: SignalGroup | None = getattr(self, signal_group_name, None) - if not isinstance(group, SignalGroup) or name not in group: + if not isinstance(group, SignalGroup): return super_setattr(self, name, value) # don't emit if the signal doesn't exist or has no listeners - signal: SignalInstance = group[name] - if len(signal) < 1: + signal: SignalInstance | None = group.get_signal_by_alias(name) + if signal is None or len(signal) < 1: return super_setattr(self, name, value) with _changes_emitted(self, name, signal): @@ -352,6 +443,14 @@ def __setattr__(self, name: str, value: Any) -> None: events when fields change. If `False`, no `__setattr__` method will be created. (This will prevent signal emission, and assumes you are using a different mechanism to emit signals when fields change.) + signal_aliases: Mapping[str, str | None] | None + If defined, a mapping between field name and signal name. Field that are not + defined as keys alias to the same name. If the value is None, the field does not + emit a changed signal when mutated. If None, defaults to an empty dict. + Default to None + skip_signals : Sequence[str], optional + A list of field names for which the creation of an associated Signal is skipped. + Default to [] Examples -------- @@ -377,18 +476,30 @@ class Person: def __init__( self, *, - equality_operators: dict[str, EqOperator] | None = None, - signal_group_class: type[SignalGroup] | None = None, + equality_operators: Mapping[str, EqOperator] | None = None, warn_on_no_fields: bool = True, cache_on_instance: bool = True, patch_setattr: bool = True, + signal_group_class: type[SignalGroup] = SignalGroup, + collect_fields: bool = True, + signal_aliases: Mapping[str, str | None] | Callable[[str], str] | None = None, + skip_signals: Sequence[str] | None = None, ): - self._signal_group = signal_group_class self._name: str | None = None self._eqop = tuple(equality_operators.items()) if equality_operators else None self._warn_on_no_fields = warn_on_no_fields self._cache_on_instance = cache_on_instance self._patch_setattr = patch_setattr + self._signal_group_class = signal_group_class + self._collect_fields = collect_fields + self._skip_signals = list(skip_signals) if skip_signals else [] + self._signal_aliases: dict[str, str | None] | Callable[[str], str] = ( + signal_aliases + if callable(signal_aliases) + else (dict(signal_aliases) if signal_aliases else {}) + ) + + self._signal_groups: dict[int, type[SignalGroup]] = {} def __set_name__(self, owner: type, name: str) -> None: """Called when this descriptor is added to class `owner` as attribute `name`.""" @@ -434,13 +545,15 @@ def __get__( if instance is None: return self + signal_group = self._get_signal_group(owner) + # if we haven't yet instantiated a SignalGroup for this instance, # do it now and cache it. Note that we cache it here in addition to # the instance (in case the instance is not modifiable). obj_id = id(instance) if obj_id not in self._instance_map: # cache it - self._instance_map[obj_id] = self._create_group(owner)(instance) + self._instance_map[obj_id] = signal_group(instance) # also *try* to set it on the instance as well, since it will skip all the # __get__ logic in the future, but if it fails, no big deal. if self._name and self._cache_on_instance: @@ -453,13 +566,55 @@ def __get__( return self._instance_map[obj_id] + def _get_signal_group(self, owner: type) -> type[SignalGroup]: + type_id = id(owner) + if type_id not in self._signal_groups: + self._signal_groups[type_id] = self._create_group(owner) + return self._signal_groups[type_id] + def _create_group(self, owner: type) -> type[SignalGroup]: - Group = self._signal_group or _build_dataclass_signal_group(owner, self._eqop) + # Do not collect fields from owner class + if not self._collect_fields: + Group = copy.deepcopy(self._signal_group_class) + if not hasattr(Group, "_psygnal_signals"): + raise ValueError( + "SignalGroupDescriptor signal_group_class argument must be a " + "subclass of SignalGroup if collect_field_to_signals is False." + ) + + # Remove skipped signals + for sig in self._skip_signals: + if sig in Group._psygnal_signals: + del Group._psygnal_signals[sig] # type: ignore [attr-defined] + + # Add aliases + if callable(self._signal_aliases): + warnings.warn( + "Skip signal aliases, cannot use a callable `signal_aliases` with " + "`collect_fields = False`", + UserWarning, + stacklevel=2, + ) + Group._psygnal_aliases = {} + else: + Group._psygnal_aliases = {**self._signal_aliases} + + # Collect fields and create SignalGroup subclass + else: + Group = _build_dataclass_signal_group( + owner, + signal_group_class=self._signal_group_class, + equality_operators=self._eqop, + skip_signals=self._skip_signals, + signal_aliases=self._signal_aliases, + ) + if self._warn_on_no_fields and not Group._psygnal_signals: warnings.warn( f"No mutable fields found on class {owner}: no events will be " "emitted. (Is this a dataclass, attrs, msgspec, or pydantic model?)", stacklevel=2, ) + self._do_patch_setattr(owner) return Group diff --git a/src/psygnal/_signal.py b/src/psygnal/_signal.py index e53d6a19..29fcd2f8 100644 --- a/src/psygnal/_signal.py +++ b/src/psygnal/_signal.py @@ -150,12 +150,14 @@ def __set_name__(self, owner: type[Any], name: str) -> None: self._name = name @overload - def __get__(self, instance: None, owner: type[Any] | None = None) -> Signal: ... + def __get__( + self, instance: None, owner: type[Any] | None = None + ) -> Signal: ... # pragma: no cover @overload def __get__( self, instance: Any, owner: type[Any] | None = None - ) -> SignalInstance: ... + ) -> SignalInstance: ... # pragma: no cover def __get__( self, instance: Any, owner: type[Any] | None = None diff --git a/src/psygnal/containers/_evented_set.py b/src/psygnal/containers/_evented_set.py index 4a312d35..73770a7a 100644 --- a/src/psygnal/containers/_evented_set.py +++ b/src/psygnal/containers/_evented_set.py @@ -84,12 +84,12 @@ def __repr__(self) -> str: def _pre_add_hook(self, item: _T) -> _T | BailType: return item # pragma: no cover - def _post_add_hook(self, item: _T) -> None: ... + def _post_add_hook(self, item: _T) -> None: ... # pragma: no cover def _pre_discard_hook(self, item: _T) -> _T | BailType: return item # pragma: no cover - def _post_discard_hook(self, item: _T) -> None: ... + def _post_discard_hook(self, item: _T) -> None: ... # pragma: no cover def _do_add(self, item: _T) -> None: self._data.add(item) diff --git a/src/psygnal/utils.py b/src/psygnal/utils.py index 88149168..c84f6482 100644 --- a/src/psygnal/utils.py +++ b/src/psygnal/utils.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Generator, Iterator from warnings import warn -from ._group import EmissionInfo, SignalGroup, SignalRelay +from ._group import EmissionInfo, SignalGroup from ._signal import SignalInstance __all__ = ["monitor_events", "iter_signal_instances"] @@ -59,10 +59,6 @@ def monitor_events( ) disconnectors = set() for siginst in iter_signal_instances(obj, include_private_attrs): - if isinstance(siginst, SignalRelay): - # TODO: ... but why? - continue - if _old_api: def _report(*args: Any, signal: SignalInstance = siginst) -> None: @@ -103,14 +99,22 @@ def iter_signal_instances( SignalInstance SignalInstances (and SignalGroups) found as attributes on `obj`. """ + # SignalGroup + if isinstance(obj, SignalGroup): + for sig in obj: + yield obj[sig] + return + + # Signal attached to Class for n in dir(obj): - if include_private_attrs or not n.startswith("_"): - with suppress(AttributeError, FutureWarning): - attr = getattr(obj, n) - if isinstance(attr, SignalInstance): - yield attr - if isinstance(attr, SignalGroup): - yield attr._psygnal_relay + if not include_private_attrs and n.startswith("_"): + continue + with suppress(AttributeError, FutureWarning): + attr = getattr(obj, n) + if isinstance(attr, SignalInstance): + yield attr + if isinstance(attr, SignalGroup): + yield attr._psygnal_relay _COMPILED_EXTS = (".so", ".pyd") diff --git a/tests/test_custom_fields.py b/tests/test_custom_fields.py new file mode 100644 index 00000000..4581428f --- /dev/null +++ b/tests/test_custom_fields.py @@ -0,0 +1,436 @@ +# from __future__ import annotations # breaks msgspec Annotated + +import contextlib +import sys +from typing import ClassVar +from unittest.mock import Mock + +import pytest + +from psygnal import ( + PSYGNAL_METADATA, + EmissionInfo, + Signal, + SignalGroup, + SignalGroupDescriptor, + is_evented, +) + +Annotated = None +with contextlib.suppress(ImportError): + from typing import Annotated # type: ignore + + +min_py_version = pytest.mark.skipif( + sys.version_info < (3, 9), reason="needs typing.Annotated" +) + + +def get_signal_aliases(obj: object) -> dict[str, str | None]: + if not is_evented(obj): + return {} + return obj.events._psygnal_aliases + + +@pytest.mark.parametrize( + "type_", + [ + "dataclass", + "attrs", + pytest.param("pydantic", marks=min_py_version), + pytest.param("msgspec", marks=min_py_version), + ], +) +def test_field_metadata(type_: str) -> None: + a_metadata = {PSYGNAL_METADATA: {"alias": "a_changed"}} + b_metadata = {PSYGNAL_METADATA: {"eq": lambda s1, s2: s1.lower() == s2.lower()}} + c_metadata = {PSYGNAL_METADATA: {"skip": True}} + d_metadata = {PSYGNAL_METADATA: {"disable_setattr": True}} + + if type_ == "dataclass": + from dataclasses import dataclass, field + + @dataclass + class Base: + a: int = field(metadata=a_metadata) + events: ClassVar = SignalGroupDescriptor() + + @dataclass + class Foo(Base): + b: str = field(metadata=b_metadata) + + @dataclass + class Bar(Foo): + c: float = field(metadata=c_metadata) + + @dataclass + class Baz(Bar): + d: float = field(metadata=d_metadata) + + elif type_ == "attrs": + from attrs import define, field + + @define + class Base: + a: int = field(metadata=a_metadata) + events: ClassVar = SignalGroupDescriptor() + + @define + class Foo(Base): + b: str = field(metadata=b_metadata) + + @define + class Bar(Foo): + c: float = field(metadata=c_metadata) + + @define + class Baz(Bar): + d: float = field(metadata=d_metadata) + + elif type_ == "pydantic": + pytest.importorskip("pydantic", minversion="2") + from pydantic import BaseModel, Field + + class Base(BaseModel): + a: Annotated[int, a_metadata] + events: ClassVar = SignalGroupDescriptor() + + # Alternative, using Field `json_schema_extra` keyword argument + class Foo(Base): + b: str = Field(json_schema_extra=b_metadata) + + class Bar(Foo): + c: Annotated[float, c_metadata] + + class Baz(Bar): + d: Annotated[float, d_metadata] + + elif type_ == "msgspec": + msgspec = pytest.importorskip("msgspec") + + class Base(msgspec.Struct): # type: ignore + a: Annotated[int, msgspec.Meta(extra=a_metadata)] + events: ClassVar = SignalGroupDescriptor() + + class Foo(Base): + b: Annotated[str, msgspec.Meta(extra=b_metadata)] + + class Bar(Foo): + c: Annotated[float, msgspec.Meta(extra=c_metadata)] + + class Baz(Bar): + d: Annotated[float, msgspec.Meta(extra=d_metadata)] + + assert Bar.events is Base.events + + # Instantiate objects + base = Base(a=1) + foo = Foo(a=1, b="b") + bar = Bar(a=1, b="b", c=3.0) + bar2 = Bar(a=1, b="b", c=3.0) + baz = Baz(a=1, b="b", c=3.0, d=4.0) + + # the patching of __setattr__ should only happen once + # and it will happen only on the first access of .events + assert set(base.events) == {"a_changed"} + assert set(foo.events) == {"a_changed", "b"} + assert set(bar.events) == {"a_changed", "b"} + assert set(bar2.events) == {"a_changed", "b"} + assert set(baz.events) == {"a_changed", "b", "d"} + + assert get_signal_aliases(base) == {"a": "a_changed"} + assert get_signal_aliases(foo) == {"a": "a_changed", "b": "b"} + assert get_signal_aliases(bar) == {"a": "a_changed", "b": "b"} + assert get_signal_aliases(bar2) == {"a": "a_changed", "b": "b"} + assert get_signal_aliases(baz) == {"a": "a_changed", "b": "b", "d": None} + + mock = Mock() + assert not hasattr(foo.events, "a") + foo.events.a_changed.connect(mock) + foo.events.b.connect(mock) + baz.events.a_changed.connect(mock) + baz.events.b.connect(mock) + baz.events.d.connect(mock) + + # base doesn't affect subclass + assert not hasattr(base.events, "a") + base.events.a_changed.emit(1) + mock.assert_not_called() + + base.events.a_changed.emit(2) + mock.assert_not_called() + + # `alias` works + assert hasattr(foo.events, "a_changed") + foo.a = 2 + mock.assert_called_once_with(2, 1) + mock.reset_mock() + + baz.a = 2 + mock.assert_called_once_with(2, 1) + mock.reset_mock() + + # `eq` works + foo.b = "B" + mock.assert_not_called() + baz.b = "B" + mock.assert_not_called() + + foo.b = "C" + mock.assert_called_once_with("C", "B") + mock.reset_mock() + + # `skip` works + assert not hasattr(baz.events, "c") + + # `disable_setattr` works + baz.d = 5.0 + mock.assert_not_called() + + # Check all + mock1 = Mock() + baz.events.all.connect(mock1) + baz.c = 4.0 + mock1.assert_not_called() + baz.d = 6.0 + mock1.assert_not_called() + baz.a = 3 + assert hasattr(baz.events, "a_changed") + mock1.assert_called_once_with(EmissionInfo(baz.events.a_changed, (3, 2))) + + +@pytest.mark.parametrize( + "type_", + [ + "dataclass", + "attrs", + "pydantic", + "msgspec", + ], +) +def test_alias_parameters(type_: str) -> None: + foo_options = {"skip_signals": ["_b"]} + bar_options = {"signal_aliases": lambda x: f"{x}_changed"} + baz_options = {"signal_aliases": {"a": "a_changed", "_b": "b_changed", "b": None}} + + if type_ == "dataclass": + from dataclasses import dataclass, field + + @dataclass + class Foo: + events: ClassVar = SignalGroupDescriptor(**foo_options) + a: int + _b: str + + @dataclass + class Bar: + events: ClassVar = SignalGroupDescriptor(**bar_options) + a: int + b: str + + @dataclass + class Baz: + events: ClassVar = SignalGroupDescriptor(**baz_options) + a: int + _b: str = field(default="b") + + @property + def b(self) -> str: + return self._b + + @b.setter + def b(self, value: str): + self._b = value + + elif type_ == "attrs": + from attrs import define, field + + @define + class Foo: + events: ClassVar = SignalGroupDescriptor(**foo_options) + a: int + _b: str = field(alias="_b") + + @define + class Bar: + events: ClassVar = SignalGroupDescriptor(**bar_options) + a: int + b: str + + @define + class Baz: + events: ClassVar = SignalGroupDescriptor(**baz_options) + a: int + _b: str = field(alias="_b", default="b") + + @property + def b(self) -> str: + return self._b + + @b.setter + def b(self, value: str): + self._b = value + + elif type_ == "pydantic": + pytest.importorskip("pydantic", minversion="2") + from pydantic import BaseModel + + class Foo(BaseModel): + events: ClassVar = SignalGroupDescriptor(**foo_options) + a: int + _b: str # not a field anyway + + class Bar(BaseModel): + events: ClassVar = SignalGroupDescriptor(**bar_options) + a: int + b: str + + class Baz(BaseModel): + events: ClassVar = SignalGroupDescriptor(**baz_options) + a: int + _b: str = "b" # not defining a field, signal will not be created + + @property + def b(self) -> str: + return self._b + + @b.setter + def b(self, value: str): + self._b = value + + elif type_ == "msgspec": + msgspec = pytest.importorskip("msgspec") + + class Foo(msgspec.Struct): # type: ignore + events: ClassVar = SignalGroupDescriptor(**foo_options) + a: int + _b: str + + class Bar(msgspec.Struct): # type: ignore + events: ClassVar = SignalGroupDescriptor(**bar_options) + a: int + b: str + + class Baz(msgspec.Struct): # type: ignore + events: ClassVar = SignalGroupDescriptor(**baz_options) + a: int + _b: str = "b" + + @property + def b(self) -> str: + return self._b + + @b.setter + def b(self, value: str): + self._b = value + + # Instantiate objects + foo = Foo(a=1, _b="b") + bar = Bar(a=1, b="b") + baz = Baz(a=1) + + # Check signals + assert set(foo.events) == {"a"} + assert set(bar.events) == {"a_changed", "b_changed"} + if type_.startswith("pydantic"): + assert set(baz.events) == {"a_changed"} + else: + assert set(baz.events) == {"a_changed", "b_changed"} + assert get_signal_aliases(baz) == { + "a": "a_changed", + "_b": "b_changed", + "b": None, + } + + mock = Mock() + baz.events.a_changed.connect(mock) + if not type_.startswith("pydantic"): + baz.events.b_changed.connect(mock) + + baz.a = 1 + mock.assert_not_called() + baz.a = 2 + mock.assert_called_once_with(2, 1) + mock.reset_mock() + + # pydantic v1 does not support properties + if type_ != "pydantic_v1": + baz.b = "b" + mock.assert_not_called() + baz.b = "c" + if not type_.startswith("pydantic"): + mock.assert_called_once_with("c", "b") + + +@pytest.mark.parametrize("collect", [False, True]) +def test_direct_signal_group(collect) -> None: + """Test directly using evented_setattr on a class""" + + class FooSignalGroup(SignalGroup): + a = Signal(int, int) + b_changed = Signal(float, float) + c = Signal(str, str) + d = Signal(str, str) + + class Foo: + events: ClassVar = SignalGroupDescriptor( + signal_group_class=FooSignalGroup, + collect_fields=collect, + signal_aliases={"b": "b_changed", "c": None, "_c": "c"}, + ) + a: int + b: float + _c: str + _d: str + + def __init__(self, a: int = 1, b: float = 2.0, c: str = "c", d: str = "d"): + self.a = a + self.b = b + self.c = c + self.d = d + + @property + def c(self) -> str: + return self._c + + @c.setter + def c(self, value: str): + self._c = value + + @property + def d(self) -> str: + return self._d.lower() + + @d.setter + def d(self, value: str): + self._d = value + + foo = Foo() + mock = Mock() + foo.events.a.connect(mock) + foo.events.b_changed.connect(mock) + foo.events.c.connect(mock) + foo.events.d.connect(mock) + + foo.a = 2 + mock.assert_called_once_with(2, 1) + mock.reset_mock() + + foo.b = 3.0 + mock.assert_called_once_with(3.0, 2.0) + mock.reset_mock() + + foo.c = "c" + mock.assert_not_called() + foo.c = "cc" + mock.assert_called_once_with("cc", "c") + mock.reset_mock() + foo._c = "ccc" + mock.assert_called_once_with("ccc", "cc") + mock.reset_mock() + + foo.d = "D" + mock.assert_not_called() + foo.d = "DD" + mock.assert_called_once_with("dd", "d") + mock.reset_mock() diff --git a/tests/test_evented_model.py b/tests/test_evented_model.py index 9d3571e9..1c938e18 100644 --- a/tests/test_evented_model.py +++ b/tests/test_evented_model.py @@ -69,9 +69,8 @@ class User(EventedModel): # test event system assert isinstance(user.events, SignalGroup) - # with pytest.warns(FutureWarning): - assert "id" in user.events.signals - assert "name" in user.events.signals + assert "id" in user.events + assert "name" in user.events # ClassVars are excluded from events assert "age" not in user.events diff --git a/tests/test_group.py b/tests/test_group.py index f5eacd6e..2aa23d7b 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -121,7 +121,7 @@ def test_signal_group_connect_no_args() -> None: def my_slot() -> None: count.append(1) - group.all.connect(my_slot) + group.connect(my_slot) group.sig1.emit(1) group.sig2.emit("hi") assert len(count) == 2 @@ -133,7 +133,7 @@ def test_group_blocked() -> None: mock1 = Mock() mock2 = Mock() - group.all.connect(mock1) + group.connect(mock1) group.sig1.connect(mock2) group.sig1.emit(1) @@ -185,7 +185,7 @@ def test_group_disconnect_single_slot() -> None: group.sig1.connect(mock1) group.sig2.connect(mock2) - group.all.disconnect(mock1) + group.disconnect(mock1) group.sig1.emit() mock1.assert_not_called() @@ -203,7 +203,7 @@ def test_group_disconnect_all_slots() -> None: group.sig1.connect(mock1) group.sig2.connect(mock2) - group.all.disconnect() + group.disconnect() group.sig1.emit() group.sig2.emit() @@ -253,20 +253,21 @@ def test_group_deepcopy( """ class T: - def method(self) -> None: ... + def method(self): ... obj = T() group = Group(obj) assert deepcopy(group) is not group # but no warning group.all.connect(obj.method) + group2 = deepcopy(group) assert not len(group2.all) mock = Mock() mock2 = Mock() - group.all.connect(mock) - group2.all.connect(mock2) + group.connect(mock) + group2.connect(mock2) # test that we can access signalinstances (either using getattr or __getitem__) siginst1 = get_sig(group, signame) diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index 1b36f032..a4e80f06 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -110,8 +110,8 @@ def test_no_patching(patch_setattr: bool) -> None: # sourcery skip: extract-duplicate-method @dataclass class Foo: - a: int _events: ClassVar = SignalGroupDescriptor(patch_setattr=patch_setattr) + a: int with patch.object( _group_descriptor, "evented_setattr", wraps=_group_descriptor.evented_setattr