Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent dupe calls, alternative #546

Merged
merged 11 commits into from
Oct 4, 2023
137 changes: 78 additions & 59 deletions src/magicgui/type_map/_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import datetime
import inspect
import itertools
import os
import pathlib
import sys
Expand Down Expand Up @@ -366,6 +367,65 @@
_T = TypeVar("_T", bound=type)


def _register_type_callback(
resolved_type: _T,
return_callback: ReturnCallback | None = None,
) -> list[type]:
modified_callbacks = []
if return_callback is None:
return []
_validate_return_callback(return_callback)
# if the type is a Union, add the callback to all of the types in the union
# (except NoneType)
if get_origin(resolved_type) is Union:
for type_per in _generate_union_variants(resolved_type):
if return_callback not in _RETURN_CALLBACKS[type_per]:
_RETURN_CALLBACKS[type_per].append(return_callback)
modified_callbacks.append(type_per)

for t in get_args(resolved_type):
if not _is_none_type(t) and return_callback not in _RETURN_CALLBACKS[t]:
_RETURN_CALLBACKS[t].append(return_callback)
modified_callbacks.append(t)
elif return_callback not in _RETURN_CALLBACKS[resolved_type]:
_RETURN_CALLBACKS[resolved_type].append(return_callback)
modified_callbacks.append(resolved_type)
return modified_callbacks


def _register_widget(
resolved_type: _T,
widget_type: WidgetRef | None = None,
**options: Any,
) -> WidgetTuple | None:
_options = cast(dict, options)

previous_widget = _TYPE_DEFS.get(resolved_type)

if "choices" in _options:
_TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options)
if widget_type is not None:
warnings.warn(
"Providing `choices` overrides `widget_type`. Categorical widget "
f"will be used for type {resolved_type}",
stacklevel=2,
)
elif widget_type is not None:
if not isinstance(widget_type, (str, WidgetProtocol)) and not (
inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget)
):
raise TypeError(

Check warning on line 417 in src/magicgui/type_map/_type_map.py

View check run for this annotation

Codecov / codecov/patch

src/magicgui/type_map/_type_map.py#L417

Added line #L417 was not covered by tests
'"widget_type" must be either a string, WidgetProtocol, or '
"Widget subclass"
)
_TYPE_DEFS[resolved_type] = (widget_type, _options)
elif "bind" in _options:
# if we're binding a value to this parameter, it doesn't matter what type
# of ValueWidget is used... it usually won't be shown
_TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options)

Check warning on line 425 in src/magicgui/type_map/_type_map.py

View check run for this annotation

Codecov / codecov/patch

src/magicgui/type_map/_type_map.py#L425

Added line #L425 was not covered by tests
return previous_widget


@overload
def register_type(
type_: _T,
Expand Down Expand Up @@ -435,43 +495,11 @@
"must be provided."
)

def _deco(type_: _T) -> _T:
resolved_type = resolve_single_type(type_)
if return_callback is not None:
_validate_return_callback(return_callback)
# if the type is a Union, add the callback to all of the types in the union
# (except NoneType)
if get_origin(resolved_type) is Union:
for t in get_args(resolved_type):
if not _is_none_type(t):
_RETURN_CALLBACKS[t].append(return_callback)
else:
_RETURN_CALLBACKS[resolved_type].append(return_callback)

_options = cast(dict, options)

if "choices" in _options:
_TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options)
if widget_type is not None:
warnings.warn(
"Providing `choices` overrides `widget_type`. Categorical widget "
f"will be used for type {resolved_type}",
stacklevel=2,
)
elif widget_type is not None:
if not isinstance(widget_type, (str, WidgetProtocol)) and not (
inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget)
):
raise TypeError(
'"widget_type" must be either a string, WidgetProtocol, or '
"Widget subclass"
)
_TYPE_DEFS[resolved_type] = (widget_type, _options)
elif "bind" in _options:
# if we're binding a value to this parameter, it doesn't matter what type
# of ValueWidget is used... it usually won't be shown
_TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options)
return type_
def _deco(type__: _T) -> _T:
resolved_type = resolve_single_type(type__)
_register_type_callback(resolved_type, return_callback)
_register_widget(resolved_type, widget_type, **options)
return type__

return _deco if type_ is None else _deco(type_)

