From c9588c90cc677cf32713c270ae851b6108eea63d Mon Sep 17 00:00:00 2001 From: Autumn Date: Mon, 17 Jun 2024 12:30:56 -0700 Subject: [PATCH] [*] all backport --- instruct/__init__.py | 440 ++++++++------ instruct/templates/data_class.jinja | 3 +- instruct/typedef.py | 896 +++++++++++++++++++++------- instruct/typing.py | 29 +- tests/test_atomic.py | 86 ++- tests/test_typedef.py | 33 + 6 files changed, 1078 insertions(+), 409 deletions(-) diff --git a/instruct/__init__.py b/instruct/__init__.py index d1cee3d..ab61244 100644 --- a/instruct/__init__.py +++ b/instruct/__init__.py @@ -739,7 +739,7 @@ def _set_defaults(self): """.strip() -def make_defaults(fields, defaults_var_template): +def make_defaults(fields: Tuple[str, ...], defaults_var_template: str): defaults_var_template = env.from_string(defaults_var_template).render(fields=fields) code = env.from_string(DEFAULTS_FRAGMENT).render(item=defaults_var_template) return code @@ -1472,7 +1472,7 @@ def apply_skip_keys( def list_callables(cls): for key in dir(cls): - value = inspect.getattr_static(cls, key) + value = getattr(cls, key) if isinstance(value, type): continue elif not callable(value): @@ -1711,17 +1711,18 @@ def __sub__(self: IAtomic, skip: Union[Mapping[str, Any], Iterable[Any]]) -> Typ return new_cls def __new__( - klass, - class_name, - bases, - attrs, + klass: Type[AtomicMeta], + class_name: str, + bases: Tuple[Union[Type, Type[Atomic]], ...], + attrs: Dict[str, Any], *, - fast=None, - concrete_class=False, - skip_fields=FrozenMapping(), - include_fields=FrozenMapping(), - **mixins: Any, - ): + fast: Optional[bool] = None, + concrete_class: bool = False, + # metadata: + skip_fields: FrozenMapping = FrozenMapping(), + include_fields: FrozenMapping = FrozenMapping(), + **mixins: bool, + ) -> Type[Atomic]: if concrete_class: parent_has_hash = None if "__hash__" not in attrs: @@ -1753,7 +1754,7 @@ def __new__( if include_fields and skip_fields: raise TypeError("Cannot specify both include_fields and skip_fields!") data_class_attrs = {} - pending_base_class_functions = [] + pending_base_class_funcs = [] # Move overrides to the data class, # so we call them first, then the codegen pieces. # Suitable for a single level override. @@ -1770,10 +1771,11 @@ def __new__( ): if key in attrs: if hasattr(attrs[key], "_instruct_base_cls"): - pending_base_class_functions.append(key) + pending_base_class_funcs.append(key) continue data_class_attrs[key] = attrs.pop(key) - base_class_functions = tuple(pending_base_class_functions) + base_class_functions = tuple(pending_base_class_funcs) + assert isinstance(attrs, dict) support_cls_attrs = attrs del attrs @@ -1820,17 +1822,17 @@ def __new__( ) coerce_mappings = dict(unpack_coerce_mappings(coerce_mappings)) - if not isinstance(support_cls_attrs["__coerce__"], ImmutableMapping): - support_cls_attrs["__coerce__"] = ImmutableMapping(coerce_mappings) + if not isinstance(support_cls_attrs["__coerce__"], ImmutableMapping): + support_cls_attrs["__coerce__"] = ImmutableMapping(coerce_mappings) - coerce_mappings = cast(CoerceMapping, coerce_mappings) + coerce_mappings = cast(CoerceMapping, coerce_mappings) # A support column is a __slot__ element that is unmanaged. - support_columns = [] + pending_support_columns: List[str] = [] if isinstance(support_cls_attrs["__slots__"], tuple): # Classes with tuples in them are assumed to be # data class definitions (i.e. supporting things like a change log) - support_columns.extend(support_cls_attrs["__slots__"]) + pending_support_columns.extend(support_cls_attrs["__slots__"]) support_cls_attrs["__slots__"] = FrozenMapping() if not isinstance(support_cls_attrs["__slots__"], AbstractMapping): @@ -1843,19 +1845,25 @@ def __new__( if fast is None: fast = not __debug__ + combined_slots: Dict[str, TypeHint] + nested_atomic_collections: Dict[str, Union[Type[Atomic], Tuple[Type[Atomic], ...]]] combined_columns: Dict[Type, Type] = {} + combined_slots = {} - nested_atomic_collections: Dict[str, AtomicMeta] = {} + nested_atomic_collections = {} # Mapping of public name -> custom type vector for `isinstance(...)` checks! column_types: Dict[str, Union[Type, Tuple[Type, ...]]] = {} base_class_has_subclass_init = False + cls: Union[type, Type[Atomic]] + for cls in bases: if cls is object: break base_class_has_subclass_init = hasattr(cls, "__init_subclass__") if base_class_has_subclass_init: break + del cls init_subclass_kwargs = {} @@ -1895,9 +1903,12 @@ def __new__( if "__extra_slots__" in support_cls_attrs: pending_extra_slots.extend(support_cls_attrs["__extra_slots__"]) # Base class inherited items: + inherited_listeners: Dict[str, List[Callable]] + annotated_metadata = {} inherited_listeners = {} for cls in bases: + skipped_properties: Tuple[str, ...] skipped_properties = () if ( hasattr(cls, "__slots__") @@ -1922,38 +1933,40 @@ def __new__( else: inherited_listeners[key] = value if hasattr(cls, "__extra_slots__"): - support_columns.extend(list(cls.__extra_slots__)) + pending_support_columns.extend(list(cls.__extra_slots__)) if ismetasubclass(cls, AtomicMeta): + parent_atomic: Type[AtomicImpl] = cast(Type[AtomicImpl], cls) # Only AtomicMeta Descendants will merge in the helpers of # _columns: Dict[str, Type] - if cls._annotated_metadata: - annotated_metadata.update(cls._annotated_metadata) - if cls._column_types: - column_types.update(cls._column_types) - if cls._nested_atomic_collection_keys: - for key, value in cls._nested_atomic_collection_keys.items(): + if parent_atomic._annotated_metadata: + annotated_metadata.update(parent_atomic._annotated_metadata) + if parent_atomic._column_types: + column_types.update(parent_atomic._column_types) + if parent_atomic._nested_atomic_collection_keys: + for key, value in parent_atomic._nested_atomic_collection_keys.items(): # Override of this collection definition, so don't inherit! if key in combined_columns: continue nested_atomic_collections[key] = value - if cls._columns: - combined_columns.update(cls._columns) - if cls._slots: - combined_slots.update(cls._slots) - if cls._support_columns: - support_columns.extend(cls._support_columns) - skipped_properties = cls._no_op_properties - - if hasattr(cls, "setter_wrapper"): - setter_wrapper.append(cls.setter_wrapper) - if hasattr(cls, "__getter_template__"): - getter_templates.append(cls.__getter_template__) - if hasattr(cls, "__setter_template__"): - setter_templates.append(cls.__setter_template__) - if hasattr(cls, "__defaults__init__template__"): - defaults_templates.append(cls.__defaults__init__template__) + if parent_atomic._columns: + combined_columns.update(parent_atomic._columns) + if parent_atomic._slots: + combined_slots.update(parent_atomic._slots) + if parent_atomic._support_columns: + pending_support_columns.extend(parent_atomic._support_columns) + skipped_properties = parent_atomic._no_op_properties + + if hasattr(parent_atomic, "setter_wrapper"): + setter_wrapper.append(parent_atomic.setter_wrapper) + if hasattr(parent_atomic, "__getter_template__"): + getter_templates.append(parent_atomic.__getter_template__) + if hasattr(parent_atomic, "__setter_template__"): + setter_templates.append(parent_atomic.__setter_template__) + if hasattr(parent_atomic, "__defaults__init__template__"): + defaults_templates.append(parent_atomic.__defaults__init__template__) + del parent_atomic # Collect all publicly accessible properties: for key in dir(cls): value = getattr(cls, key) @@ -2082,7 +2095,7 @@ def __new__( all_coercions = {} # the `__class__` field of the generated functions will be incomplete, # so track them so we can replace them with a derived type made ``__class__`` - class_cell_fixups = [] + class_cell_fixups: List[Union[Tuple[str, Callable[..., Any]], Tuple[str, property]]] = [] for key, raw_typedef in tuple(current_class_slots.items()): disabled_derived = None if raw_typedef in klass.REGISTRY: @@ -2138,10 +2151,10 @@ def __new__( class_cell_fixups.append((key, new_property)) # Support columns are left as-is for slots - support_columns = tuple(_dedupe(support_columns)) + support_columns = tuple(_dedupe(pending_support_columns)) - ns_globals = {"NoneType": NoneType, "Flags": Flags, "typing": typing} - ns_globals[class_name] = ImmutableValue[Optional[Type[AtomicImpl]]](None) + dataclass_attrs = {"NoneType": NoneType, "Flags": Flags, "typing": typing} + dataclass_attrs[class_name] = ImmutableValue[Optional[Type[AtomicImpl]]](None) init_subclass = None @@ -2153,16 +2166,16 @@ def __new__( compile( make_fast_dumps(combined_columns, class_name), "", mode="exec" ), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) - class_cell_fixups.append(("_asdict", cast(FunctionType, ns_globals["_asdict"]))) - class_cell_fixups.append(("_astuple", cast(FunctionType, ns_globals["_astuple"]))) - class_cell_fixups.append(("_aslist", cast(FunctionType, ns_globals["_aslist"]))) + class_cell_fixups.append(("_asdict", cast(FunctionType, dataclass_attrs["_asdict"]))) + class_cell_fixups.append(("_astuple", cast(FunctionType, dataclass_attrs["_astuple"]))) + class_cell_fixups.append(("_aslist", cast(FunctionType, dataclass_attrs["_aslist"]))) exec( compile(make_fast_eq(combined_columns), "", mode="exec"), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) exec( compile( @@ -2170,10 +2183,10 @@ def __new__( "", mode="exec", ), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) - class_cell_fixups.append(("clear", cast(FunctionType, ns_globals["_clear"]))) + class_cell_fixups.append(("clear", cast(FunctionType, dataclass_attrs["_clear"]))) exec( compile( make_fast_getset_item( @@ -2186,8 +2199,8 @@ def __new__( "", mode="exec", ), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) iter_fields = [] for field in combined_columns: @@ -2196,8 +2209,8 @@ def __new__( iter_fields.append(field) exec( compile(make_fast_iter(iter_fields), "", mode="exec"), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) del iter_fields pickle_fields = [] @@ -2207,8 +2220,8 @@ def __new__( pickle_fields.append(field) exec( compile(make_set_get_states(pickle_fields), "", mode="exec"), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) exec( compile( @@ -2216,11 +2229,11 @@ def __new__( "", mode="exec", ), - ns_globals, - ns_globals, + dataclass_attrs, + dataclass_attrs, ) class_cell_fixups.append( - ("_set_defaults", cast(FunctionType, ns_globals["_set_defaults"])) + ("_set_defaults", cast(FunctionType, dataclass_attrs["_set_defaults"])) ) for key in ( @@ -2238,13 +2251,13 @@ def __new__( # Move the autogenerated functions into the support class # Any overrides that *may* call them will be assigned # to the concrete class instead - if key in ns_globals: + if key in dataclass_attrs: if key in base_class_functions: continue logger.debug(f"Copying {key} into {class_name} attributes") - support_cls_attrs[key] = ns_globals.pop(key) - if "_set_defaults" in ns_globals: - data_class_attrs["_set_defaults"] = ns_globals.pop("_set_defaults") + support_cls_attrs[key] = dataclass_attrs.pop(key) + if "_set_defaults" in dataclass_attrs: + data_class_attrs["_set_defaults"] = dataclass_attrs.pop("_set_defaults") # Any keys subtracted must have no-nop setters in order to # allow for subtype relationship will behaving as if those keys are fundamentally @@ -2254,9 +2267,9 @@ def __new__( __no_op_skip_get__, __no_op_skip_set__ ) - if ns_globals: + if dataclass_attrs: logger.debug( - f"Did not add the following to {class_name} attributes: {tuple(ns_globals.keys())}" + f"Did not add the following to {class_name} attributes: {tuple(dataclass_attrs.keys())}" ) support_cls_attrs["_columns"] = ImmutableMapping[str, CustomTypeCheck](combined_columns) @@ -2316,7 +2329,8 @@ def __new__( ImmutableValue[Type[AtomicImpl]], dc_parent ) support_cls = cast( - Type[Atomic], super().__new__(klass, class_name, bases, support_cls_attrs) + Type[Atomic], + super().__new__(klass, class_name, bases, support_cls_attrs, **init_subclass_kwargs), ) # type:ignore[misc] for prop_name, value in support_cls_attrs.items(): @@ -2331,7 +2345,7 @@ def __new__( continue setattr(support_cls, prop_name, value) - ns_globals["klass"] = support_cls + dataclass_attrs["klass"] = support_cls dataclass_slots = ( tuple("_{}_".format(key) for key in combined_columns) + support_columns + extra_slots ) @@ -2341,9 +2355,12 @@ def __new__( data_class_attrs=data_class_attrs, class_slots=current_class_slots, ) - ns_globals["_dataclass_attrs"] = data_class_attrs - exec(compile(dataclass_template, "", mode="exec"), ns_globals, ns_globals) - dc.value = data_class = ns_globals[f"_{class_name}"] + dataclass_attrs["_dataclass_attrs"] = data_class_attrs + exec(compile(dataclass_template, "", mode="exec"), dataclass_attrs, dataclass_attrs) + + data_class: Type[Atomic] + + dc.value = data_class = cast(Type[Atomic], dataclass_attrs[f"_{class_name}"]) data_class.__module__ = support_cls.__module__ for key, value in data_class_attrs.items(): if callable(value): @@ -2542,61 +2559,107 @@ def list_changes(self): AtomicMeta.register_mixin("history", History) -DEFAULTS = """{%- for field in fields %} -result._{{field}}_ = None -{%- endfor %} -""" +def _cls_keys( + cls: Type[Atomic], instance: Optional[Atomic] = None, *, all: bool = False +) -> Union[InstanceKeysView[Atomic, str], ClassKeysView[Atomic, str], ImmutableCollection[str]]: + if instance is not None: + return keys(instance, all=all) + class_keys: ClassKeysView[Atomic, str] = keys(cls, all=all) + return class_keys -class IMapping(metaclass=AtomicMeta): - """ - Allow an instruct class instance to have the `keys()` function which is - mandatory to support **item unpacking. +def _instance_keys( + self: Atomic, *, all: bool = False +) -> Union[InstanceKeysView[Atomic, str], ImmutableCollection[str]]: + return keys(self, all=all) - This will collide with any property that's already named keys. - """ - __slots__ = () +def _cls_values(cls: Type[Atomic], item: Atomic): + return values(item) + + +def _instance_values(self): + return values(self) - @ClassOrInstanceFuncsDescriptor - def keys(cls, instance=None, *, all=False) -> Set[str]: - if instance is not None: - return keys(instance, all=all) - return keys(cls, all=all) - @keys.instance_function - def keys(self, *, all=False) -> Set[str]: - return keys(self, all=all) +def _cls_items(cls: Type[Atomic], item: Atomic): + return items(item) - @ClassOrInstanceFuncsDescriptor - def values(cls, item): - return values(item) - @values.instance_function - def values(self): - return values(self) +def _instance_items(self: Atomic): + return items(self) - @ClassOrInstanceFuncsDescriptor - def items(cls, item): - return items(item) - @items.instance_function - def items(self): - return items(self) +def _cls_get(cls: Type[Atomic], instance: Atomic, key: str, default=None): + return get(instance, key, default) - @ClassOrInstanceFuncsDescriptor - def get(cls, instance, key, default=None): - return get(instance, key, default) - @get.instance_function - def get(self, key, default=None): - return get(self, key, default) +def _instance_get(self: Atomic, key: str, default=None): + return get(self, key, default) +class IMapping(Generic[Atomic], metaclass=AtomicMeta): + """ + Allow an instruct class instance to have the `keys()` function which is + mandatory to support **item unpacking. + + This will collide with any property that's already named keys. + """ + + __slots__ = () + + keys: ClassOrInstanceFuncsDescriptor[Atomic] = ClassOrInstanceFuncsDescriptor[Atomic]( + _cls_keys, cast(InstanceCallable[Atomic], _instance_keys) + ) + values: ClassOrInstanceFuncsDescriptor[Atomic] = ClassOrInstanceFuncsDescriptor[Atomic]( + _cls_values, cast(InstanceCallable[Atomic], _instance_values) + ) + items: ClassOrInstanceFuncsDescriptor[Atomic] = ClassOrInstanceFuncsDescriptor[Atomic]( + _cls_items, cast(InstanceCallable[Atomic], _instance_items) + ) + get: ClassOrInstanceFuncsDescriptor[Atomic] = ClassOrInstanceFuncsDescriptor[Atomic]( + _cls_get, cast(InstanceCallable[Atomic], _instance_get) + ) + + +del _instance_keys, _cls_values, _instance_values, _cls_items, _instance_items, _cls_get +del _instance_get AtomicMeta.register_mixin("mapping", IMapping) -def add_event_listener(*fields): +def add_event_listener(*fields: str): + """ + Event listeners are functions that are run when an attribute is set. + + Supports: + def listener(self, new): + ... + + def listener(self, old, new): + ... + + def listener(self, name: str, old, new): + ... + + >>> from instruct import SimpleBase + >>> class Foo(SimpleBase): + ... field_one: str + ... field_two: int + ... field_three: Union[str, int] + ... @add_event_listener('field_one', 'field_two') + ... def _on_field_change(self, name: str, old_value: Union[str, int, None], new_value: Union[int, str]): + ... if name == 'field_one': + ... if not new_value: + ... self.field_one = 'No empty!' + ... elif name == 'field_two': + ... if new_value < 0: + ... self.field_two = 0 + ... + >>> astuple(Foo('', -1)) + ('No empty!', 0, None) + >>> + """ + def wrapper(func): func._event_listener_funcs = getattr(func, "_event_listener_funcs", ()) + fields return func @@ -2604,7 +2667,31 @@ def wrapper(func): return wrapper -def handle_type_error(*fields): +def handle_type_error(*fields: str): + """ + Use this to call a function when an attempt to set a field fails due to a type + mismatch. If a true-ish value is returned, then the default TypeError will not + be thrown. + + >>> from instruct import SimpleBase + >>> class Foo(SimpleBase): + ... field_one: str + ... field_two: int + ... field_three: Union[str, int] + ... @handle_type_error('field_two') + ... def _try_cast_field_two(self, val): + ... try: + ... self.field_two = int(val, 10) + ... except Exception: + ... pass + ... else: + ... return True + ... + >>> f = Foo('My Foo', '255') + >>> astuple(f) + ('My Foo', 255, None) + """ + def wrapper(func): func._post_coerce_failure_funcs = getattr(func, "_post_coerce_failure_funcs", ()) + fields return func @@ -2613,77 +2700,36 @@ def wrapper(func): def load_cls(cls, args, kwargs, skip_fields: Optional[FrozenMapping] = None): + """ + :internal: interface for ``__reduce__`` to call. + """ if skip_fields: cls = cls - skip_fields return cls(*args, **kwargs) -def _encode_simple_nested_base(iterable, *, immutable=None): - """ - Handle an List[Base], Tuple[Base], Mapping[Any, Base] - and coerce to json. Does not deeply traverse by design. - """ - - # Empty items short circuit: - if not iterable: - return iterable - if isinstance(iterable, AbstractMapping): - destination = iterable - if immutable: - if hasattr(iterable, "copy"): - destination = iterable.copy() - else: - # initialize an empty form of an AbstractMapping: - destination = type(iterable)() - for key in iterable: - value = iterable[key] - if hasattr(value, "to_json"): - destination[key] = value.to_json() - else: - destination[key] = value - return iterable - elif isinstance(iterable, Sequence): - if immutable is None and isinstance(iterable, (tuple, frozenset)): - immutable = True - elif immutable is None and isinstance(iterable, (list, set)): - immutable = False - if immutable is None: - try: - iterable[0] = iterable[0] - except Exception: - immutable = True - else: - immutable = False - if immutable: - # Convert to a mutable list and replace items - iterable = list(iterable) - for index, item in enumerate(iterable): - if hasattr(item, "to_json"): - iterable[index] = item.to_json() - return iterable - return iterable - - class JSONSerializable(metaclass=AtomicMeta): __slots__ = () def to_json(self) -> Dict[str, Any]: - return AtomicMeta.to_json(self)[0] + return self.__json__() def __json__(self) -> Dict[str, Any]: - return self.to_json() + assert ismetasubclass(type(self), AtomicMeta) + return AtomicMeta.to_json(cast(AtomicMeta, self))[0] @classmethod - def from_json(cls: Type[T], data: Dict[str, Any]) -> T: + def from_json(cls: AtomicMeta, data: Dict[str, Any]) -> Type[Atomic]: return cls(**data) @classmethod - def from_many_json(cls: Type[T], iterable: Iterable[Dict[str, Any]]) -> Tuple[T, ...]: + def from_many_json(cls: AtomicMeta, iterable: Iterable[Dict[str, Any]]) -> Tuple[T, ...]: return tuple(cls.from_json(item) for item in iterable) AtomicMeta.register_mixin("json", JSONSerializable) + # ARJ: How we create the ``__init__`` body on the concrete class (i.e. one with ``__slots__``) # This is really opaque and needs to be rethought as almost no one uses it. It's meant # to allow codegen to a different access pattern @@ -2727,34 +2773,53 @@ def __reduce__(self): # Get the public, unmessed with class: cls = public_class(self) - skipped_fields = _dump_skipped_fields(type(self)) - return load_cls, (cls, (), {}, skipped_fields), self.__getstate__() + s = skipped_fields(self) + return load_cls, (cls, (), {}, s), self.__getstate__() @classmethod def _create_invalid_type(cls, field_name, val, val_type, types_required): - if len(types_required) > 1: - if len(types_required) == 2: - expects = "either an {.__name__} or {.__name__}".format(*types_required) - else: - expects = ( - f'either an {", ".join(x.__name__ for x in types_required[:-1])} ' - f"or a {types_required[-1].__name__}" + pending_types_required_names = [] + # ARJ: hmm, you know... the "types_required" field really can be made into a static + # string and stored on the type once... + for req_cls in types_required: + # ARJ: Handle Literal[1, 2, 3]-cases.. :/ + if issubclass(req_cls, CustomTypeCheck) and req_cls.__origin__ is Literal: + pending_types_required_names.extend( + f'"{arg}"' if isinstance(arg, str) else str(arg) for arg in req_cls.__args__ ) + continue + pending_types_required_names.append(req_cls.__name__) + types_required_names = tuple(pending_types_required_names) + if len(types_required_names) > 1: + if len(types_required_names) == 2: + left, right = types_required_names + expects = f"either an {left} or {right}" + else: + *rest_types, end = types_required_names + rest = ", ".join([x for x in rest_types]) + expects = f"either an {rest} or a {end}" else: - expects = f"a {types_required[0].__name__}" - return TypeError( + (expected_type,) = types_required_names + expects = f"a {expected_type}" + return InstructTypeError( f"Unable to set {field_name} to {val!r} ({val_type.__name__}). {field_name} expects " - f"{expects}" + f"{expects}", + field_name, + val, ) @classmethod def _create_invalid_value(cls, message, *args, **kwargs): - return ValueError(message, *args, **kwargs) + return InstructValueError(message, *args, **kwargs) def _handle_init_errors(self, errors, errored_keys, unrecognized_keys): if unrecognized_keys: fields = ", ".join(unrecognized_keys) - errors.append(self._create_invalid_value(f"Unrecognized fields {fields}")) + errors.append( + self._create_invalid_value( + f"Unrecognized fields {fields}", fields=unrecognized_keys + ) + ) if errors: typename = inflection.titleize(type(self).__name__[1:]) if len(errors) == 1: @@ -2786,6 +2851,7 @@ def __init__(self, *args, **kwargs): except Exception as e: errors.append(e) errored_keys.append(key) + class_keys = self._all_accessible_fields # Set by keywords for key in class_keys & kwargs.keys(): @@ -2800,9 +2866,10 @@ def __init__(self, *args, **kwargs): unrecognized_keys = kwargs.keys() - class_keys self._handle_init_errors(errors, errored_keys, unrecognized_keys) self._flags = Flags.INITIALIZED + self.__post_init__() - @mark(base_cls=True) - def _clear(self, fields: Iterable[str] = None): + # ARJ: Now you don't need to override __init__ just to do post init things + def __post_init__(self: Self): pass @mark(base_cls=True) @@ -2817,13 +2884,6 @@ def _astuple(self) -> Tuple[Any, ...]: def _aslist(self) -> List[Any]: return [] - @mark(base_cls=True) - def _set_defaults(self): - # ARJ: Override to set defaults instead of inside the `__init__` function - # Note: Always call ``super()._set_defaults()`` FIRST as if you - # call it afterwards, the inheritance tree will zero initialize it first - return self - @mark(base_cls=True) def __iter__(self): """ diff --git a/instruct/templates/data_class.jinja b/instruct/templates/data_class.jinja index 51fc961..2024dab 100644 --- a/instruct/templates/data_class.jinja +++ b/instruct/templates/data_class.jinja @@ -1,8 +1,9 @@ class _{{class_name}}(klass, concrete_class=True): __slots__ = {{slots}} + _is_data_class = True # Copy any custom __eq__, etc into the concrete class # so that we can allow the user to call super() # to defer to the autogenerated __eq__, etc {% for key in data_class_attrs %} {{key}} = _dataclass_attrs["{{key}}"] - {% endfor %} \ No newline at end of file + {% endfor %} diff --git a/instruct/typedef.py b/instruct/typedef.py index a9493df..c26944d 100644 --- a/instruct/typedef.py +++ b/instruct/typedef.py @@ -2,7 +2,6 @@ import collections.abc import inspect import sys -import typing import warnings from functools import wraps from types import FunctionType @@ -29,18 +28,12 @@ Iterable, overload, Generic, - Generator, + # Generator, Set, Dict, cast as cast_type, ) from weakref import WeakKeyDictionary - -if typing.TYPE_CHECKING: - from . import AtomicMeta - -from .typing import Protocol, Literal, Annotated, TypeGuard, is_typing_definition - from typing_extensions import ( get_origin as _get_origin, get_original_bases, @@ -49,16 +42,18 @@ ) from typing_extensions import get_args, get_type_hints -from .utils import flatten_restrict as flatten +from .constants import Range +from .typing import Protocol, Literal, Annotated, TypeGuard, is_typing_definition from .typing import ( - CustomTypeCheck as ICustomTypeCheck, TypingDefinition, EllipsisType, Atomic, TypeHint, + CustomTypeCheck, ) -from .constants import Range +from .utils import flatten_restrict as flatten from .exceptions import RangeError, TypeError as InstructTypeError +from .types import IAtomic T = TypeVar("T") U = TypeVar("U") @@ -96,7 +91,7 @@ def is_union_typedef(t) -> bool: _abstract_custom_types = WeakKeyDictionary() -class AbstractWrappedMeta(type): +class CustomTypeCheckMetaBase(type, Generic[T]): __slots__ = () def set_name(t, name): @@ -108,20 +103,100 @@ def set_type(t, type): def is_abstract_custom_typecheck(o: Type) -> bool: cls_type = type(o) - return isinstance(cls_type, AbstractWrappedMeta) + return isinstance(cls_type, CustomTypeCheckMetaBase) + + +def make_custom_typecheck(*args, is_abstract_type=False): + if len(args) == 1 and callable(args[0]): + caller = inspect.stack()[1] + warnings.warn( + f"{caller.filename}:{caller.function}:{caller.lineno}: change make_custom_typecheck(...) to have the type we're pretending to be!", # noqa:E501 + DeprecationWarning, + ) + args = (object, *args, ()) + return _make_custom_typecheck(*args, is_abstract_type=is_abstract_type) + + +@overload +def _make_custom_typecheck( + typehint: TypingDefinition, + func: Callable[[Union[Any, T]], bool], + type_args: Union[Tuple[Any, ...], Tuple[Type, ...]], + *, + is_abstract_type: bool = False, +) -> Type[CustomTypeCheck[T]]: + ... -def make_custom_typecheck(typecheck_func, typedef=None) -> Type[ICustomTypeCheck]: - """Create a custom type that will turn `isinstance(item, klass)` into `typecheck_func(item)`""" - typename = "WrappedType<{}>" +@overload +def _make_custom_typecheck( + typehint: Tuple[Type[T], ...], + func: Callable[[Union[Any, T]], bool], + type_args: Union[Tuple[Any, ...], Tuple[Type, ...]], + *, + is_abstract_type: bool = False, +) -> Type[CustomTypeCheck[T]]: + ... + + +@overload +def _make_custom_typecheck( + typehint: Type[T], + func: Callable[[Union[Any, T]], bool], + type_args: Union[Tuple[Any, ...], Tuple[Type, ...]], + *, + is_abstract_type: bool = False, +) -> Type[CustomTypeCheck[T]]: + ... + + +def _make_custom_typecheck( + typehint: Union[TypingDefinition, Tuple[Type[T], ...], Type[T]], + func: Callable[[Union[Any, T]], bool], + type_args, + *, + is_abstract_type=False, +) -> Type[CustomTypeCheck[T]]: + """ + Create a custom type that will turn `isinstance(item, klass)` into `func(item)` + """ + assert ( + is_typing_definition(typehint) + or isinstance(typehint, type) + or (isinstance(typehint, tuple) and all(isinstance(x, type) for x in typehint)) + ) + if is_typing_definition(typehint): + typehint_str = str(typehint) + origin_cls = get_origin(typehint) + if "typing." in typehint_str: + typehint_str = typehint_str.replace("typing.", "") + if "typing_extensions." in typehint_str: + typehint_str = typehint_str.replace("typing_extensions.", "") + if isinstance(typehint, TypeVar): + is_abstract_type = True + else: + if isinstance(typehint, type): + typehint_str = typehint.__name__ + origin_cls = get_origin(typehint) + elif isinstance(typehint, tuple): + origin_cls = Union + is_abstract_type = True + typehint_str = str(Union[typehint]) + else: + raise TypeError(f"Unknown type ({typehint!r}) {type(typehint)}") + + typename = "" bound_type = None registry = _abstract_custom_types - class WrappedTypeMeta(AbstractWrappedMeta): + class CustomTypeCheckMeta(CustomTypeCheckMetaBase[T]): __slots__ = () - def __instancecheck__(self, instance): - return typecheck_func(instance) + def __instancecheck__(self, instance: Union[Any, T]) -> TypeGuard[T]: + return func(instance) + + def __str__(self): + return f"{CustomTypeCheckType.__name__}" def __repr__(self): return typename.format(super().__repr__()) @@ -131,7 +206,59 @@ def __getitem__(self, key): raise AttributeError(key) return parse_typedef(bound_type[key]) - class _WrappedType(metaclass=WrappedTypeMeta): + __origin__ = origin_cls + __args__ = get_args(typehint) + + if origin_cls in (Union, Literal): + + def __iter__(self): + return iter(type_args) + + def __contains__(self, key): + return key in type_args + + bases: Tuple[Type, ...] + if is_abstract_type: + bases = (CustomTypeCheck, cast_type(Type, Generic[T])) # type:ignore[misc] + else: + assert isinstance(typehint, type), f"wyf - {typehint!s} ({type(typehint)})" + type_cls: Type[T] = cast_type(Type[T], typehint) + bases = ( + CustomTypeCheck, + cast_type(Type, Generic[T]), # type:ignore[misc] + cast_type(Type, typehint), + ) # type:ignore[misc] + + if not is_abstract_type: + if issubclass(type_cls, AbstractMapping): + key_type, value_type = type_args + typehint_str = f"{type_cls.__name__}[{key_type}, {value_type}]" + + def validate_mapping(iterable, **kwargs): + for key, value in iterable: + if not isinstance(key, key_type): + raise InstructTypeError(f"Key {key!r} is not a {key_type}", key, value) + if not isinstance(value, value_type): + raise InstructTypeError( + f"Value {value!r} is not a {value_type}", key, value + ) + yield key, value + + elif issubclass(type_cls, AbstractIterable): + if issubclass(type_cls, tuple) and Ellipsis in type_args: + typehint_str = f"{Tuple[typehint]}" + else: + typehint_str = f'{type_cls.__name__}[{", ".join(x.__name__ for x in type_args)}]' + + def validate_iterable(values): + for index, item in enumerate(values): + if not isinstance(item, type_args): + raise InstructTypeError( + f"{item!r} at index {index} is not a {typehint_str}", index, item + ) + yield item + + class CustomTypeCheckType(*bases, metaclass=CustomTypeCheckMeta[T]): # type:ignore[misc] __slots__ = () def __class_getitem__(self, key): @@ -139,17 +266,95 @@ def __class_getitem__(self, key): raise AttributeError(key) return parse_typedef(bound_type[key]) - def __new__(cls, *args): - nonlocal bound_type - if bound_type is None: - raise TypeError("Unbound or abstract type!") - return bound_type(*args) + def __new__(cls, iterable=None, **kwargs): + # if bound_type is None: + # return bound_type(*args) + if origin_cls is Union: + raise TypeError(f"Cannot instantiate a {typehint_str} (UnionType) directly!") + if is_abstract_type: + raise TypeError(f"Cannot instantiate abstract class for {typehint}") + if iterable: + if issubclass(typehint, AbstractMapping): + iterable = {**iterable, **kwargs} + iterable = dict(validate_mapping(iterable.items())) + elif issubclass(typehint, AbstractIterable): + iterable = tuple(validate_iterable(iterable)) + + return super().__new__(cls, iterable, **kwargs) + + def __str__(self): + return f"{CustomTypeCheckType.__name__}" + + def __repr__(self): + return f"{CustomTypeCheckType.__name__}({super().__repr__()})" + + if not is_abstract_type: + if issubclass(type_cls, AbstractMutableSequence): + + def __setitem__s(self, index_or_slice: Union[slice, int], value): + if isinstance(index_or_slice, slice): + return super().__setitem__(index_or_slice, validate_iterable(value)) + if not isinstance(value, type_args): + raise InstructTypeError( + f"{value!r} is not a {type_args}", index_or_slice, value + ) + super().__setitem__(index_or_slice, value) + + def insert(self, index, value): + if not isinstance(value, type_args): + raise InstructTypeError(f"{value!r} is not a {type_args}", index, value) + super().insert(index, value) + + def append(self, value): + if not isinstance(value, type_args): + index = len(self) + raise InstructTypeError(f"{value!r} is not a {type_args}", index, value) + return super().append(value) + + def extend(self, values): + return super().extend(validate_iterable(values)) + + __setitem__ = __setitem__s + + elif issubclass(type_cls, AbstractMutableMapping): + + def __setitem__m(self, key, value): + if not isinstance(key, key_type): + raise InstructTypeError(f"Key {key!r} is not a {key_type}", key, value) + if not isinstance(value, value_type): + raise InstructTypeError( + f"Value {value!r} is not a {value_type}", key, value + ) + return super().__setitem__(key, value) + + __setitem__ = __setitem__m + + if hasattr(type_cls, "setdefault"): + + def setdefault(self, key, value): + if not isinstance(key, key_type): + raise InstructTypeError(f"Key {key!r} is not a {key_type}", key, value) + if not isinstance(value, value_type): + raise InstructTypeError( + f"Value {value!r} is not a {value_type}", key, value + ) + return super().setdefault(key, value) + + if hasattr(type_cls, "update"): + + def update(self, iterable): + if isinstance(iterable, AbstractMapping): + iterable = {**iterable}.items() + return super().update(validate_mapping(iterable)) + + CustomTypeCheckType.__name__ = typehint_str + CustomTypeCheckMeta.__name__ = "CustomTypeCheck[{}]".format(typehint_str) def set_name(name): nonlocal typename typename = name - _WrappedType.__name__ = name - _WrappedType._name__ = name + CustomTypeCheckType.__name__ = name + CustomTypeCheckMeta.__name__ = "CustomTypeCheck[{}]".format(name) return name def set_type(t): @@ -157,8 +362,8 @@ def set_type(t): bound_type = t return t - registry.setdefault(_WrappedType, (set_name, set_type)) - return cast(Type[ICustomTypeCheck], _WrappedType) + registry.setdefault(CustomTypeCheckType, (set_name, set_type)) + return CustomTypeCheckType def ismetasubclass(cls, metacls): @@ -317,71 +522,127 @@ def find_class_in_definition( return type_hints -def create_custom_type(container_type, *args, check_ranges=()): +def create_custom_type(container_type: M, *args: Union[Type[Atomic], Any, Type], check_ranges=()): + # An abtract type is one that we cannot make real ourselves, + # like a Union[str, int] cannot be initialized -- it's not real. + is_abstract_type = False + if get_origin(container_type) is not None and get_origin(container_type) is not Literal: + assert get_args(container_type) in ( + args, + None, + ), f"{container_type} has {get_args(container_type)} != {args}" + + def _on_new_type(new_type: Type[CustomTypeCheck[T]]): + if is_typing_definition(container_type): + new_name = f"{container_type}" + elif isinstance(container_type, tuple): + new_name = f"{Union[container_type]}" + elif isinstance(container_type, type): + new_name = container_type.__qualname__ + for prefix in ("builtins.", "typing_extensions.", "typing."): + new_name = remove_prefix(new_name, prefix) + metaclass = cast(CustomTypeCheckMetaBase, type(new_type)) + metaclass.set_name(new_type, new_name) + return new_type + + make_type = run_after(make_custom_typecheck, _on_new_type) + test_func: Optional[Callable] = None + if isinstance(container_type, tuple): + assert not args + assert all(isinstance(x, type) for x in container_type) + return create_custom_type( + cast_type(TypingDefinition, Union[container_type]), + *container_type, + check_ranges=check_ranges, + ) + if is_typing_definition(container_type): origin_cls = get_origin(container_type) - if hasattr(container_type, "_name") and container_type._name is None: - if origin_cls is Union: - types = flatten( - (create_custom_type(arg) for arg in container_type.__args__), eager=True - ) - if check_ranges: - - def test_func(value) -> bool: - if not isinstance(value, types): - return False - failed_ranges = [] - for rng in check_ranges: - if rng.applies(value): - try: - in_range = value in rng - except TypeError: - continue - else: - if in_range: - return True - else: - failed_ranges.append(rng) - if failed_ranges: - raise RangeError(value, failed_ranges) - return False + if origin_cls is Union: + is_abstract_type = True - else: + assert args, f"Got empty args for a {container_type}" + types = flatten((parse_typedef(arg) for arg in args), eager=True) + is_simple_types = not any(issubclass(x, CustomTypeCheck) for x in types) + assert types - def test_func(value) -> bool: - """ - Check if the value is of the type set - """ - return isinstance(value, types) + if check_ranges: - elif origin_cls is Literal: - from . import AtomicMeta, public_class + def test_func_ranged_union(value: Union[Any, T]) -> TypeGuard[T]: + if not isinstance(value, types): + return False + failed_ranges = [] + for rng in check_ranges: + if rng.applies(value): + try: + in_range = value in rng + except TypeError: + continue + else: + if in_range: + return True + else: + failed_ranges.append(rng) + if failed_ranges: + raise RangeError(value, failed_ranges) + return False + + test_func = test_func_ranged_union + + else: + # ARJ: if we have Union[int, str, float], we really + # should return + # types := (int, str, float) for the fastest ``isinstance(value, types)`` + if is_simple_types: + return types - def test_func(value) -> bool: + def test_func_union(value: Union[Any, T]) -> TypeGuard[T]: """ - Operate on a Literal type + Check if the value is of the type set """ - for arg in args: - if arg is value: - # Exact match on ``is``, useful for enums + return isinstance(value, types) + + test_func = test_func_union + + return make_type(container_type, test_func, types, is_abstract_type=is_abstract_type) + + elif origin_cls is Literal: + is_abstract_type = True + from . import public_class + + assert args, "Literals require arguments" + + def test_func_literal(value: Union[Any, T]) -> TypeGuard[T]: + """ + Operate on a Literal type + """ + for arg in args: + if arg is value: + # Exact match on ``is``, useful for enums + return True + elif arg == value: + # Equality by value. This may be part of an + # overridden __eq__, so check the types too! + if isinstance(arg, type): + arg_type = arg + else: + arg_type = type(arg) + if isinstance(arg_type, IAtomic): + arg_type = public_class( + cast(Type["Atomic"], arg_type), preserve_subtraction=True + ) + if isinstance(value, arg_type): return True - elif arg == value: - # Equality by value. This may be part of an - # overridden __eq__, so check the types too! - if isinstance(arg, type): - arg_type = arg - else: - arg_type = type(arg) - if isinstance(arg_type, AtomicMeta): - arg_type = public_class(arg_type, preserve_subtraction=True) - if isinstance(value, arg_type): - return True - return False + return False + + test_func = test_func_literal elif container_type is AnyStr: + is_abstract_type = True + assert not args if check_ranges: - def test_func(value) -> bool: + def test_ranged_anystr(value: Union[T, Any]) -> TypeGuard[T]: if not isinstance(value, (str, bytes)): return False failed_ranges = [] @@ -400,83 +661,202 @@ def test_func(value) -> bool: raise RangeError(value, failed_ranges) return False + test_func = test_ranged_anystr + else: return (str, bytes) elif container_type is Any: + # is_abstract_type = True + assert not args return object - elif origin_cls is not None and ( - issubclass(origin_cls, collections.abc.Iterable) - and issubclass(origin_cls, collections.abc.Container) - ): - return parse_typedef(container_type) + elif isinstance(origin_cls, type): + if not args: + return origin_cls + return create_custom_type(origin_cls, *args, check_ranges=check_ranges) elif isinstance(container_type, TypeVar): def test_func(value) -> bool: return False - else: - raise NotImplementedError(container_type, repr(container_type)) - elif isinstance(container_type, type) and ( - issubclass(container_type, collections.abc.Iterable) - and issubclass(container_type, collections.abc.Container) - ): - test_func = create_typecheck_container(container_type, args) - elif isinstance(container_type, type) and not args: - if check_ranges: - - def test_func(value) -> bool: - if not isinstance(value, container_type): - return False - failed_ranges = [] - for rng in check_ranges: - if rng.applies(value): - try: - in_range = value in rng - except TypeError: + elif isinstance(container_type, type) and Protocol in container_type.mro(): + # Like P[int] where P = class P(Protocol[T]) + protocol: Type = cast(type, container_type) + if is_simple_protocol(protocol): + is_abstract_type = True + + attribute_types: Dict[str, CustomTypeCheck] = {} + attribute_values: Dict[str, Any] = {} + attribute_functions: Set[str] = set() + hints = get_type_hints(protocol) + + for attribute in get_protocol_members(protocol): + with suppress(KeyError): + hint = hints[attribute] + attribute_types[attribute] = parse_typedef(hint) + with suppress(AttributeError): + attribute_value = getattr(protocol, attribute) + if isinstance(attribute_value, FunctionType): + attribute_functions.add(attribute) continue - else: - if in_range: - return True + attribute_values[attribute] = attribute_value + + def test_simple_protocol(value): + assert attribute_types or attribute_values or attribute_functions + if attribute_types: + for attribute, type_cls in attribute_types.items(): + try: + attr_value = getattr(value, attribute) + except AttributeError: + return False else: - failed_ranges.append(rng) - if failed_ranges: - raise RangeError(value, failed_ranges) - return False + if not isinstance(attr_value, type_cls): + return False + if attribute_values: + for attribute, expected_value in attribute_values.items(): + try: + attr_value = getattr(value, attribute) + except AttributeError: + return False + else: + if expected_value != attr_value: + return False + if attribute_functions: + value_type = type(value) + for attribute in attribute_functions: + try: + attr_value = getattr(value_type, attribute) + except AttributeError: + return False + else: + if not callable(attr_value): + return False + return True + test_func = test_simple_protocol + + elif is_protocol(protocol): + # This implements Protocol[T] + assert get_args(protocol) + ... + raise NotImplementedError( + f"generic protocol support not implemented ({get_args(protocol)})" + ) + else: + cls = get_origin(protocol) + assert cls is not None and isinstance(cls, type), Protocol in protocol.mro() + assert is_protocol(cls) + # ARJ: This is the case where someone is referring + # to a specialized version of a protocol, like + # class P(Protocol[T]): + # a: T + # container_type = P[int] + # parse_typedef(container_type) + raise NotImplementedError("Specialized protocol-inheriting classes not implemented") else: - return container_type - else: - assert isinstance(container_type, tuple), f"container_type is {container_type}" - if check_ranges: - # materialized = issubclass(container_type, collections.abc.Container) and issubclass(container_type, collections.abc.Iterable) + raise NotImplementedError( + f"To be implemented: {container_type} ({str(container_type)})" + ) + + # active (collection type, *(generic_types)) ? + if test_func is None: + assert isinstance(container_type, type) + # ARJ: every built-in type implements AbstractCollection... + if (issubclass(container_type, (AbstractIterable,))) and args: + is_abstract_type = not ( + issubclass(container_type, AbstractCollection) + and not container_type.__module__.startswith("collections.abc") + ) + test_func, args = create_typecheck_for_container(container_type, args) + if isinstance(args, type): + args = (args,) + + elif not args: + # This type has no args (i.e. like it's Dict or set or a custom simple class) + if check_ranges: - def test_func(value): - if not isinstance(value, container_type): + def test_regular_type(value: Union[Any, T]) -> TypeGuard[T]: + if not isinstance(value, container_type): + return False + failed_ranges = [] + for rng in check_ranges: + if rng.applies(value): + try: + in_range = value in rng + except TypeError: + continue + else: + if in_range: + return True + else: + failed_ranges.append(rng) + if failed_ranges: + raise RangeError(value, failed_ranges) return False - failed_ranges = [] - for rng in check_ranges: - if rng.applies(value): - try: - in_range = value in rng - except TypeError: - continue - else: - if in_range: - return True + + test_func = test_regular_type + + else: + return container_type + else: + is_abstract_type = True + assert isinstance(container_type, type), f"must be a type - {container_type!r}" + assert not args + assert not get_args(container_type) + if check_ranges: + + def test_ranged_abstract_type(value: Union[Any, T]) -> TypeGuard[T]: + if not isinstance(value, container_type): + return False + failed_ranges = [] + for rng in check_ranges: + if rng.applies(value): + try: + in_range = value in rng + except TypeError: + continue else: - failed_ranges.append(rng) - if failed_ranges: - raise RangeError(value, failed_ranges) - return False + if in_range: + return True + else: + failed_ranges.append(rng) + if failed_ranges: + raise RangeError(value, failed_ranges) + return False + + test_func = test_ranged_abstract_type + + else: + def test_abstract_type(value: Union[T, Any]) -> TypeGuard[T]: + return isinstance(value, container_type) + + test_func = test_abstract_type + + new_type = make_type(container_type, test_func, args, is_abstract_type=is_abstract_type) + return new_type + + +def run_after(func: Callable, *thunks: Callable): + @wraps(func) + def wrapped(*args, **kwargs): + try: + v = func(*args, **kwargs) + except Exception: + raise else: + for thunk in thunks: + r = thunk(v) + if r is not None: + v = r + return v + + return wrapped - def test_func(value): - return isinstance(value, container_type) - t = make_custom_typecheck(test_func) - AbstractWrappedMeta.set_type(t, container_type) - return t +def remove_prefix(s: str, prefix: str): + if s.startswith(prefix): + return s[len(prefix) :] + return s def is_variable_tuple(decl) -> bool: @@ -487,36 +867,77 @@ def is_variable_tuple(decl) -> bool: return False -def create_typecheck_container(container_type, items): +V = TypeVar("V") + + +@overload +def create_typecheck_for_container( + container_cls: Tuple[Type[V], ...], value_types: Tuple[Type[Any], ...] = () +) -> Tuple[Callable[[Any], TypeGuard[Tuple[V, ...]]], Tuple[Any, ...]]: + ... + + +@overload +def create_typecheck_for_container( + container_cls: Type[Tuple[V, EllipsisType]], value_types: Tuple[Type, ...] = () +) -> Tuple[Callable[[Any], TypeGuard[Tuple[V, ...]]], Tuple[Any, ...]]: + ... + + +@overload +def create_typecheck_for_container( + container_cls: Type[Mapping[T, U]], value_types: Tuple[Type, ...] = () +) -> Tuple[Callable[[Any], TypeGuard[Mapping[T, U]]], Tuple[Any, ...]]: + ... + + +@overload +def create_typecheck_for_container( + container_cls: Type[Iterable[T]], value_types: Tuple[Type, ...] = () +) -> Tuple[Callable[[Any], TypeGuard[Iterable[T]]], Tuple[Any, ...]]: + ... + + +def create_typecheck_for_container(container_cls, value_types=()): + """ + Determine a function to determine if given value V is an instance of some container type + and values within are within the + """ + assert isinstance(container_cls, type), f"{container_cls!r} is not a type" + assert issubclass( + container_cls, (AbstractMapping, AbstractIterable, tuple) + ), f"Not a supported container type - {container_cls!r}" + test_types = [] test_func: Optional[Callable[[Any], bool]] = None - if issubclass(container_type, tuple): - container_type = tuple + if issubclass(container_cls, tuple): + container_cls = tuple # Special support: Tuple[type, ...] - if any(item is Ellipsis for item in items): - if len(items) != 2: + if any(item is Ellipsis for item in value_types): + if len(value_types) != 2: raise TypeError("Tuple[type, ...] is allowed but it must be a two pair tuple!") - homogenous_type_spec, ellipsis = items + homogenous_type_spec, ellipsis = value_types if ellipsis is not Ellipsis or homogenous_type_spec is Ellipsis: raise TypeError( "Tuple[type, ...] is allowed but it must have ellipsis as second arg" ) homogenous_type = parse_typedef(homogenous_type_spec) - def test_func(value): - if not isinstance(value, container_type): + def test_func_homogenous_tuple(value): + if not isinstance(value, container_cls): return False return all(isinstance(item, homogenous_type) for item in value) - return test_func + return test_func_homogenous_tuple, homogenous_type else: - for some_type in items: - test_types.append(create_custom_type(some_type)) + test_types = flatten( + (parse_typedef(some_type) for some_type in value_types), eager=True + ) - def test_func(value): - if not isinstance(value, container_type): + def test_func_heterogenous_tuple(value): + if not isinstance(value, container_cls): return False if len(value) != len(test_types): raise ValueError(f"Expecting a {len(test_types)} value tuple!") @@ -526,60 +947,126 @@ def test_func(value): return False return True - elif issubclass(container_type, AbstractMapping): - if items: - key_type_spec, value_type_spec = items + assert all( + isinstance(x, type) for x in test_types + ), f"some test types are invalid - {test_types}" + return test_func_heterogenous_tuple, tuple(test_types) + + elif issubclass(container_cls, AbstractMapping): + if value_types: + key_type_spec, value_type_spec = value_types key_type = parse_typedef(key_type_spec) value_type = parse_typedef(value_type_spec) - def test_func(mapping) -> bool: - if not isinstance(mapping, container_type): + def test_func_mapping(mapping) -> bool: + if not isinstance(mapping, container_cls): return False for key, value in mapping.items(): if not all((isinstance(key, key_type), isinstance(value, value_type))): return False return True - if test_func is None: - if items: - for some_type in items: - test_types.append(create_custom_type(some_type)) - types = tuple(test_types) + return test_func_mapping, (key_type, value_type) - def test_func(value): - if not isinstance(value, container_type): + if test_func is None: + if value_types: + for some_type in value_types: + args = get_args(some_type) + test_types.append(create_custom_type(some_type, *args)) + test_types = flatten(test_types, eager=True) + + def test_func_with_subtypes(value): + if not isinstance(value, container_cls): return False - return all(isinstance(item, types) for item in value) + return all(isinstance(item, test_types) for item in value) + + assert all(isinstance(x, type) for x in test_types) + return test_func_with_subtypes, test_types else: - def test_func(value): - return isinstance(value, container_type) + def test_func_simple(value): + return isinstance(value, container_cls) - return test_func + test_func = test_func_simple + return test_func, None -def is_typing_definition(item): - with suppress(AttributeError): - cls_module = type(item).__module__ - if cls_module in ("typing", "typing_extensions") or cls_module.startswith( - ("typing.", "typing_extensions.") - ): - return True - with suppress(AttributeError): - module = item.__module__ - if module in ("typing", "typing_extensions") or module.startswith( - ("typing.", "typing_extensions.") - ): - return True - if isinstance(item, (TypeVar, TypeVarTuple, ParamSpec)): - return True - origin = get_origin(item) - args = get_args(item) - if origin is not None: - return True - elif args: - return True + +def is_genericizable(item: TypingDefinition): + """ + Tell us if this definition is an unspecialized generic-capable type like + + >>> T = TypeVar('T') + >>> U = TypeVar('U') + >>> is_genericizable(Generic[T]) + False + >>> is_genericizable(Protocol[T]) + False + >>> class PairedNamespace(Generic[T, U]): + ... first: T + ... second: U + ... + >>> is_genericizable(PairedNamespace) + True + >>> class PairedProtocol(Protocol[T, U]): + ... first: T + ... second: U + ... def bar(self, val: T) -> U: + ... ... + ... + >>> is_genericizable(PairedProtocol) + True + >>> + + + This will reject Generic[T], Protocol[T] but allow a class that inherits one + of those. + """ + if isinstance(item, type): + generic_or_protocol = () + with suppress(TypeError): + generic_or_protocol = get_original_bases(item) + cls: Type + for cls in generic_or_protocol: + args = get_args(cls) + if any(isinstance(arg, TypeVar) for arg in args): + return True + origin_cls = get_origin(cls) + if origin_cls is Generic: + return True + return False + + +def is_simple_protocol(item: TypeHint): + """ + Tell us if this class is like: + + class X(Protocol): + required_field: int + + excludes: + + class Y(Protocol[T]): ... + + >>> T = TypeVar('T') + >>> class X(Protocol): + ... required_field: int + ... + >>> class Y(Protocol[T]): + ... required_field: T + ... + >>> is_simple_protocol(X) + True + >>> is_simple_protocol(Y) + False + + """ + if isinstance(item, type): + generic_or_protocol = () + with suppress(TypeError): + generic_or_protocol = get_original_bases(item) + return Protocol in generic_or_protocol return False @@ -616,9 +1103,10 @@ def parse_typedef( # ARJ: Okay, we're not a typing module descendant. # Are we a type itelf? if isinstance(typedef, type): + cls = typedef if check_ranges: - return create_custom_type(typedef, check_ranges=check_ranges) - return typedef + return create_custom_type(cls, check_ranges=check_ranges) + return cls elif isinstance(typedef, tuple): with suppress(ValueError): (single_type,) = typedef @@ -658,31 +1146,21 @@ def parse_typedef( new_name = new_name[len("typing.") :] else: new_name = typedef.__name__ - AbstractWrappedMeta.set_name(new_type, new_name) - AbstractWrappedMeta.set_type(new_type, typedef) + CustomTypeCheckMetaBase.set_name(new_type, new_name) + CustomTypeCheckMetaBase.set_type(new_type, typedef) return new_type elif as_origin_cls is Union: - args = get_args(typedef) assert args - if check_ranges: - return create_custom_type(typedef, check_ranges=check_ranges) - return flatten((parse_typedef(argument) for argument in args), eager=True) + return create_custom_type(typedef, *args, check_ranges=check_ranges) elif as_origin_cls is Literal: - args = get_args(typedef) if not args: raise NotImplementedError("Literals must be non-empty!") - items = [] # ARJ: We *really* should make one single type, however, # this messes with the message in the test_typedef::test_literal # and I'm not comfortable with changing the public messages globally. - for arg in args: - new_type = create_custom_type(typedef, arg) - if isinstance(arg, str): - arg = f'"{arg}"' - AbstractWrappedMeta.set_name(new_type, f"{arg!s}") - items.append(new_type) - return tuple(items) - elif as_origin_cls is not None or isinstance(typedef, TypeVar): + new_type = create_custom_type(typedef, *args) + return new_type + elif as_origin_cls is not None or isinstance(typedef, TypeVar) or is_protocol(typedef): if is_typing_definition(typedef) and hasattr(typedef, "_name") and typedef._name is None: # special cases! raise NotImplementedError( @@ -693,14 +1171,14 @@ def parse_typedef( if as_origin_cls is not None: cls = create_custom_type(as_origin_cls, *args, check_ranges=check_ranges) else: - cls = create_custom_type(typedef, check_ranges=check_ranges) + cls = create_custom_type(typedef, *args, check_ranges=check_ranges) new_name = str(typedef) if new_name.startswith(("typing_extensions.")): new_name = new_name[len("typing_extensions.") :] if new_name.startswith(("typing.")): new_name = new_name[len("typing.") :] - AbstractWrappedMeta.set_name(cls, new_name) - AbstractWrappedMeta.set_type(cls, typedef) + type(cls).set_name(cls, new_name) + type(cls).set_type(cls, typedef) return cls return as_origin_cls diff --git a/instruct/typing.py b/instruct/typing.py index 07e6569..ca1f241 100644 --- a/instruct/typing.py +++ b/instruct/typing.py @@ -1,9 +1,10 @@ import sys import typing +from contextlib import suppress from collections.abc import Collection as AbstractCollection from typing import Collection, ClassVar, Tuple, Dict, Type, Callable, Union, Any, Generic -from typing_extensions import get_origin +from typing_extensions import get_origin, get_args if sys.version_info[:2] >= (3, 8): from typing import Protocol, Literal, runtime_checkable, TypedDict @@ -167,13 +168,27 @@ def is_typing_definition(item: Any) -> TypeGuard[TypingDefinition]: """ Check if the given item is a type hint. """ - module_name: str = getattr(item, "__module__", "") - if module_name in ("typing", "typing_extensions"): + + with suppress(AttributeError): + cls_module = type(item).__module__ + if cls_module in ("typing", "typing_extensions") or cls_module.startswith( + ("typing.", "typing_extensions.") + ): + return True + with suppress(AttributeError): + module = item.__module__ + if module in ("typing", "typing_extensions") or module.startswith( + ("typing.", "typing_extensions.") + ): + return True + if isinstance(item, (TypeVar, TypeVarTuple, ParamSpec)): + return True + origin = get_origin(item) + args = get_args(item) + if origin is not None: + return True + elif args: return True - if module_name == "builtins": - origin = get_origin(item) - if origin is not None: - return is_typing_definition(origin) return False diff --git a/tests/test_atomic.py b/tests/test_atomic.py index c8e2038..60a7ba0 100644 --- a/tests/test_atomic.py +++ b/tests/test_atomic.py @@ -1,6 +1,7 @@ import json import pprint -from typing import Union, List, Tuple, Optional, Dict, Any, Type +import sys +from typing import Union, List, Tuple, Optional, Dict, Any, Type, TypeVar try: from typing import Annotated @@ -16,6 +17,7 @@ import pytest import instruct import inflection +from instruct.types import IAtomic from instruct import ( Base, add_event_listener, @@ -35,6 +37,28 @@ asjson, ) +if sys.version_info < (3, 9): + from typing_extensions import get_type_hints +else: + from typing import get_type_hints + + +def test_simple() -> None: + class Data(SimpleBase): + foo: int + bar: str + baz: Dict[str, Any] + cool: Annotated[int, "is this cool?", "yes"] + + assert get_type_hints(Data) == { + "foo": int, + "bar": str, + "baz": Dict[str, Any], + "cool": int, + } + assert isinstance(Data, IAtomic) + assert isinstance(Data(), IAtomic) + class Data(Base, history=True): __slots__ = {"field": Union[str, int], "other": str} @@ -1152,7 +1176,7 @@ class Position(Base): assert instruct.public_class(fp.worker, preserve_subtraction=True) is not Person -def test_skip_keys_coerce_classmethod(): +def test_skip_keys_coerce_classmethod() -> None: class Person(Base): id: int name: str @@ -1503,3 +1527,61 @@ class Foo(SimpleBase): pass assert list(Foo()) == [] + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.8 or higher") +def test_using_builtin_unions(): + class TestUnion(SimpleBase): + field: str | int + + TestUnion("foo") + TestUnion(1) + with pytest.raises(TypeError): + TestUnion(1.5) + + +def test_with_init_subclass(): + Registry = {} + + class Foo(SimpleBase): + def __init_subclass__(cls, swallow: str, **kwargs): + Registry[cls] = swallow + super().__init_subclass__() + + f = Foo() + + class Bar(Foo, swallow="Barn!"): + ... + + assert Bar in Registry + assert Registry[Bar] == "Barn!" + assert len(Registry) == 1 + + class BarBar(Bar, swallow="Farter"): + def __init_subclass__(cls, **kwargs): + return + + assert len(Registry) == 2 + + class BreakChainBar(BarBar): + ... + + assert len(Registry) == 2 + + +def test_simple_generics(): + T = TypeVar("T") + assert isinstance(T, TypeVar) + + class Foo(SimpleBase): + field: T + + assert isinstance(Foo._slots["field"], TypeVar) + assert Foo.__parameters__ == (T,) + + # any_instance = Foo(None) + # assert any_instance.field is None + + cls = Foo[int] + assert isinstance(cls(1).field, int) + assert get_type_hints(cls)["field"] is int diff --git a/tests/test_typedef.py b/tests/test_typedef.py index 9098247..996a764 100644 --- a/tests/test_typedef.py +++ b/tests/test_typedef.py @@ -5,6 +5,7 @@ make_custom_typecheck, has_collect_class, find_class_in_definition, + is_typing_definition, ) from instruct import Base, AtomicMeta from instruct.typing import Self, Protocol @@ -254,3 +255,35 @@ class Bar(Base): type_hints = Optional[Bar] items = tuple(find_class_in_definition(type_hints, AtomicMeta, metaclass=True)) assert items == (Bar,) + + +def test_parse_typedef_generics(): + T = TypeVar("T") + assert is_typing_definition(T) + assert tuple(find_class_in_definition((T,), TypeVar)) == (T,) + assert tuple(find_class_in_definition(T, TypeVar)) == (T,) + U = TypeVar("U") + ListOfT = List[T] + ListOfTint = ListOfT[int] + + assert not isinstance("any", parse_typedef(T)) + assert not isinstance([1, 2, 3], parse_typedef(ListOfT)) + assert not isinstance("any", parse_typedef(ListOfTint)) + assert not isinstance([None, 2, 3], parse_typedef(ListOfTint)) + + assert isinstance([1, 2, 3], parse_typedef(ListOfTint)) + + cls = parse_typedef(ListOfT) + assert callable(cls) + assert isinstance(cls, type) + cls_int = cls[int] + assert isinstance([1, 2, 3], cls_int) + + SomeDictGeneric = Dict[T, U] + SomeGenericSeq = Tuple[T, U] + assert not isinstance({"1": 1}, parse_typedef(SomeDictGeneric)) + DictStrInt = SomeDictGeneric[str, int] + assert isinstance({"1": 1}, parse_typedef(DictStrInt)) + assert isinstance((1, "str"), parse_typedef(SomeGenericSeq[int, str])) + assert not isinstance((1, 1), parse_typedef(SomeGenericSeq[int, str])) + assert not isinstance((1, 1), parse_typedef(Tuple[int, str]))