diff --git a/instruct/__init__.py b/instruct/__init__.py index f9d7f81..0864be6 100644 --- a/instruct/__init__.py +++ b/instruct/__init__.py @@ -9,6 +9,7 @@ import sys import types import typing +from contextlib import suppress from base64 import urlsafe_b64encode from collections.abc import ( @@ -41,6 +42,8 @@ Type, TYPE_CHECKING, Union, + TypeVar, + Generic, ) from weakref import WeakValueDictionary @@ -59,7 +62,7 @@ ) from .typing import T, CellType, CoerceMapping, NoneType from .types import FrozenMapping, ReadOnly, AttrsDict, ClassOrInstanceFuncsDescriptor -from .utils import flatten_fields +from .utils import flatten_fields, invert_mapping from .subtype import wrapper_for_type from .exceptions import ( OrphanedListenersError, @@ -1139,12 +1142,13 @@ class Atomic(type): if TYPE_CHECKING: _data_class: Atomic _columns: Mapping[str, Any] + _slots: Mapping[str, type] _column_types: Mapping[str, Union[Type, Tuple[Type, ...]]] _all_coercions: Mapping[str, Tuple[Union[Type, Tuple[Type, ...]], Callable]] _support_columns: Tuple[str, ...] - _properties: typing.collections.KeysView[str] + _properties: typing.KeysView[str] _configuration: AttrsDict - _all_accessible_fields: typing.collections.KeysView[str] + _all_accessible_fields: typing.KeysView[str] # i.e. key -> List[Union[AtomicDerived, bool]] means key can hold an Atomic derived type. _nested_atomic_collection_keys: Mapping[str, Tuple[Atomic, ...]] @@ -1183,12 +1187,49 @@ def __and__( ) return self - skip_fields + def __getitem__(self, key): + if not isinstance(key, tuple): + new_params = (key,) + else: + new_params = key + if self.__parameters__: + assert len(new_params) == len(self.__parameters__) + param_mapping = {p: n for p, n in zip(self.__parameters__, new_params)} + new_annotations = {} + complex_fields = [] + typehints = get_type_hints(self) + for field_name in self.__parameters_by_field__: + typevar_params = self.__parameters_by_field__[field_name] + typehint = typehints[field_name] + if isinstance(typehint, TypeVar): + new_annotations[field_name] = param_mapping[typehint] + else: + complex_fields.append(field_name) + + for attr_name in self.__parameters_by_field__.keys() & frozenset(complex_fields): + typevar_params = self.__parameters_by_field__[attr_name] + attr_type_args = [] + for new_type, typevar in zip(new_params, typevar_params): + attr_type_args.append(new_type) + type_genericable = self._columns[attr_name] + new_annotations[attr_name] = type_genericable[tuple(attr_type_args)] + return type( + self.__name__, + (self - frozenset(new_annotations.keys()),), + { + "__annotations__": new_annotations, + "__args__": new_params, + "__parameters__": self.__parameters__, + }, + ) + raise AttributeError(key) + def __sub__(self: Atomic, skip_fields: Union[Mapping[str, Any], Iterable[Any]]) -> Atomic: assert isinstance(skip_fields, (list, frozenset, set, tuple, dict, str, FrozenMapping)) debug_mode = is_debug_mode("skip") root_class = type(self) - cls = public_class(self) + cls: Type[Atomic] = public_class(self) if not skip_fields: return self @@ -1294,7 +1335,7 @@ def __new__( concrete_class=False, skip_fields=FrozenMapping(), include_fields=FrozenMapping(), - **mixins, + **mixins: Any, ): if concrete_class: attrs["_is_data_class"] = ReadOnly(True) @@ -1312,7 +1353,7 @@ def __new__( if include_fields and skip_fields: raise TypeError("Cannot specify both include_fields and skip_fields!") data_class_attrs = {} - base_class_functions = [] + pending_base_class_functions = [] # Move overrides to the data class, # so we call them first, then the codegen pieces. # Suitable for a single level override. @@ -1328,15 +1369,20 @@ def __new__( ): if key in attrs: if hasattr(attrs[key], "_instruct_base_cls"): - base_class_functions.append(key) + pending_base_class_functions.append(key) continue data_class_attrs[key] = attrs.pop(key) - base_class_functions = tuple(base_class_functions) + base_class_functions = tuple(pending_base_class_functions) support_cls_attrs = attrs del attrs if "__slots__" not in support_cls_attrs and "__annotations__" in support_cls_attrs: - module = import_module(support_cls_attrs["__module__"]) + try: + module = import_module(support_cls_attrs["__module__"]) + except KeyError: + globalns = typing.__dict__ + else: + globalns = ChainMap(module.__dict__, typing.__dict__) kwargs = {} if _GET_TYPEHINTS_ALLOWS_EXTRA: kwargs["include_extras"] = True @@ -1345,7 +1391,7 @@ def __new__( support_cls_attrs, # First look in the module, then failsafe to the typing to support # unimported 'Optional', et al - ChainMap(module.__dict__, typing.__dict__), + globalns, **kwargs, ) support_cls_attrs["__slots__"] = hints @@ -1358,7 +1404,7 @@ def __new__( coerce_mappings: Optional[CoerceMapping] = None if "__coerce__" in support_cls_attrs: if support_cls_attrs["__coerce__"] is not None: - coerce_mappings: AbstractMapping = support_cls_attrs["__coerce__"] + coerce_mappings = support_cls_attrs["__coerce__"] if isinstance(coerce_mappings, ReadOnly): # Unwrap coerce_mappings = coerce_mappings.value @@ -1394,7 +1440,7 @@ def __new__( if fast is None: fast = not __debug__ - combined_columns = {} + combined_columns: Dict[Type, Type] = {} combined_slots = {} nested_atomic_collections: Dict[str, Atomic] = {} # Mapping of public name -> custom type vector for `isinstance(...)` checks! @@ -1524,17 +1570,32 @@ def __new__( derived_classes = {} current_class_columns = {} current_class_slots = {} - for key, value in support_cls_attrs["__slots__"].items(): - if isinstance(value, dict): - value = type("{}".format(key.capitalize()), bases, {"__slots__": value}) - derived_classes[key] = value - if not ismetasubclass(value, Atomic): - nested_atomics = tuple(find_class_in_definition(value, Atomic, metaclass=True)) + avail_generics = () + generics_by_field = {} + for key, typehint_or_anonymous_struct_decl in support_cls_attrs["__slots__"].items(): + if isinstance(typehint_or_anonymous_struct_decl, dict): + anonymous_struct_decl = typehint_or_anonymous_struct_decl + typehint = type( + "{}".format(key.capitalize()), bases, {"__slots__": anonymous_struct_decl} + ) + derived_classes[key] = typehint + else: + typehint = typehint_or_anonymous_struct_decl + del typehint_or_anonymous_struct_decl + if not ismetasubclass(typehint, Atomic): + nested_atomics = tuple(find_class_in_definition(typehint, Atomic, metaclass=True)) if nested_atomics: nested_atomic_collections[key] = nested_atomics del nested_atomics - current_class_slots[key] = combined_slots[key] = value - current_class_columns[key] = combined_columns[key] = parse_typedef(value) + nested_generics = tuple(find_class_in_definition(typehint, TypeVar)) + if nested_generics: + for item in nested_generics: + if item not in avail_generics: + avail_generics = (*avail_generics, item) + generics_by_field[key] = nested_generics + del nested_generics + current_class_slots[key] = combined_slots[key] = typehint + current_class_columns[key] = combined_columns[key] = parse_typedef(typehint) no_op_skip_keys = [] if skip_fields: @@ -1569,6 +1630,9 @@ def __new__( no_op_skip_keys.append(key) del current_class_slots[key] del current_class_columns[key] + # ARJ: https://stackoverflow.com/a/54497260 + # if avail_generics and Generic not in bases: + # bases = (*bases[:-1], *bases[-1], Generic[avail_generics]) # Gather listeners: listeners, post_coerce_failure_handlers = gather_listeners( @@ -1818,6 +1882,10 @@ def __new__( support_cls_attrs["_listener_funcs"] = ReadOnly(listeners) # Ensure public class has zero slots! support_cls_attrs["__slots__"] = () + if avail_generics: + support_cls_attrs["__parameters__"] = tuple(avail_generics) + support_cls_attrs["__parameters_by_field__"] = ReadOnly(generics_by_field) + support_cls_attrs["__parameter_fields__"] = ReadOnly(invert_mapping(generics_by_field)) support_cls_attrs["_data_class"] = support_cls_attrs[f"_{class_name}"] = dc = ReadOnly(None) support_cls_attrs["_parent"] = parent_cell = ReadOnly(None) diff --git a/instruct/typedef.py b/instruct/typedef.py index 1c3351a..46febde 100644 --- a/instruct/typedef.py +++ b/instruct/typedef.py @@ -3,8 +3,15 @@ from functools import wraps import types import sys -from collections.abc import Mapping as AbstractMapping -from typing import Union, Any, AnyStr, List, Tuple, cast, Optional, Callable, Type +import typing +from collections.abc import ( + Mapping as AbstractMapping, + Sequence as AbstractSequence, + Iterable as AbstractIterable, +) +from typing import Union, Any, AnyStr, List, Tuple, cast, Optional, Callable, Type, TypeVar +from contextlib import suppress +from weakref import WeakKeyDictionary try: from typing import Literal @@ -23,43 +30,98 @@ from .constants import Range from .exceptions import RangeError -LOWER_THAN_310 = sys.version_info < (3, 10) - -if LOWER_THAN_310: - get_origin = _get_origin +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack else: + from typing_extensions import TypeVarTuple, Unpack + +if sys.version_info >= (3, 10): + from typing import ParamSpec + from types import UnionType + + UnionTypes = (Union, UnionType) + # patch get_origin to always return a Union over a 'a | b' def get_origin(cls): t = _get_origin(cls) - if isinstance(t, type) and issubclass(t, types.UnionType): - return Union[cls.__args__] + if isinstance(t, type) and issubclass(t, UnionType): + return Union return t +else: + UnionTypes = (Union,) + get_origin = _get_origin + + +def is_union_typedef(t) -> bool: + return _get_origin(t) in UnionTypes + + +_abstract_custom_types = WeakKeyDictionary() + + +class AbstractWrappedMeta(type): + __slots__ = () + + def set_name(t, name): + return _abstract_custom_types[t][0](name) + + def set_type(t, type): + return _abstract_custom_types[t][1](type) + -def make_custom_typecheck(func) -> Type[ICustomTypeCheck]: - """Create a custom type that will turn `isinstance(item, klass)` into `func(item)`""" +def is_abstract_custom_typecheck(o: Type) -> bool: + cls_type = type(o) + return isinstance(cls_type, AbstractWrappedMeta) + + +def make_custom_typecheck(typecheck_func, typedef=None) -> Type[ICustomTypeCheck]: + """Create a custom type that will turn `isinstance(item, klass)` into `typecheck_func(item)`""" typename = "WrappedType<{}>" + bound_type = None + registry = _abstract_custom_types - class WrappedType(type): + class WrappedTypeMeta(AbstractWrappedMeta): __slots__ = () def __instancecheck__(self, instance): - return func(instance) + return typecheck_func(instance) def __repr__(self): return typename.format(super().__repr__()) - class _WrappedType(metaclass=WrappedType): + def __getitem__(self, key): + if bound_type is None: + raise AttributeError(key) + return parse_typedef(bound_type[key]) + + class _WrappedType(metaclass=WrappedTypeMeta): __slots__ = () - @staticmethod - def set_name(name): - nonlocal typename - typename = name - _WrappedType.__name__ = name - _WrappedType._name__ = name - return name + def __class_getitem__(self, key): + if bound_type is None: + raise AttributeError(key) + return parse_typedef(bound_type[key]) + + def __new__(cls, *args): + nonlocal bound_type + if bound_type is None: + raise TypeError("Unbound or abstract type!") + return bound_type(*args) + + def set_name(name): + nonlocal typename + typename = name + _WrappedType.__name__ = name + _WrappedType._name__ = name + return name + + def set_type(t): + nonlocal bound_type + bound_type = t + return t + registry.setdefault(_WrappedType, (set_name, set_type)) return cast(Type[ICustomTypeCheck], _WrappedType) @@ -128,6 +190,12 @@ def find_class_in_definition( or isinstance(type_hints, type) ), f"{type_hints} is a {type(type_hints)}" + test_func = lambda child: isinstance(child, type) and issubormetasubclass( + child, root_cls, metaclass=metaclass + ) + if issubclass(root_cls, TypeVar): + test_func = lambda child: isinstance(child, TypeVar) + if is_typing_definition(type_hints): type_cls: Type = cast(Type, type_hints) origin_cls = get_origin(type_cls) @@ -140,9 +208,7 @@ def find_class_in_definition( type_cls_copied: bool = False if origin_cls is Union: for index, child in enumerate(args): - if isinstance(child, type) and issubormetasubclass( - child, root_cls, metaclass=metaclass - ): + if test_func(child): replacement = yield child else: replacement = yield from find_class_in_definition( @@ -161,9 +227,7 @@ def find_class_in_definition( ): if issubclass(origin_cls, collections.abc.Mapping): key_type, value_type = args - if isinstance(value_type, type) and issubormetasubclass( - value_type, root_cls, metaclass=metaclass - ): + if test_func(value_type): replacement = yield value_type else: replacement = yield from find_class_in_definition( @@ -176,9 +240,7 @@ def find_class_in_definition( type_cls_copied = True else: for index, child in enumerate(args): - if isinstance(child, type) and issubormetasubclass( - child, root_cls, metaclass=metaclass - ): + if test_func(child): replacement = yield child else: replacement = yield from find_class_in_definition( @@ -190,35 +252,40 @@ def find_class_in_definition( if args != get_args(type_cls): type_cls = type_cls.copy_with(args) type_cls_copied = True + elif test_func(type_cls): + replacement = yield type_cls + if replacement is not None: + return type_cls + return None + if type_cls_copied: return type_cls return None - if isinstance(type_hints, type): - if issubormetasubclass(type_hints, root_cls, metaclass=metaclass): - replacement = yield type_hints - if replacement is not None: - return replacement + elif test_func(type_hints): + replacement = yield type_hints + if replacement is not None: + return replacement return None - for index, type_cls in enumerate(type_hints[:]): - if isinstance(type_cls, type) and issubormetasubclass( - type_cls, root_cls, metaclass=metaclass - ): - replacement = yield type_cls - else: - replacement = yield from find_class_in_definition( - type_cls, root_cls, metaclass=metaclass - ) - if replacement is not None: - type_hints = type_hints[:index] + (replacement,) + type_hints[index + 1 :] - return type_hints + elif isinstance(type_hints, (tuple, list)): + for index, type_cls in enumerate(type_hints[:]): + if test_func(type_cls): + replacement = yield type_cls + else: + replacement = yield from find_class_in_definition( + type_cls, root_cls, metaclass=metaclass + ) + if replacement is not None: + type_hints = type_hints[:index] + (replacement,) + type_hints[index + 1 :] + return type_hints def create_custom_type(container_type, *args, check_ranges=()): if is_typing_definition(container_type): + origin_cls = get_origin(container_type) if hasattr(container_type, "_name") and container_type._name is None: - if container_type.__origin__ is Union: + if origin_cls is Union: types = flatten( (create_custom_type(arg) for arg in container_type.__args__), eager=True ) @@ -251,7 +318,7 @@ def test_func(value) -> bool: """ return isinstance(value, types) - elif container_type.__origin__ is Literal: + elif origin_cls is Literal: from . import Atomic, public_class def test_func(value) -> bool: @@ -301,13 +368,18 @@ def test_func(value) -> bool: return (str, bytes) elif container_type is Any: return object - elif isinstance(getattr(container_type, "__origin__", None), type) and ( - issubclass(container_type.__origin__, collections.abc.Iterable) - and issubclass(container_type.__origin__, collections.abc.Container) + elif origin_cls is not None and ( + issubclass(origin_cls, collections.abc.Iterable) + and issubclass(origin_cls, collections.abc.Container) ): return parse_typedef(container_type) + elif isinstance(container_type, TypeVar): + + def test_func(value) -> bool: + return False + else: - raise NotImplementedError(container_type, container_type._name) + raise NotImplementedError(container_type, repr(container_type)) elif isinstance(container_type, type) and ( issubclass(container_type, collections.abc.Iterable) and issubclass(container_type, collections.abc.Container) @@ -340,6 +412,7 @@ def test_func(value) -> bool: else: assert isinstance(container_type, tuple), f"container_type is {container_type}" if check_ranges: + # materialized = issubclass(container_type, collections.abc.Container) and issubclass(container_type, collections.abc.Iterable) def test_func(value): if not isinstance(value, container_type): @@ -365,10 +438,20 @@ def test_func(value): def test_func(value): return isinstance(value, container_type) - return make_custom_typecheck(test_func) + t = make_custom_typecheck(test_func) + AbstractWrappedMeta.set_type(t, container_type) + return t -def create_typecheck_container(container_type, items: Tuple[Any]): +def is_variable_tuple(decl) -> bool: + origin_cls = get_origin(decl) + with suppress(ValueError): + t, ellipsis = get_args(decl) + return origin_cls is tuple and ellipsis is Ellipsis + return False + + +def create_typecheck_container(container_type, items): test_types = [] test_func: Optional[Callable[[Any], bool]] = None @@ -425,12 +508,12 @@ def test_func(mapping) -> bool: if items: for some_type in items: test_types.append(create_custom_type(some_type)) - test_types = tuple(test_types) + types = tuple(test_types) def test_func(value): if not isinstance(value, container_type): return False - return all(isinstance(item, test_types) for item in value) + return all(isinstance(item, types) for item in value) else: @@ -441,16 +524,26 @@ def test_func(value): def is_typing_definition(item): - module_name: str = getattr(item, "__module__", None) - if module_name in ("typing", "typing_extensions"): - return True - if module_name == "builtins": - origin = get_origin(item) - if origin is not None: - return is_typing_definition(origin) - if not LOWER_THAN_310: - if isinstance(item, (types.UnionType,)): + with suppress(AttributeError): + cls_module = type(item).__module__ + if cls_module in ("typing", "typing_extensions") or cls_module.startswith( + ("typing.", "typing_extensions.") + ): return True + with suppress(AttributeError): + module = item.__module__ + if module in ("typing", "typing_extensions") or module.startswith( + ("typing.", "typing_extensions.") + ): + return True + if isinstance(item, (TypeVar, TypeVarTuple, ParamSpec)): + return True + origin = get_origin(item) + args = get_args(item) + if origin is not None: + return True + elif args: + return True return False @@ -482,8 +575,6 @@ def parse_typedef( metaclass that executes an embedded function for checking if all members of the collection is the right type. i.e all(isintance(item, int) for item in object) """ - if type(typedef) is tuple or type(typedef) is list: - return tuple(parse_typedef(x) for x in typedef) if not is_typing_definition(typedef): # ARJ: Okay, we're not a typing module descendant. @@ -491,11 +582,18 @@ def parse_typedef( if isinstance(typedef, type): if check_ranges: return create_custom_type(typedef, check_ranges=check_ranges) - else: - return typedef - raise NotImplementedError(f"Unknown typedef definition {typedef!r} ({type(typedef)})!") + return typedef + elif isinstance(typedef, tuple): + with suppress(ValueError): + (single_type,) = typedef + return parse_typedef(single_type) + return tuple(parse_typedef(decl) for decl in typedef) + else: + raise NotImplementedError(f"Unknown typedef definition {typedef!r} ({type(typedef)})!") as_origin_cls = get_origin(typedef) + args = get_args(typedef) + if typedef is AnyStr: return create_custom_type(typedef, check_ranges=check_ranges) elif typedef is Any: @@ -503,16 +601,17 @@ def parse_typedef( elif typedef is Union: raise TypeError("A bare union means nothing!") elif as_origin_cls is Annotated: - typedef, *raw_metadata = get_args(typedef) + typedef, *raw_metadata = args # Skip to the internal type: # flags = [] - check_ranges = [] + p_check_ranges = [] for annotation in raw_metadata: if isinstance(annotation, Range): - check_ranges.append(annotation) + p_check_ranges.append(annotation) # elif (getattr(annotation, '__module__', '') or '').startswith('instruct.constants'): # flags.append(annotation) - check_ranges = tuple(check_ranges) + check_ranges = tuple(p_check_ranges) + del p_check_ranges new_type = parse_typedef(typedef, check_ranges=check_ranges) if check_ranges: if is_typing_definition(typedef): @@ -523,7 +622,8 @@ def parse_typedef( new_name = new_name[len("typing.") :] else: new_name = typedef.__name__ - new_type.set_name(new_name) + AbstractWrappedMeta.set_name(new_type, new_name) + AbstractWrappedMeta.set_type(new_type, typedef) return new_type elif as_origin_cls is Union: args = get_args(typedef) @@ -543,26 +643,31 @@ def parse_typedef( new_type = create_custom_type(typedef, arg) if isinstance(arg, str): arg = f'"{arg}"' - new_type.set_name(f"{arg!s}") + AbstractWrappedMeta.set_name(new_type, f"{arg!s}") items.append(new_type) return tuple(items) - elif as_origin_cls is not None: + elif as_origin_cls is not None or isinstance(typedef, TypeVar): if is_typing_definition(typedef) and hasattr(typedef, "_name") and typedef._name is None: # special cases! raise NotImplementedError( f"The type definition for {typedef} is not supported, report as an issue." ) args = get_args(typedef) - if args: - cls = create_custom_type(as_origin_cls, *args, check_ranges=check_ranges) + if args or as_origin_cls is None: + if as_origin_cls is not None: + cls = create_custom_type(as_origin_cls, *args, check_ranges=check_ranges) + else: + cls = create_custom_type(typedef, check_ranges=check_ranges) new_name = str(typedef) if new_name.startswith(("typing_extensions.")): new_name = new_name[len("typing_extensions.") :] if new_name.startswith(("typing.")): new_name = new_name[len("typing.") :] - cls.set_name(new_name) + AbstractWrappedMeta.set_name(cls, new_name) + AbstractWrappedMeta.set_type(cls, typedef) return cls return as_origin_cls + raise NotImplementedError( - f"The type definition for {typedef!r} is not supported yet, report as an issue." + f"The type definition for {typedef!r} ({type(typedef)}) is not supported yet, report as an issue." ) diff --git a/instruct/utils.py b/instruct/utils.py index 208a60d..7f7898c 100644 --- a/instruct/utils.py +++ b/instruct/utils.py @@ -6,6 +6,16 @@ from .types import FrozenMapping +def invert_mapping(mapping): + inverted = {} + for key, value in mapping.items(): + try: + inverted[value].append(key) + except KeyError: + inverted[value] = [key] + return {key: tuple(value) for key, value in inverted.items()} + + def support_eager_eval(func): @functools.wraps(func) def wrapper(*args, eager=False, **kwargs): @@ -57,7 +67,6 @@ def flatten_fields(item: Union[Mapping[str, Any], Iterable[Union[str, Iterable[A elif isinstance(item, (AbstractIterable, AbstractMapping)) and not isinstance( item, (bytearray, bytes) ): - is_mapping = False if isinstance(item, AbstractMapping): iterable = ((key, item[key]) for key in item) diff --git a/tests/test_atomic.py b/tests/test_atomic.py index c4d3906..de345cf 100644 --- a/tests/test_atomic.py +++ b/tests/test_atomic.py @@ -1,13 +1,15 @@ import json import pprint import sys -from typing import Union, List, Tuple, Optional, Dict, Any, Type +from typing import Union, List, Tuple, Optional, Dict, Any, Type, TypeVar try: from typing import Annotated except ImportError: from typing_extensions import Annotated +from typing_extensions import get_type_hints + from enum import Enum import datetime import base64 @@ -36,6 +38,21 @@ ) +def test_simple() -> None: + class Data(SimpleBase): + foo: int + bar: str + baz: Dict[str, Any] + cool: Annotated[int, "is this cool?", "yes"] + + assert get_type_hints(Data) == { + "foo": int, + "bar": str, + "baz": Dict[str, Any], + "cool": int, + } + + class Data(Base, history=True): __slots__ = {"field": Union[str, int], "other": str} @@ -1530,3 +1547,21 @@ class BreakChainBar(BarBar): ... assert len(Registry) == 2 + + +def test_simple_generics(): + T = TypeVar("T") + assert isinstance(T, TypeVar) + + class Foo(SimpleBase): + field: T + + assert isinstance(Foo._slots["field"], TypeVar) + assert Foo.__parameters__ == (T,) + + # any_instance = Foo(None) + # assert any_instance.field is None + + cls = Foo[int] + assert isinstance(cls(1).field, int) + assert get_type_hints(cls)["field"] is int diff --git a/tests/test_typedef.py b/tests/test_typedef.py index 968d03c..c298480 100644 --- a/tests/test_typedef.py +++ b/tests/test_typedef.py @@ -5,6 +5,9 @@ make_custom_typecheck, has_collect_class, find_class_in_definition, + is_typing_definition, + get_args, + issubormetasubclass, ) from instruct import Base, Atomic from typing import List, Union, AnyStr, Any, Optional, Generic, TypeVar, Tuple, FrozenSet, Set, Dict @@ -18,7 +21,6 @@ def test_parse_typedef(): - custom_type = make_custom_typecheck(lambda val: val == 3) assert isinstance(3, custom_type) assert not isinstance("a", custom_type) @@ -162,3 +164,35 @@ class Bar(Base): type_hints = Optional[Bar] items = tuple(find_class_in_definition(type_hints, Atomic, metaclass=True)) assert items == (Bar,) + + +def test_parse_typedef_generics(): + T = TypeVar("T") + assert is_typing_definition(T) + assert tuple(find_class_in_definition((T,), TypeVar)) == (T,) + assert tuple(find_class_in_definition(T, TypeVar)) == (T,) + U = TypeVar("U") + ListOfT = List[T] + ListOfTint = ListOfT[int] + + assert not isinstance("any", parse_typedef(T)) + assert not isinstance([1, 2, 3], parse_typedef(ListOfT)) + assert not isinstance("any", parse_typedef(ListOfTint)) + assert not isinstance([None, 2, 3], parse_typedef(ListOfTint)) + + assert isinstance([1, 2, 3], parse_typedef(ListOfTint)) + + cls = parse_typedef(ListOfT) + assert callable(cls) + assert isinstance(cls, type) + cls_int = cls[int] + assert isinstance([1, 2, 3], cls_int) + + SomeDictGeneric = Dict[T, U] + SomeGenericSeq = Tuple[T, U] + assert not isinstance({"1": 1}, parse_typedef(SomeDictGeneric)) + DictStrInt = SomeDictGeneric[str, int] + assert isinstance({"1": 1}, parse_typedef(DictStrInt)) + assert isinstance((1, "str"), parse_typedef(SomeGenericSeq[int, str])) + assert not isinstance((1, 1), parse_typedef(SomeGenericSeq[int, str])) + assert not isinstance((1, 1), parse_typedef(Tuple[int, str]))