Skip to content

Commit

Permalink
[*] support generics
Browse files Browse the repository at this point in the history
  • Loading branch information
autumnjolitz committed Jun 13, 2024
1 parent 1c33b63 commit 4955f18
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 104 deletions.
108 changes: 88 additions & 20 deletions instruct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import types
import typing
from contextlib import suppress

from base64 import urlsafe_b64encode
from collections.abc import (
Expand Down Expand Up @@ -41,6 +42,8 @@
Type,
TYPE_CHECKING,
Union,
TypeVar,
Generic,
)
from weakref import WeakValueDictionary

Expand All @@ -59,7 +62,7 @@
)
from .typing import T, CellType, CoerceMapping, NoneType
from .types import FrozenMapping, ReadOnly, AttrsDict, ClassOrInstanceFuncsDescriptor
from .utils import flatten_fields
from .utils import flatten_fields, invert_mapping
from .subtype import wrapper_for_type
from .exceptions import (
OrphanedListenersError,
Expand Down Expand Up @@ -1139,12 +1142,13 @@ class Atomic(type):
if TYPE_CHECKING:
_data_class: Atomic
_columns: Mapping[str, Any]
_slots: Mapping[str, type]
_column_types: Mapping[str, Union[Type, Tuple[Type, ...]]]
_all_coercions: Mapping[str, Tuple[Union[Type, Tuple[Type, ...]], Callable]]
_support_columns: Tuple[str, ...]
_properties: typing.collections.KeysView[str]
_properties: typing.KeysView[str]
_configuration: AttrsDict
_all_accessible_fields: typing.collections.KeysView[str]
_all_accessible_fields: typing.KeysView[str]
# i.e. key -> List[Union[AtomicDerived, bool]] means key can hold an Atomic derived type.
_nested_atomic_collection_keys: Mapping[str, Tuple[Atomic, ...]]

Expand Down Expand Up @@ -1183,12 +1187,49 @@ def __and__(
)
return self - skip_fields

def __getitem__(self, key):
if not isinstance(key, tuple):
new_params = (key,)
else:
new_params = key
if self.__parameters__:
assert len(new_params) == len(self.__parameters__)
param_mapping = {p: n for p, n in zip(self.__parameters__, new_params)}
new_annotations = {}
complex_fields = []
typehints = get_type_hints(self)
for field_name in self.__parameters_by_field__:
typevar_params = self.__parameters_by_field__[field_name]
typehint = typehints[field_name]
if isinstance(typehint, TypeVar):
new_annotations[field_name] = param_mapping[typehint]
else:
complex_fields.append(field_name)

for attr_name in self.__parameters_by_field__.keys() & frozenset(complex_fields):
typevar_params = self.__parameters_by_field__[attr_name]
attr_type_args = []
for new_type, typevar in zip(new_params, typevar_params):
attr_type_args.append(new_type)
type_genericable = self._columns[attr_name]
new_annotations[attr_name] = type_genericable[tuple(attr_type_args)]
return type(
self.__name__,
(self - frozenset(new_annotations.keys()),),
{
"__annotations__": new_annotations,
"__args__": new_params,
"__parameters__": self.__parameters__,
},
)
raise AttributeError(key)

def __sub__(self: Atomic, skip_fields: Union[Mapping[str, Any], Iterable[Any]]) -> Atomic:
assert isinstance(skip_fields, (list, frozenset, set, tuple, dict, str, FrozenMapping))
debug_mode = is_debug_mode("skip")

root_class = type(self)
cls = public_class(self)
cls: Type[Atomic] = public_class(self)

if not skip_fields:
return self
Expand Down Expand Up @@ -1294,7 +1335,7 @@ def __new__(
concrete_class=False,
skip_fields=FrozenMapping(),
include_fields=FrozenMapping(),
**mixins,
**mixins: Any,
):
if concrete_class:
attrs["_is_data_class"] = ReadOnly(True)
Expand All @@ -1312,7 +1353,7 @@ def __new__(
if include_fields and skip_fields:
raise TypeError("Cannot specify both include_fields and skip_fields!")
data_class_attrs = {}
base_class_functions = []
pending_base_class_functions = []
# Move overrides to the data class,
# so we call them first, then the codegen pieces.
# Suitable for a single level override.
Expand All @@ -1328,15 +1369,20 @@ def __new__(
):
if key in attrs:
if hasattr(attrs[key], "_instruct_base_cls"):
base_class_functions.append(key)
pending_base_class_functions.append(key)
continue
data_class_attrs[key] = attrs.pop(key)
base_class_functions = tuple(base_class_functions)
base_class_functions = tuple(pending_base_class_functions)
support_cls_attrs = attrs
del attrs

if "__slots__" not in support_cls_attrs and "__annotations__" in support_cls_attrs:
module = import_module(support_cls_attrs["__module__"])
try:
module = import_module(support_cls_attrs["__module__"])
except KeyError:
globalns = typing.__dict__
else:
globalns = ChainMap(module.__dict__, typing.__dict__)
kwargs = {}
if _GET_TYPEHINTS_ALLOWS_EXTRA:
kwargs["include_extras"] = True
Expand All @@ -1345,7 +1391,7 @@ def __new__(
support_cls_attrs,
# First look in the module, then failsafe to the typing to support
# unimported 'Optional', et al
ChainMap(module.__dict__, typing.__dict__),
globalns,
**kwargs,
)
support_cls_attrs["__slots__"] = hints
Expand All @@ -1358,7 +1404,7 @@ def __new__(
coerce_mappings: Optional[CoerceMapping] = None
if "__coerce__" in support_cls_attrs:
if support_cls_attrs["__coerce__"] is not None:
coerce_mappings: AbstractMapping = support_cls_attrs["__coerce__"]
coerce_mappings = support_cls_attrs["__coerce__"]
if isinstance(coerce_mappings, ReadOnly):
# Unwrap
coerce_mappings = coerce_mappings.value
Expand Down Expand Up @@ -1394,7 +1440,7 @@ def __new__(
if fast is None:
fast = not __debug__

combined_columns = {}
combined_columns: Dict[Type, Type] = {}
combined_slots = {}
nested_atomic_collections: Dict[str, Atomic] = {}
# Mapping of public name -> custom type vector for `isinstance(...)` checks!
Expand Down Expand Up @@ -1524,17 +1570,32 @@ def __new__(
derived_classes = {}
current_class_columns = {}
current_class_slots = {}
for key, value in support_cls_attrs["__slots__"].items():
if isinstance(value, dict):
value = type("{}".format(key.capitalize()), bases, {"__slots__": value})
derived_classes[key] = value
if not ismetasubclass(value, Atomic):
nested_atomics = tuple(find_class_in_definition(value, Atomic, metaclass=True))
avail_generics = ()
generics_by_field = {}
for key, typehint_or_anonymous_struct_decl in support_cls_attrs["__slots__"].items():
if isinstance(typehint_or_anonymous_struct_decl, dict):
anonymous_struct_decl = typehint_or_anonymous_struct_decl
typehint = type(
"{}".format(key.capitalize()), bases, {"__slots__": anonymous_struct_decl}
)
derived_classes[key] = typehint
else:
typehint = typehint_or_anonymous_struct_decl
del typehint_or_anonymous_struct_decl
if not ismetasubclass(typehint, Atomic):
nested_atomics = tuple(find_class_in_definition(typehint, Atomic, metaclass=True))
if nested_atomics:
nested_atomic_collections[key] = nested_atomics
del nested_atomics
current_class_slots[key] = combined_slots[key] = value
current_class_columns[key] = combined_columns[key] = parse_typedef(value)
nested_generics = tuple(find_class_in_definition(typehint, TypeVar))
if nested_generics:
for item in nested_generics:
if item not in avail_generics:
avail_generics = (*avail_generics, item)
generics_by_field[key] = nested_generics
del nested_generics
current_class_slots[key] = combined_slots[key] = typehint
current_class_columns[key] = combined_columns[key] = parse_typedef(typehint)

no_op_skip_keys = []
if skip_fields:
Expand Down Expand Up @@ -1569,6 +1630,9 @@ def __new__(
no_op_skip_keys.append(key)
del current_class_slots[key]
del current_class_columns[key]
# ARJ: https://stackoverflow.com/a/54497260
# if avail_generics and Generic not in bases:
# bases = (*bases[:-1], *bases[-1], Generic[avail_generics])

# Gather listeners:
listeners, post_coerce_failure_handlers = gather_listeners(
Expand Down Expand Up @@ -1818,6 +1882,10 @@ def __new__(
support_cls_attrs["_listener_funcs"] = ReadOnly(listeners)
# Ensure public class has zero slots!
support_cls_attrs["__slots__"] = ()
if avail_generics:
support_cls_attrs["__parameters__"] = tuple(avail_generics)
support_cls_attrs["__parameters_by_field__"] = ReadOnly(generics_by_field)
support_cls_attrs["__parameter_fields__"] = ReadOnly(invert_mapping(generics_by_field))

support_cls_attrs["_data_class"] = support_cls_attrs[f"_{class_name}"] = dc = ReadOnly(None)
support_cls_attrs["_parent"] = parent_cell = ReadOnly(None)
Expand Down
Loading

0 comments on commit 4955f18

Please sign in to comment.