diff --git a/src/magicgui/type_map/_type_map.py b/src/magicgui/type_map/_type_map.py index 33571c73d..cc422846e 100644 --- a/src/magicgui/type_map/_type_map.py +++ b/src/magicgui/type_map/_type_map.py @@ -3,6 +3,7 @@ import datetime import inspect +import itertools import os import pathlib import sys @@ -366,6 +367,65 @@ def _validate_return_callback(func: Callable) -> None: _T = TypeVar("_T", bound=type) +def _register_type_callback( + resolved_type: _T, + return_callback: ReturnCallback | None = None, +) -> list[type]: + modified_callbacks = [] + if return_callback is None: + return [] + _validate_return_callback(return_callback) + # if the type is a Union, add the callback to all of the types in the union + # (except NoneType) + if get_origin(resolved_type) is Union: + for type_per in _generate_union_variants(resolved_type): + if return_callback not in _RETURN_CALLBACKS[type_per]: + _RETURN_CALLBACKS[type_per].append(return_callback) + modified_callbacks.append(type_per) + + for t in get_args(resolved_type): + if not _is_none_type(t) and return_callback not in _RETURN_CALLBACKS[t]: + _RETURN_CALLBACKS[t].append(return_callback) + modified_callbacks.append(t) + elif return_callback not in _RETURN_CALLBACKS[resolved_type]: + _RETURN_CALLBACKS[resolved_type].append(return_callback) + modified_callbacks.append(resolved_type) + return modified_callbacks + + +def _register_widget( + resolved_type: _T, + widget_type: WidgetRef | None = None, + **options: Any, +) -> WidgetTuple | None: + _options = cast(dict, options) + + previous_widget = _TYPE_DEFS.get(resolved_type) + + if "choices" in _options: + _TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options) + if widget_type is not None: + warnings.warn( + "Providing `choices` overrides `widget_type`. Categorical widget " + f"will be used for type {resolved_type}", + stacklevel=2, + ) + elif widget_type is not None: + if not isinstance(widget_type, (str, WidgetProtocol)) and not ( + inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget) + ): + raise TypeError( + '"widget_type" must be either a string, WidgetProtocol, or ' + "Widget subclass" + ) + _TYPE_DEFS[resolved_type] = (widget_type, _options) + elif "bind" in _options: + # if we're binding a value to this parameter, it doesn't matter what type + # of ValueWidget is used... it usually won't be shown + _TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options) + return previous_widget + + @overload def register_type( type_: _T, @@ -435,43 +495,11 @@ def register_type( "must be provided." ) - def _deco(type_: _T) -> _T: - resolved_type = resolve_single_type(type_) - if return_callback is not None: - _validate_return_callback(return_callback) - # if the type is a Union, add the callback to all of the types in the union - # (except NoneType) - if get_origin(resolved_type) is Union: - for t in get_args(resolved_type): - if not _is_none_type(t): - _RETURN_CALLBACKS[t].append(return_callback) - else: - _RETURN_CALLBACKS[resolved_type].append(return_callback) - - _options = cast(dict, options) - - if "choices" in _options: - _TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options) - if widget_type is not None: - warnings.warn( - "Providing `choices` overrides `widget_type`. Categorical widget " - f"will be used for type {resolved_type}", - stacklevel=2, - ) - elif widget_type is not None: - if not isinstance(widget_type, (str, WidgetProtocol)) and not ( - inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget) - ): - raise TypeError( - '"widget_type" must be either a string, WidgetProtocol, or ' - "Widget subclass" - ) - _TYPE_DEFS[resolved_type] = (widget_type, _options) - elif "bind" in _options: - # if we're binding a value to this parameter, it doesn't matter what type - # of ValueWidget is used... it usually won't be shown - _TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options) - return type_ + def _deco(type__: _T) -> _T: + resolved_type = resolve_single_type(type__) + _register_type_callback(resolved_type, return_callback) + _register_widget(resolved_type, widget_type, **options) + return type__ return _deco if type_ is None else _deco(type_) @@ -507,23 +535,19 @@ def type_registered( """ resolved_type = resolve_single_type(type_) - # check if return_callback is already registered - rc_was_present = return_callback in _RETURN_CALLBACKS.get(resolved_type, []) # store any previous widget_type and options for this type - prev_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None) - resolved_type = register_type( - resolved_type, - widget_type=widget_type, - return_callback=return_callback, - **options, - ) + + revert_list = _register_type_callback(resolved_type, return_callback) + prev_type_def = _register_widget(resolved_type, widget_type, **options) + new_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None) try: yield finally: # restore things to before the context - if return_callback is not None and not rc_was_present: - _RETURN_CALLBACKS[resolved_type].remove(return_callback) + if return_callback is not None: # this if is only for mypy + for return_callback_type in revert_list: + _RETURN_CALLBACKS[return_callback_type].remove(return_callback) if _TYPE_DEFS.get(resolved_type, None) is not new_type_def: warnings.warn("Type definition changed during context", stacklevel=2) @@ -537,9 +561,6 @@ def type_registered( def type2callback(type_: type) -> list[ReturnCallback]: """Return any callbacks that have been registered for ``type_``. - Note that if the return type is X, then the callbacks registered for Optional[X] - will be returned also be returned. - Parameters ---------- type_ : type @@ -555,7 +576,7 @@ def type2callback(type_: type) -> list[ReturnCallback]: # look for direct hits ... # if it's an Optional, we need to look for the type inside the Optional - _, type_ = _is_optional(resolve_single_type(type_)) + type_ = resolve_single_type(type_) if type_ in _RETURN_CALLBACKS: return _RETURN_CALLBACKS[type_] @@ -566,10 +587,8 @@ def type2callback(type_: type) -> list[ReturnCallback]: return [] -def _is_optional(type_: Any) -> tuple[bool, type]: - # TODO: this function is too similar to _type_optional above... need to combine - if get_origin(type_) is Union: - args = get_args(type_) - if len(args) == 2 and any(_is_none_type(i) for i in args): - return True, next(i for i in args if not _is_none_type(i)) - return False, type_ +def _generate_union_variants(type_: Any) -> Iterator[type]: + type_args = get_args(type_) + for i in range(2, len(type_args) + 1): + for per in itertools.combinations(type_args, i): + yield cast(type, Union[per]) diff --git a/tests/conftest.py b/tests/conftest.py index df4ad5936..c45364c53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,3 +17,12 @@ def always_qapp(qapp): for w in qapp.topLevelWidgets(): w.close() w.deleteLater() + + +@pytest.fixture(autouse=True, scope="function") +def _clean_return_callbacks(): + from magicgui.type_map._type_map import _RETURN_CALLBACKS + + yield + + _RETURN_CALLBACKS.clear() diff --git a/tests/test_magicgui.py b/tests/test_magicgui.py index d96cec7fd..c20d00316 100644 --- a/tests/test_magicgui.py +++ b/tests/test_magicgui.py @@ -4,7 +4,7 @@ import inspect from enum import Enum -from typing import NewType, Optional +from typing import NewType, Optional, Union from unittest.mock import Mock import pytest @@ -901,3 +901,37 @@ def func_optional(a: bool) -> ReturnType: mock.reset_mock() func_optional(a=False) mock.assert_called_once_with(func_optional, None, ReturnType) + + +@pytest.mark.parametrize("optional", [True, False]) +def test_no_duplication_call(optional): + mock = Mock() + mock2 = Mock() + + NewInt = NewType("NewInt", int) + register_type(Optional[NewInt], return_callback=mock) + register_type(NewInt, return_callback=mock) + register_type(NewInt, return_callback=mock2) + ReturnType = Optional[NewInt] if optional else NewInt + + @magicgui + def func() -> ReturnType: + return NewInt(1) + + func() + + mock.assert_called_once() + assert mock2.call_count == (not optional) + + +def test_no_order(): + mock = Mock() + + register_type(Union[int, None], return_callback=mock) + + @magicgui + def func() -> Union[None, int]: + return 1 + + func() + mock.assert_called_once() diff --git a/tests/test_types.py b/tests/test_types.py index 83d89a62b..b2d24cb05 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -189,6 +189,41 @@ def test_type_registered_warns(): assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit) +def test_type_registered_optional_callbacks(): + assert not _RETURN_CALLBACKS[int] + assert not _RETURN_CALLBACKS[Optional[int]] + + @magicgui + def func1(a: int) -> int: + return a + + @magicgui + def func2(a: int) -> Optional[int]: + return a + + mock1 = Mock() + mock2 = Mock() + mock3 = Mock() + + register_type(int, return_callback=mock2) + + with type_registered(Optional[int], return_callback=mock1): + func1(1) + mock1.assert_called_once_with(func1, 1, int) + mock1.reset_mock() + func2(2) + mock1.assert_called_once_with(func2, 2, Optional[int]) + mock1.reset_mock() + mock2.assert_called_once_with(func1, 1, int) + assert _RETURN_CALLBACKS[int] == [mock2, mock1] + assert _RETURN_CALLBACKS[Optional[int]] == [mock1] + register_type(Optional[int], return_callback=mock3) + assert _RETURN_CALLBACKS[Optional[int]] == [mock1, mock3] + + assert _RETURN_CALLBACKS[Optional[int]] == [mock3] + assert _RETURN_CALLBACKS[int] == [mock2, mock3] + + def test_pick_widget_literal(): from typing import Literal