From d94446a9b637fd8db41ad31f5a93e73d2f008782 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 13 Mar 2023 17:45:07 +0100 Subject: [PATCH 1/5] use variant generation to simplify discover callbacks --- src/magicgui/type_map/_type_map.py | 30 +++++++++++++++++------------- tests/test_magicgui.py | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/magicgui/type_map/_type_map.py b/src/magicgui/type_map/_type_map.py index 59bc43c54..66b037a2f 100644 --- a/src/magicgui/type_map/_type_map.py +++ b/src/magicgui/type_map/_type_map.py @@ -3,6 +3,7 @@ import datetime import inspect +import itertools import os import pathlib import sys @@ -426,11 +427,19 @@ def _deco(type_: _T) -> _T: # if the type is a Union, add the callback to all of the types in the union # (except NoneType) if get_origin(resolved_type) is Union: + for type_per in _generate_union_variants(resolved_type): + if return_callback not in _RETURN_CALLBACKS[type_per]: + _RETURN_CALLBACKS[type_per].append(return_callback) + for t in get_args(resolved_type): - if not _is_none_type(t): + if ( + not _is_none_type(t) + and return_callback not in _RETURN_CALLBACKS[t] + ): _RETURN_CALLBACKS[t].append(return_callback) else: - _RETURN_CALLBACKS[resolved_type].append(return_callback) + if return_callback not in _RETURN_CALLBACKS[resolved_type]: + _RETURN_CALLBACKS[resolved_type].append(return_callback) _options = cast(dict, options) @@ -521,9 +530,6 @@ def type_registered( def type2callback(type_: type) -> list[ReturnCallback]: """Return any callbacks that have been registered for ``type_``. - Note that if the return type is X, then the callbacks registered for Optional[X] - will be returned also be returned. - Parameters ---------- type_ : type @@ -539,7 +545,7 @@ def type2callback(type_: type) -> list[ReturnCallback]: # look for direct hits ... # if it's an Optional, we need to look for the type inside the Optional - _, type_ = _is_optional(resolve_single_type(type_)) + type_ = resolve_single_type(type_) if type_ in _RETURN_CALLBACKS: return _RETURN_CALLBACKS[type_] @@ -550,10 +556,8 @@ def type2callback(type_: type) -> list[ReturnCallback]: return [] -def _is_optional(type_: Any) -> tuple[bool, type]: - # TODO: this function is too similar to _type_optional above... need to combine - if get_origin(type_) is Union: - args = get_args(type_) - if len(args) == 2 and any(_is_none_type(i) for i in args): - return True, next(i for i in args if not _is_none_type(i)) - return False, type_ +def _generate_union_variants(type_: Any) -> Iterator[type]: + type_args = get_args(type_) + for i in range(2, len(type_args) + 1): + for per in itertools.permutations(type_args, i): + yield cast(type, Union[per]) diff --git a/tests/test_magicgui.py b/tests/test_magicgui.py index 784eb540a..3344ded40 100644 --- a/tests/test_magicgui.py +++ b/tests/test_magicgui.py @@ -901,3 +901,24 @@ def func_optional(a: bool) -> ReturnType: mock.reset_mock() func_optional(a=False) mock.assert_called_once_with(func_optional, None, ReturnType) + + +@pytest.mark.parametrize("optional", [True, False]) +def test_no_duplication_call(optional): + mock = Mock() + mock2 = Mock() + + NewInt = NewType("NewInt", int) + register_type(Optional[NewInt], return_callback=mock) + register_type(NewInt, return_callback=mock) + register_type(NewInt, return_callback=mock2) + ReturnType = Optional[NewInt] if optional else NewInt + + @magicgui + def func() -> ReturnType: + return NewInt(1) + + func() + + mock.assert_called_once() + assert mock2.call_count == (not optional) From 49a1928000755696bc40337a9b680805ef4d2d4c Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 13 Mar 2023 18:12:47 +0100 Subject: [PATCH 2/5] add test checking that order is not important --- src/magicgui/type_map/_type_map.py | 2 +- tests/test_magicgui.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/magicgui/type_map/_type_map.py b/src/magicgui/type_map/_type_map.py index 66b037a2f..50c9d22f9 100644 --- a/src/magicgui/type_map/_type_map.py +++ b/src/magicgui/type_map/_type_map.py @@ -559,5 +559,5 @@ def type2callback(type_: type) -> list[ReturnCallback]: def _generate_union_variants(type_: Any) -> Iterator[type]: type_args = get_args(type_) for i in range(2, len(type_args) + 1): - for per in itertools.permutations(type_args, i): + for per in itertools.combinations(type_args, i): yield cast(type, Union[per]) diff --git a/tests/test_magicgui.py b/tests/test_magicgui.py index 3344ded40..8e98b0a73 100644 --- a/tests/test_magicgui.py +++ b/tests/test_magicgui.py @@ -4,7 +4,7 @@ import inspect from enum import Enum -from typing import NewType, Optional +from typing import NewType, Optional, Union from unittest.mock import Mock import pytest @@ -922,3 +922,16 @@ def func() -> ReturnType: mock.assert_called_once() assert mock2.call_count == (not optional) + + +def test_no_order(): + mock = Mock() + + register_type(Union[int, None], return_callback=mock) + + @magicgui + def func() -> Union[None, int]: + return 1 + + func() + mock.assert_called_once() From 8fe85574b463623da38b47e03cf34c7dd6eeada8 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 13 Mar 2023 18:13:17 +0100 Subject: [PATCH 3/5] remove obsolete file --- magicgui/_version.py | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 magicgui/_version.py diff --git a/magicgui/_version.py b/magicgui/_version.py deleted file mode 100644 index b921eea20..000000000 --- a/magicgui/_version.py +++ /dev/null @@ -1,4 +0,0 @@ -# file generated by setuptools_scm -# don't change, don't track in version control -__version__ = version = "0.5.2.dev30+ga5e272f.d20220824" -__version_tuple__ = version_tuple = (0, 5, 2, "dev30", "ga5e272f.d20220824") From 3d6f164e9fbaedf23106791c75b182f34dcbb412 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 27 Mar 2023 11:40:29 +0200 Subject: [PATCH 4/5] fix tests --- tests/conftest.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index df4ad5936..c45364c53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,3 +17,12 @@ def always_qapp(qapp): for w in qapp.topLevelWidgets(): w.close() w.deleteLater() + + +@pytest.fixture(autouse=True, scope="function") +def _clean_return_callbacks(): + from magicgui.type_map._type_map import _RETURN_CALLBACKS + + yield + + _RETURN_CALLBACKS.clear() From d0e0429ded1f17f1fc96426198e90443f8a046a4 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 27 Mar 2023 13:05:51 +0200 Subject: [PATCH 5/5] fix type_registered --- src/magicgui/type_map/_type_map.py | 127 ++++++++++++++++------------- tests/test_types.py | 35 ++++++++ 2 files changed, 106 insertions(+), 56 deletions(-) diff --git a/src/magicgui/type_map/_type_map.py b/src/magicgui/type_map/_type_map.py index 50c9d22f9..8c1d5691c 100644 --- a/src/magicgui/type_map/_type_map.py +++ b/src/magicgui/type_map/_type_map.py @@ -351,6 +351,65 @@ def _validate_return_callback(func: Callable) -> None: _T = TypeVar("_T", bound=type) +def _register_type_callback( + resolved_type: _T, + return_callback: ReturnCallback | None = None, +) -> list[type]: + modified_callbacks = [] + if return_callback is None: + return [] + _validate_return_callback(return_callback) + # if the type is a Union, add the callback to all of the types in the union + # (except NoneType) + if get_origin(resolved_type) is Union: + for type_per in _generate_union_variants(resolved_type): + if return_callback not in _RETURN_CALLBACKS[type_per]: + _RETURN_CALLBACKS[type_per].append(return_callback) + modified_callbacks.append(type_per) + + for t in get_args(resolved_type): + if not _is_none_type(t) and return_callback not in _RETURN_CALLBACKS[t]: + _RETURN_CALLBACKS[t].append(return_callback) + modified_callbacks.append(t) + elif return_callback not in _RETURN_CALLBACKS[resolved_type]: + _RETURN_CALLBACKS[resolved_type].append(return_callback) + modified_callbacks.append(resolved_type) + return modified_callbacks + + +def _register_widget( + resolved_type: _T, + widget_type: WidgetRef | None = None, + **options: Any, +) -> WidgetTuple | None: + _options = cast(dict, options) + + previous_widget = _TYPE_DEFS.get(resolved_type) + + if "choices" in _options: + _TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options) + if widget_type is not None: + warnings.warn( + "Providing `choices` overrides `widget_type`. Categorical widget " + f"will be used for type {resolved_type}", + stacklevel=2, + ) + elif widget_type is not None: + if not isinstance(widget_type, (str, WidgetProtocol)) and not ( + inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget) + ): + raise TypeError( + '"widget_type" must be either a string, WidgetProtocol, or ' + "Widget subclass" + ) + _TYPE_DEFS[resolved_type] = (widget_type, _options) + elif "bind" in _options: + # if we're binding a value to this parameter, it doesn't matter what type + # of ValueWidget is used... it usually won't be shown + _TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options) + return previous_widget + + @overload def register_type( type_: _T, @@ -420,51 +479,11 @@ def register_type( "must be provided." ) - def _deco(type_: _T) -> _T: - resolved_type = resolve_single_type(type_) - if return_callback is not None: - _validate_return_callback(return_callback) - # if the type is a Union, add the callback to all of the types in the union - # (except NoneType) - if get_origin(resolved_type) is Union: - for type_per in _generate_union_variants(resolved_type): - if return_callback not in _RETURN_CALLBACKS[type_per]: - _RETURN_CALLBACKS[type_per].append(return_callback) - - for t in get_args(resolved_type): - if ( - not _is_none_type(t) - and return_callback not in _RETURN_CALLBACKS[t] - ): - _RETURN_CALLBACKS[t].append(return_callback) - else: - if return_callback not in _RETURN_CALLBACKS[resolved_type]: - _RETURN_CALLBACKS[resolved_type].append(return_callback) - - _options = cast(dict, options) - - if "choices" in _options: - _TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options) - if widget_type is not None: - warnings.warn( - "Providing `choices` overrides `widget_type`. Categorical widget " - f"will be used for type {resolved_type}", - stacklevel=2, - ) - elif widget_type is not None: - if not isinstance(widget_type, (str, WidgetProtocol)) and not ( - inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget) - ): - raise TypeError( - '"widget_type" must be either a string, WidgetProtocol, or ' - "Widget subclass" - ) - _TYPE_DEFS[resolved_type] = (widget_type, _options) - elif "bind" in _options: - # if we're binding a value to this parameter, it doesn't matter what type - # of ValueWidget is used... it usually won't be shown - _TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options) - return type_ + def _deco(type__: _T) -> _T: + resolved_type = resolve_single_type(type__) + _register_type_callback(resolved_type, return_callback) + _register_widget(resolved_type, widget_type, **options) + return type__ return _deco if type_ is None else _deco(type_) @@ -500,23 +519,19 @@ def type_registered( """ resolved_type = resolve_single_type(type_) - # check if return_callback is already registered - rc_was_present = return_callback in _RETURN_CALLBACKS.get(resolved_type, []) # store any previous widget_type and options for this type - prev_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None) - resolved_type = register_type( - resolved_type, - widget_type=widget_type, - return_callback=return_callback, - **options, - ) + + revert_list = _register_type_callback(resolved_type, return_callback) + prev_type_def = _register_widget(resolved_type, widget_type, **options) + new_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None) try: yield finally: # restore things to before the context - if return_callback is not None and not rc_was_present: - _RETURN_CALLBACKS[resolved_type].remove(return_callback) + if return_callback is not None: # this if is only for mypy + for return_callback_type in revert_list: + _RETURN_CALLBACKS[return_callback_type].remove(return_callback) if _TYPE_DEFS.get(resolved_type, None) is not new_type_def: warnings.warn("Type definition changed during context", stacklevel=2) diff --git a/tests/test_types.py b/tests/test_types.py index 12f489b94..6eedaa886 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -184,3 +184,38 @@ def test_type_registered_warns(): register_type(Path, widget_type=widgets.TextEdit) assert isinstance(widgets.create_widget(annotation=Path), widgets.TextEdit) assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit) + + +def test_type_registered_optional_callbacks(): + assert not _RETURN_CALLBACKS[int] + assert not _RETURN_CALLBACKS[Optional[int]] + + @magicgui + def func1(a: int) -> int: + return a + + @magicgui + def func2(a: int) -> Optional[int]: + return a + + mock1 = Mock() + mock2 = Mock() + mock3 = Mock() + + register_type(int, return_callback=mock2) + + with type_registered(Optional[int], return_callback=mock1): + func1(1) + mock1.assert_called_once_with(func1, 1, int) + mock1.reset_mock() + func2(2) + mock1.assert_called_once_with(func2, 2, Optional[int]) + mock1.reset_mock() + mock2.assert_called_once_with(func1, 1, int) + assert _RETURN_CALLBACKS[int] == [mock2, mock1] + assert _RETURN_CALLBACKS[Optional[int]] == [mock1] + register_type(Optional[int], return_callback=mock3) + assert _RETURN_CALLBACKS[Optional[int]] == [mock1, mock3] + + assert _RETURN_CALLBACKS[Optional[int]] == [mock3] + assert _RETURN_CALLBACKS[int] == [mock2, mock3]