Expand Down Expand Up @@ -507,23 +535,19 @@
"""
resolved_type = resolve_single_type(type_)

# check if return_callback is already registered
rc_was_present = return_callback in _RETURN_CALLBACKS.get(resolved_type, [])
# store any previous widget_type and options for this type
prev_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None)
resolved_type = register_type(
resolved_type,
widget_type=widget_type,
return_callback=return_callback,
**options,
)

revert_list = _register_type_callback(resolved_type, return_callback)
prev_type_def = _register_widget(resolved_type, widget_type, **options)

new_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None)
try:
yield
finally:
# restore things to before the context
if return_callback is not None and not rc_was_present:
_RETURN_CALLBACKS[resolved_type].remove(return_callback)
if return_callback is not None: # this if is only for mypy
for return_callback_type in revert_list:
_RETURN_CALLBACKS[return_callback_type].remove(return_callback)

if _TYPE_DEFS.get(resolved_type, None) is not new_type_def:
warnings.warn("Type definition changed during context", stacklevel=2)
Expand All @@ -537,9 +561,6 @@
def type2callback(type_: type) -> list[ReturnCallback]:
"""Return any callbacks that have been registered for ``type_``.

Note that if the return type is X, then the callbacks registered for Optional[X]
will be returned also be returned.

Parameters
----------
type_ : type
Expand All @@ -555,7 +576,7 @@

# look for direct hits ...
# if it's an Optional, we need to look for the type inside the Optional
_, type_ = _is_optional(resolve_single_type(type_))
type_ = resolve_single_type(type_)
if type_ in _RETURN_CALLBACKS:
return _RETURN_CALLBACKS[type_]

Expand All @@ -566,10 +587,8 @@
return []


def _is_optional(type_: Any) -> tuple[bool, type]:
# TODO: this function is too similar to _type_optional above... need to combine
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and any(_is_none_type(i) for i in args):
return True, next(i for i in args if not _is_none_type(i))
return False, type_
def _generate_union_variants(type_: Any) -> Iterator[type]:
tlambert03 marked this conversation as resolved.
Show resolved Hide resolved
type_args = get_args(type_)
for i in range(2, len(type_args) + 1):
for per in itertools.combinations(type_args, i):
yield cast(type, Union[per])
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def always_qapp(qapp):
for w in qapp.topLevelWidgets():
w.close()
w.deleteLater()


@pytest.fixture(autouse=True, scope="function")
def _clean_return_callbacks():
from magicgui.type_map._type_map import _RETURN_CALLBACKS

yield

_RETURN_CALLBACKS.clear()
36 changes: 35 additions & 1 deletion tests/test_magicgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from enum import Enum
from typing import NewType, Optional
from typing import NewType, Optional, Union
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -901,3 +901,37 @@ def func_optional(a: bool) -> ReturnType:
mock.reset_mock()
func_optional(a=False)
mock.assert_called_once_with(func_optional, None, ReturnType)


@pytest.mark.parametrize("optional", [True, False])
def test_no_duplication_call(optional):
mock = Mock()
mock2 = Mock()

NewInt = NewType("NewInt", int)
register_type(Optional[NewInt], return_callback=mock)
register_type(NewInt, return_callback=mock)
register_type(NewInt, return_callback=mock2)
ReturnType = Optional[NewInt] if optional else NewInt

@magicgui
def func() -> ReturnType:
return NewInt(1)

func()

mock.assert_called_once()
assert mock2.call_count == (not optional)


def test_no_order():
mock = Mock()

register_type(Union[int, None], return_callback=mock)

@magicgui
def func() -> Union[None, int]:
return 1

func()
mock.assert_called_once()
35 changes: 35 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,41 @@ def test_type_registered_warns():
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)


def test_type_registered_optional_callbacks():
assert not _RETURN_CALLBACKS[int]
assert not _RETURN_CALLBACKS[Optional[int]]

@magicgui
def func1(a: int) -> int:
return a

@magicgui
def func2(a: int) -> Optional[int]:
return a

mock1 = Mock()
mock2 = Mock()
mock3 = Mock()

register_type(int, return_callback=mock2)

with type_registered(Optional[int], return_callback=mock1):
func1(1)
mock1.assert_called_once_with(func1, 1, int)
mock1.reset_mock()
func2(2)
mock1.assert_called_once_with(func2, 2, Optional[int])
mock1.reset_mock()
mock2.assert_called_once_with(func1, 1, int)
assert _RETURN_CALLBACKS[int] == [mock2, mock1]
assert _RETURN_CALLBACKS[Optional[int]] == [mock1]
register_type(Optional[int], return_callback=mock3)
assert _RETURN_CALLBACKS[Optional[int]] == [mock1, mock3]

assert _RETURN_CALLBACKS[Optional[int]] == [mock3]
assert _RETURN_CALLBACKS[int] == [mock2, mock3]


def test_pick_widget_literal():
from typing import Literal

Expand Down