From 3e336ea628186c1b74f17bd0a0963578e1414731 Mon Sep 17 00:00:00 2001 From: "m.kindritskiy" Date: Sun, 30 Jul 2023 20:50:10 +0300 Subject: [PATCH] add enums support --- docs/changelog/changes_07.rst | 1 + docs/enums.rst | 279 ++++++++++++++++++++++++++ docs/index.rst | 1 + hiku/denormalize/base.py | 7 +- hiku/engine.py | 58 +++++- hiku/enum.py | 134 +++++++++---- hiku/federation/directive.py | 6 +- hiku/graph.py | 61 +++++- hiku/introspection/graphql.py | 101 +++++++--- hiku/introspection/types.py | 1 + hiku/types.py | 36 +++- hiku/utils/typing.py | 8 +- hiku/validate/graph.py | 16 ++ hiku/validate/query.py | 1 + tests/test_directives.py | 12 +- tests/test_enum.py | 299 ++++++++++++++++++++++++++++ tests/test_introspection_graphql.py | 102 +++++++++- 17 files changed, 1025 insertions(+), 98 deletions(-) create mode 100644 docs/enums.rst create mode 100644 tests/test_enum.py diff --git a/docs/changelog/changes_07.rst b/docs/changelog/changes_07.rst index 5acde837..2cac965e 100644 --- a/docs/changelog/changes_07.rst +++ b/docs/changelog/changes_07.rst @@ -40,6 +40,7 @@ Changes in 0.7 - Added `ID` type. - Added support for unions :ref:`Check unions documentation ` - Added support for interfaces :ref:`Check interfaces documentation ` + - Added support for enums :ref:`Check enums documentation ` Backward-incompatible changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/enums.rst b/docs/enums.rst new file mode 100644 index 00000000..d0432893 --- /dev/null +++ b/docs/enums.rst @@ -0,0 +1,279 @@ +Enums +===== + +.. _enums-doc: + +Enums are a special types that can be used to define a set of possible values for a field. + +In graphql you can use enum type like this: + +.. code-block:: + + enum Status { + ACTIVE + DELETED + } + + type Usr { + id: ID! + status: Status! + } + + type Query { + user: User! + } + + +Enum from string +---------------- + +In `hiku` you can define enum type like this: + +.. code-block:: python + + from hiku.graph import Field, Graph, Link, Node, Root + from hiku.enum import Enum + from hiku.types import ID, TypeRef, Optional, EnumRef + + users = { + 1: {'id': "1", 'status': 'ACTIVE'}, + } + + def user_fields_resolver(fields, ids): + def get_field(field, user): + if field.name == 'id': + return user.id + elif field.name == 'status': + return user.status + + return [[get_field(field, users[id]) for field in fields] for id in ids] + + def get_user(opts): + return 1 + + enums = [ + Enum('Status', ['ACTIVE', 'DELETED']) + ] + + GRAPH = Graph([ + Node('User', [ + Field('id', ID, user_fields_resolver), + Field('status', EnumRef['Status'], user_fields_resolver), + ]), + Root([ + Link('user', TypeRef['User'], get_user, requires=None), + ]), + ], enums=enums) + +Lets look at the example above: + +- ``Enum`` type is defined with a name and a list of possible values. +- ``User.status`` field has type ``EnumRef['Status']`` which is a reference to the ``Status`` enum type. +- ``status`` field returns ``user.status`` which is plain string. + +.. note:: + + You can not return a value that is not in the enum list of possible values. Hiku will raise an error if you try to do so. + +Now lets look at the query: + +.. code-block:: python + + query { + user { + id + status + } + } + +The result will be: + +.. code-block:: + + { + 'id': "1", + 'status': 'ACTIVE', + } + + +Enum from builtin Enum type +---------------------------------- + +You can also use python builtin ``Enum`` type to define an enum type in ``hiku``: + +.. code-block:: python + + from enum import Enum as PyEnum + from hiku.enum import Enum + + class Status(PyEnum): + ACTIVE = 'active' + DELETED = 'deleted' + + Graph(..., enums=[Enum.from_builtin(Status)]) + +``Enum.from_builtin`` will create ``hiku.enum.EnumFromBuiltin``: + +- ``EnumFromBuiltin`` will use ``Enum.__name__`` as a enum name. +- ``EnumFromBuiltin`` will use ``Enum.__members__`` to get a list of possible values. +- ``EnumFromBuiltin`` will use ``member.name`` to get a value name: + + .. code-block:: python + + class Status(PyEnum): + ACTIVE = 1 + DELETED = 2 + + is equivalent to: + + .. code-block:: python + + enum Status { ACTIVE, DELETED } + +If you use builtin python ``Enum``, then you MUST return enum value from the resolver function, otherwise ``hiku`` will raise an error. + +.. code-block:: python + + def user_fields_resolver(fields, ids): + def get_field(field, user): + if field.name == 'id': + return user.id + elif field.name == 'status': + return Status(user.status) + + return [[get_field(field, users[id]) for field in fields] for id in ids] + +By default ``Enum.from_builtin`` will use ``Enum.__name__`` as a name for the enum type. + +.. note:: + + You can create enum using ``Enum`` class directly if you want custom name (for example non-pep8 compliant): + + .. code-block:: python + + Status = Enum('User_Status', ['ACTIVE', 'DELETED']) + +If you want to specify different name you can pass ``name`` argument to ``Enum.from_builtin`` method. + +.. code-block:: python + + Graph(..., enums=[Enum.from_builtin(Status, name='User_Status')]) + +Custom Enum type +---------------- + +You can also create custom enum type by subclassing ``hiku.enum.BaseEnum`` class: + +.. code-block:: python + + from hiku.enum import BaseEnum + + class IntToStrEnum(BaseEnum): + _MAPPING = {1: 'one', 2: 'two', 3: 'three'} + _INVERTED_MAPPING = {v: k for k, v in _MAPPING.items()} + + def __init__(self, name: str, values: list[int], description: str = None): + super().__init__(name, [_MAPPING[v] for v in values], description) + + def parse(self, value: str) -> int: + return self._INVERTED_MAPPING[value] + + def serialize(self, value: int) -> str: + return self._MAPPING[value] + +Enum serialization +------------------ + +``Enum`` serializes values into strings. If value is not in the list of possible values, then ``hiku`` will raise an error. + +``EnumFromBuiltin`` serializes values which are instances of ``Enum`` class into strings by calling `.name` on enum value. If value is not an instance of ``Enum`` class, then ``hiku`` will raise an error. + +You can also define custom serialization for your enum type by subclassing ``hiku.enum.BaseEnum`` class. + +Enum parsing +------------ + +``Enum`` parses values into strings. If value is not in the list of possible values, then ``hiku`` will raise an error. + +``EnumFromBuiltin`` parses values into enum values by calling ``Enum(value)``. If value is not in the list of possible values, then ``hiku`` will raise an error. + +You can also define custom parsing for your enum type by subclassing ``hiku.enum.BaseEnum`` class. + +Enum as a field argument +------------------------ + +You can use enum as a field argument: + +.. code-block:: python + + import enum + from hiku.enum import Enum + from hiku.graph import Node, Root, Field, Link, Graph, Option + from hiku.types import ID, TypeRef, Optional, EnumRef + + users = [ + {'id': "1", 'status': Status.ACTIVE}, + {'id': "2", 'status': Status.DELETED}, + ] + + def link_users(opts): + ids = [] + for user in users: + # here opts['status'] will be an instance of Status enum + if user['status'] == opts['status']: + ids.append(user.id) + + return ids + + + class Status(enum.Enum): + ACTIVE = 'active' + DELETED = 'deleted' + + GRAPH = Graph([ + Node('User', [ + Field('id', ID, user_fields_resolver), + Field('status', EnumRef['Status'], user_fields_resolver), + ]), + Root([ + Link( + 'users', + Sequence[TypeRef['User']], + link_users, + requires=None, + options=[ + Option('status', EnumRef['Status'], default=Status.ACTIVE), + ] + ), + ]), + ], enums=[Enum.from_builtin(Status)]) + + +Now you can use enum as a field argument: + +.. code-block:: + + query { + users(status: DELETED) { + id + status + } + } + +The result will be: + +.. code-block:: + + [{ + "id": "2", + "status": "DELETED", + }] + + +.. note:: + + Input value will be parsed using ``.parse`` method of ``Enum`` type. + + For ``Enum`` input value will be parsed into ``str``. + + For ``EnumFromBuiltin`` input value will be parsed into python Enum instance. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 8b7707ae..673fbde5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,7 @@ User's Guide asyncio graphql protobuf + enums interfaces unions directives diff --git a/hiku/denormalize/base.py b/hiku/denormalize/base.py index 87214ccf..20d4af21 100644 --- a/hiku/denormalize/base.py +++ b/hiku/denormalize/base.py @@ -1,6 +1,7 @@ import typing as t from collections import deque +from ..enum import BaseEnum from ..graph import Graph, Interface, Union from ..query import ( QueryVisitor, @@ -24,9 +25,9 @@ def __init__(self, graph: Graph, result: Proxy) -> None: self._types = graph.__types__ self._unions = graph.unions_map self._result = result - self._type: t.Deque[t.Union[t.Type[Record], Union, Interface]] = deque( - [self._types["__root__"]] - ) + self._type: t.Deque[ + t.Union[t.Type[Record], Union, Interface, BaseEnum] + ] = deque([self._types["__root__"]]) self._data = deque([result]) self._res: t.Deque = deque() diff --git a/hiku/engine.py b/hiku/engine.py index 4b5136ed..ced889a7 100644 --- a/hiku/engine.py +++ b/hiku/engine.py @@ -18,6 +18,7 @@ Optional, DefaultDict, Awaitable, + Sequence as SequenceT, TYPE_CHECKING, ) from functools import partial @@ -25,6 +26,8 @@ from collections import defaultdict from collections.abc import Sequence, Mapping, Hashable +from hiku.types import OptionalMeta, SequenceMeta + from .cache import ( CacheVisitor, CacheInfo, @@ -76,7 +79,9 @@ def _yield_options( - graph_obj: Union[Link, Field], query_obj: Union[QueryField, QueryLink] + graph: Graph, + graph_obj: Union[Link, Field], + query_obj: Union[QueryField, QueryLink], ) -> Iterator[Tuple[str, Any]]: options = query_obj.options or {} for option in graph_obj.options: @@ -87,14 +92,19 @@ def _yield_options( option.name, graph_obj ) ) + elif option.enum_name is not None: + enum = graph.enums_map[option.enum_name] + yield option.name, enum.parse(value) else: yield option.name, value def _get_options( - graph_obj: Union[Link, Field], query_obj: Union[QueryField, QueryLink] + graph: Graph, + graph_obj: Union[Link, Field], + query_obj: Union[QueryField, QueryLink], ) -> Dict: - return dict(_yield_options(graph_obj, query_obj)) + return dict(_yield_options(graph, graph_obj, query_obj)) class InitOptions(QueryTransformer): @@ -138,7 +148,7 @@ def visit_node(self, obj: QueryNode) -> QueryNode: def visit_field(self, obj: QueryField) -> QueryField: graph_obj = self._path[-1].fields_map[obj.name] if graph_obj.options: - return obj.copy(options=_get_options(graph_obj, obj)) + return obj.copy(options=_get_options(self._graph, graph_obj, obj)) else: return obj @@ -160,7 +170,11 @@ def visit_link(self, obj: QueryLink) -> QueryLink: assert isinstance(graph_obj, Field), type(graph_obj) node = obj.node - options = _get_options(graph_obj, obj) if graph_obj.options else None + options = ( + _get_options(self._graph, graph_obj, obj) + if graph_obj.options + else None + ) return obj.copy(node=node, options=options) @@ -295,6 +309,21 @@ def _is_hashable(obj: Any) -> bool: return True +def convert_value(graph: Graph, field: Union[Field, Link], value: Any) -> Any: + if field.enum_name is None: + return value + + enum = graph.enums_map[field.enum_name] + + if isinstance(field.type, SequenceMeta): + return [enum.serialize(v) for v in value] + elif isinstance(field.type, OptionalMeta): + if value is None: + return None + + return enum.serialize(value) + + def update_index( index: Index, node: Node, @@ -312,9 +341,11 @@ def update_index( def store_fields( + graph: Graph, index: Index, node: Node, query_fields: List[Union[QueryField, QueryLink]], + graph_fields: SequenceT[Union[Field, Link]], ids: Optional[Any], query_result: Any, ) -> None: @@ -332,10 +363,12 @@ def store_fields( assert ids is not None node_idx = index[node.name] for i, row in zip(ids, query_result): - node_idx[i].update(zip(names, row)) + for field, name, value in zip(graph_fields, names, row): + node_idx[i][name] = convert_value(graph, field, value) else: assert ids is None - index.root.update(zip(names, query_result)) + for field, name, value in zip(graph_fields, names, query_result): + index.root[name] = convert_value(graph, field, value) return None @@ -785,6 +818,7 @@ def _schedule_fields( ids: Optional[Any], ) -> Union[SubmitRes, TaskSet]: query_fields = [qf for _, qf in fields] + graph_fields = [gf for gf, _ in fields] dep: Union[TaskSet, SubmitRes] if hasattr(func, "__subquery__"): @@ -799,7 +833,15 @@ def _schedule_fields( proc = dep.result def callback() -> None: - store_fields(self._index, node, query_fields, ids, proc()) + store_fields( + self._graph, + self._index, + node, + query_fields, + graph_fields, + ids, + proc(), + ) self._untrack(path) self._queue.add_callback(dep, callback) diff --git a/hiku/enum.py b/hiku/enum.py index 1ee25025..e31e3636 100644 --- a/hiku/enum.py +++ b/hiku/enum.py @@ -1,59 +1,113 @@ +import abc import dataclasses -from enum import EnumMeta -from typing import Any, List, Optional, TypeVar +import enum -EnumType = TypeVar("EnumType", bound=EnumMeta) +from typing import ( + Any, + Generic, + Optional, + Sequence, + TYPE_CHECKING, + TypeVar, + Union, +) + +if TYPE_CHECKING: + from hiku.graph import AbstractGraphVisitor @dataclasses.dataclass class EnumValue: name: str - value: Any + description: Optional[str] = None + deprecation_reason: Optional[str] = None -@dataclasses.dataclass -class EnumInfo: - wrapped_cls: EnumMeta - name: str - values: List[EnumValue] +class BaseEnum(abc.ABC): + def __init__( + self, + name: str, + values: Sequence[Union[str, EnumValue]], + description: Optional[str] = None, + ): + self.name = name + self.description = description + self.values = [ + EnumValue(v) if isinstance(v, str) else v for v in values + ] + self.values_map = {v.name: v for v in self.values} + + @abc.abstractmethod + def parse(self, value: Any) -> Any: + raise NotImplementedError + + @abc.abstractmethod + def serialize(self, value: Any) -> str: + raise NotImplementedError + + def __contains__(self, item: str) -> bool: + return item in self.values_map + + def accept(self, visitor: "AbstractGraphVisitor") -> Any: + return visitor.visit_enum(self) + +EM = TypeVar("EM", bound=enum.EnumMeta) +E = TypeVar("E", bound=enum.Enum) -def _process_enum( - cls: EnumType, - name: Optional[str] = None, -) -> EnumType: - if not isinstance(cls, EnumMeta): - raise TypeError(f"{cls} is not an Enum") - if not name: - name = cls.__name__ +class Enum(BaseEnum): + def parse(self, value: Any) -> str: + if value not in self: + raise TypeError( + "Enum '{}' can not represent value: {!r}".format( + self.name, value + ) + ) + return value - values = [] - for item in cls: # type: ignore - value = EnumValue( - item.name, - item.value, - ) - values.append(value) + def serialize(self, value: str) -> str: + if value not in self: + raise TypeError( + "Enum '{}' can not represent value: {!r}".format( + self.name, value + ) + ) + return value - cls.__enum_info__ = EnumInfo( # type: ignore - wrapped_cls=cls, - name=name, - values=values, - ) + @classmethod + def from_builtin( + cls, + enum_cls: EM, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> "EnumFromBuiltin": + return EnumFromBuiltin(enum_cls, name or enum_cls.__name__, description) - return cls +class EnumFromBuiltin(BaseEnum, Generic[E]): + def __init__( + self, enum_cls: EM, name: str, description: Optional[str] = None + ): + super().__init__(name, [v for v in enum_cls.__members__], description) + self.enum_cls = enum_cls -def enum( - _cls: Optional[EnumType] = None, - *, - name: Optional[str] = None, -) -> Any: - def wrap(cls: EnumType) -> EnumType: - return _process_enum(cls, name) + def parse(self, value: Any) -> E: + try: + return self.enum_cls(value) + except ValueError: + raise TypeError( + "Enum '{}' can not represent value: {!r}".format( + self.name, value + ) + ) - if not _cls: - return wrap + def serialize(self, value: Any) -> str: + if not isinstance(value, self.enum_cls): + raise TypeError( + "Enum '{}' can not represent value: {!r}".format( + self.name, value + ) + ) - return wrap(_cls) + return value.name diff --git a/hiku/federation/directive.py b/hiku/federation/directive.py index dd3eb5c1..a401b902 100644 --- a/hiku/federation/directive.py +++ b/hiku/federation/directive.py @@ -13,7 +13,6 @@ get_fields, wrap_dataclass, ) -from hiku.enum import enum from hiku.utils.typing import builtin_to_introspection_type T = TypeVar("T", bound="FederationSchemaDirective") @@ -29,10 +28,7 @@ class LinkImport: ... -@enum(name="link__Purpose") -class LinkPurpose(Enum): - SECURITY = "SECURITY" - EXECUTION = "EXECUTION" +LinkPurpose = Enum("link__Purpose", ["SECURITY", "EXECUTION"]) @dataclass diff --git a/hiku/graph.py b/hiku/graph.py index b0b2c01f..0405946e 100644 --- a/hiku/graph.py +++ b/hiku/graph.py @@ -14,7 +14,10 @@ from collections import OrderedDict, defaultdict from typing import List +from hiku.enum import BaseEnum + from .types import ( + EnumRefMeta, InterfaceRef, InterfaceRefMeta, Optional, @@ -107,6 +110,7 @@ def __init__( self.type = type_ self.default = default self.description = description + self.enum_name = get_enum_name(type_) if type_ is not None else None def __repr__(self) -> str: return "{}({!r}, {!r}, ...)".format( @@ -216,6 +220,7 @@ def __init__( self.options = options or () self.description = description self.directives = directives or () + self.enum_name = get_enum_name(type_) if type_ is not None else None def __repr__(self) -> str: return "{}({!r}, {!r}, {!r})".format( @@ -270,6 +275,20 @@ def is_interface(type_: GenericMeta) -> bool: return isinstance(type_, InterfaceRefMeta) +def get_enum_name(type_: GenericMeta) -> t.Optional[str]: + if isinstance(type_, OptionalMeta): + if isinstance(type_.__type__, EnumRefMeta): + return type_.__type__.__type_name__ + return None + if isinstance(type_, SequenceMeta): + return get_enum_name(type_.__item_type__) + + if isinstance(type_, EnumRefMeta): + return type_.__type_name__ + + return None + + def collect_interfaces_types( items: t.List["Node"], interfaces: t.List["Interface"] ) -> t.Dict[str, t.List[str]]: @@ -307,7 +326,7 @@ def collect_interfaces_types( # (ctx, ids) -> [] t.Callable[[t.Any, List[LT]], SyncAsync[LR]], # (ctx, ids, opts) -> [] - t.Callable[[t.Any, List[LT], List], SyncAsync[LR]], + t.Callable[[t.Any, List[LT], t.Dict], SyncAsync[LR]], ] RootLinkOne = RootLinkT[LR] @@ -504,6 +523,7 @@ def __init__( # type: ignore[no-untyped-def] self.directives = directives or () self.is_union = is_union(type_) self.is_interface = is_interface(type_) + self.enum_name = get_enum_name(type_) def __repr__(self) -> str: return "{}({!r}, {!r}, {!r}, ...)".format( @@ -687,6 +707,7 @@ def __init__( directives: t.Optional[t.Sequence[t.Type[SchemaDirective]]] = None, unions: t.Optional[t.List[Union]] = None, interfaces: t.Optional[t.List[Interface]] = None, + enums: t.Optional[t.List[BaseEnum]] = None, ): """ :param items: list of nodes @@ -699,15 +720,23 @@ def __init__( if interfaces is None: interfaces = [] - GraphValidator.validate(items, unions, interfaces) + if enums is None: + enums = [] + + GraphValidator.validate(items, unions, interfaces, enums) self.items = GraphInit.init(items) self.unions = unions self.interfaces = interfaces self.interfaces_types = collect_interfaces_types(self.items, interfaces) + self.enums: t.List[BaseEnum] = enums self.data_types = data_types or {} self.__types__ = GraphTypes.get_types( - self.items, self.unions, self.interfaces, self.data_types + self.items, + self.unions, + self.interfaces, + self.enums, + self.data_types, ) self.directives: t.Tuple[t.Type[SchemaDirective], ...] = tuple( directives or () @@ -747,6 +776,10 @@ def unions_map(self) -> OrderedDict: def interfaces_map(self) -> OrderedDict: return OrderedDict((i.name, i) for i in self.interfaces) + @cached_property + def enums_map(self) -> "OrderedDict[str, BaseEnum]": + return OrderedDict((e.name, e) for e in self.enums) + def accept(self, visitor: "AbstractGraphVisitor") -> t.Any: return visitor.visit_graph(self) @@ -780,6 +813,10 @@ def visit_union(self, obj: Union) -> t.Any: def visit_interface(self, obj: Interface) -> t.Any: pass + @abstractmethod + def visit_enum(self, obj: BaseEnum) -> t.Any: + pass + @abstractmethod def visit_root(self, obj: Root) -> t.Any: pass @@ -799,6 +836,9 @@ def visit_union(self, obj: "Union") -> t.Any: def visit_interface(self, obj: "Interface") -> t.Any: pass + def visit_enum(self, obj: "BaseEnum") -> t.Any: + pass + def visit_option(self, obj: "Option") -> t.Any: pass @@ -876,6 +916,9 @@ def visit_interface(self, obj: Interface) -> Interface: description=obj.description, ) + def visit_enum(self, obj: BaseEnum) -> BaseEnum: + return obj + def visit_root(self, obj: Root) -> Root: return Root([self.visit(f) for f in obj.fields]) @@ -929,6 +972,7 @@ def _visit_graph( items: t.List[Node], unions: t.List[Union], interfaces: t.List[Interface], + enums: t.List[BaseEnum], data_types: t.Dict[str, t.Type[Record]], ) -> t.Dict[str, t.Type[Record]]: types = OrderedDict(data_types) @@ -945,6 +989,9 @@ def _visit_graph( for interface in interfaces: types[interface.name] = self.visit(interface) + for enum in enums: + types[enum.name] = self.visit(enum) + types["__root__"] = Record[ chain.from_iterable(r.__field_types__.items() for r in roots) ] @@ -956,13 +1003,14 @@ def get_types( items: t.List[Node], unions: t.List[Union], interfaces: t.List[Interface], + enums: t.List[BaseEnum], data_types: t.Dict[str, t.Type[Record]], ) -> t.Dict[str, t.Type[Record]]: - return cls()._visit_graph(items, unions, interfaces, data_types) + return cls()._visit_graph(items, unions, interfaces, enums, data_types) def visit_graph(self, obj: Graph) -> t.Dict[str, t.Type[Record]]: return self._visit_graph( - obj.items, obj.unions, obj.interfaces, obj.data_types + obj.items, obj.unions, obj.interfaces, obj.enums, obj.data_types ) def visit_node(self, obj: Node) -> t.Type[Record]: @@ -980,5 +1028,8 @@ def visit_field(self, obj: Field) -> t.Union[FieldType, AnyMeta]: def visit_union(self, obj: Union) -> Union: return obj + def visit_enum(self, obj: BaseEnum) -> BaseEnum: + return obj + def visit_interface(self, obj: Interface) -> Interface: return obj diff --git a/hiku/introspection/graphql.py b/hiku/introspection/graphql.py index 28f08d0c..dc9d459b 100644 --- a/hiku/introspection/graphql.py +++ b/hiku/introspection/graphql.py @@ -26,6 +26,7 @@ ) from ..graph import GraphVisitor, GraphTransformer from ..types import ( + EnumRefMeta, IDMeta, InterfaceRefMeta, ScalarMeta, @@ -52,6 +53,8 @@ cached_property, ) from .types import ( + ENUM, + EnumValueIdent, INTERFACE, SCALAR, NON_NULL, @@ -177,6 +180,14 @@ def visit_unionref(self, obj: UnionRefMeta) -> t.Any: def visit_interfaceref(self, obj: InterfaceRefMeta) -> t.Any: return NON_NULL(INTERFACE(obj.__type_name__, tuple())) + def visit_enumref(self, obj: EnumRefMeta) -> t.Any: + return NON_NULL( + ENUM( + obj.__type_name__, + tuple(), + ) + ) + def visit_string(self, obj: StringMeta) -> HashedNamedTuple: return NON_NULL(SCALAR("String")) @@ -193,25 +204,10 @@ def visit_boolean(self, obj: BooleanMeta) -> HashedNamedTuple: return NON_NULL(SCALAR("Boolean")) -def not_implemented(*args: t.Any, **kwargs: t.Any) -> t.NoReturn: - raise NotImplementedError(args, kwargs) - - def na_maybe(schema: SchemaInfo) -> NothingType: return Nothing -def na_many( - schema: SchemaInfo, - ids: t.Optional[t.List] = None, - options: t.Optional[t.Any] = None, -) -> t.List[t.List]: - if ids is None: - return [] - else: - return [[] for _ in ids] - - def schema_link(schema: SchemaInfo) -> None: return None @@ -234,6 +230,12 @@ def type_link( interface.name, tuple(OBJECT(type_name) for type_name in possible_types), ) + elif name in schema.query_graph.enums_map: + enum = schema.query_graph.enums_map[name] + return ENUM( + enum.name, + tuple(OBJECT(value.name) for value in enum.values), + ) else: return Nothing @@ -267,6 +269,12 @@ def root_schema_types(schema: SchemaInfo) -> t.Iterator[HashedNamedTuple]: ), ) + for enum in schema.query_graph.enums: + yield ENUM( + enum.name, + tuple(OBJECT(value.name) for value in enum.values), + ) + def root_schema_query_type(schema: SchemaInfo) -> HashedNamedTuple: return OBJECT(QUERY_ROOT_NAME) @@ -335,6 +343,15 @@ def type_info( ident.name ].description, } + elif isinstance(ident, ENUM): + info = { + "id": ident, + "kind": "ENUM", + "name": ident.name, + "description": schema.query_graph.enums_map[ + ident.name + ].description, + } else: raise TypeError(repr(ident)) yield [info.get(f.name) for f in fields] @@ -342,7 +359,7 @@ def type_info( @listify def type_fields_link( - schema: SchemaInfo, ids: t.List, options: t.List + schema: SchemaInfo, ids: t.List, options: t.Dict ) -> t.Iterator[t.List[HashedNamedTuple]]: for ident in ids: if isinstance(ident, OBJECT): @@ -400,6 +417,23 @@ def possible_types_type_link(schema: SchemaInfo, ids: t.List) -> t.Iterator: yield [] +@listify +def enum_values_type_link( + schema: SchemaInfo, ids: t.List, opts: t.Dict +) -> t.Iterator: + if ids is None: + yield [] + + for ident in ids: + if isinstance(ident, ENUM): + yield [ + EnumValueIdent(ident.name, value.name) + for value in ident.of_types + ] + else: + yield [] + + @listify def interfaces_type_link(schema: SchemaInfo, ids: t.List) -> t.Iterator: if ids is None: @@ -508,7 +542,11 @@ def input_value_info( if option.default is Nothing: default = None else: - default = json.dumps(option.default) + if option.enum_name is not None: + enum = schema.query_graph.enums_map[option.enum_name] + default = enum.serialize(option.default) + else: + default = json.dumps(option.default) info = { "id": ident, "name": option.name, @@ -590,6 +628,24 @@ def directive_args_link( return links +@listify +def enum_value_info( + schema: SchemaInfo, + fields: t.List[Field], + ids: t.List[EnumValueIdent], # type: ignore[valid-type] +) -> t.Iterator[t.List[Any]]: + for ident in ids: + enum = schema.query_graph.enums_map[ident.enum_name] # type: ignore[attr-defined] # noqa: E501 + value = enum.values_map[ident.value_name] # type: ignore[attr-defined] + data = { + "name": value.name, + "description": value.description, + "isDeprecated": bool(value.deprecation_reason), + "deprecationReason": value.deprecation_reason, + } + yield [data[f.name] for f in fields] + + GRAPH = Graph( [ Node( @@ -627,8 +683,7 @@ def directive_args_link( Link( "enumValues", Sequence[TypeRef["__EnumValue"]], - # TODO: add enums handling - na_many, + enum_values_type_link, requires="id", options=[ Option("includeDeprecated", Boolean, default=False) @@ -699,10 +754,10 @@ def directive_args_link( Node( "__EnumValue", [ - Field("name", String, not_implemented), - Field("description", String, not_implemented), - Field("isDeprecated", Boolean, not_implemented), - Field("deprecationReason", String, not_implemented), + Field("name", String, enum_value_info), + Field("description", String, enum_value_info), + Field("isDeprecated", Boolean, enum_value_info), + Field("deprecationReason", String, enum_value_info), ], ), Node( diff --git a/hiku/introspection/types.py b/hiku/introspection/types.py index 9f7e5249..14f8cffb 100644 --- a/hiku/introspection/types.py +++ b/hiku/introspection/types.py @@ -39,6 +39,7 @@ def __hash__(self: Any) -> int: LIST = _namedtuple("LIST", "of_type") NON_NULL = _namedtuple("NON_NULL", "of_type") ENUM = _namedtuple("ENUM", "name of_types") +EnumValueIdent = _namedtuple("EnumValueIdent", "enum_name value_name") FieldIdent = _namedtuple("FieldIdent", "node, name") FieldArgIdent = _namedtuple("FieldArgIdent", "node, field, name") diff --git a/hiku/types.py b/hiku/types.py index 2ebda611..942566a4 100644 --- a/hiku/types.py +++ b/hiku/types.py @@ -6,6 +6,7 @@ if t.TYPE_CHECKING: from hiku.graph import Union, Interface + from hiku.enum import BaseEnum class GenericMeta(type): @@ -297,7 +298,26 @@ class InterfaceRef(metaclass=InterfaceRefMeta): ... -RefMeta = (TypeRefMeta, UnionRefMeta, InterfaceRefMeta) +class EnumRefMeta(TypingMeta): + __type_name__: str + + def __cls_init__(cls, *args: str) -> None: + assert len(args) == 1, f"{cls.__name__} takes exactly one argument" + + cls.__type_name__ = args[0] + + def __cls_repr__(self) -> str: + return "{}[{!r}]".format(self.__name__, self.__type_name__) + + def accept(cls, visitor: "AbstractTypeVisitor") -> t.Any: + return visitor.visit_enumref(cls) + + +class EnumRef(metaclass=EnumRefMeta): + ... + + +RefMeta = (TypeRefMeta, UnionRefMeta, InterfaceRefMeta, EnumRefMeta) @t.overload @@ -378,6 +398,10 @@ def visit_unionref(self, obj: UnionRefMeta) -> t.Any: def visit_interfaceref(self, obj: InterfaceRefMeta) -> t.Any: pass + @abstractmethod + def visit_enumref(self, obj: EnumRefMeta) -> t.Any: + pass + class TypeVisitor(AbstractTypeVisitor): def visit_any(self, obj: AnyMeta) -> t.Any: @@ -407,6 +431,9 @@ def visit_typeref(self, obj: TypeRefMeta) -> t.Any: def visit_unionref(self, obj: UnionRefMeta) -> t.Any: pass + def visit_enumref(self, obj: EnumRefMeta) -> t.Any: + pass + def visit_optional(self, obj: OptionalMeta) -> t.Any: self.visit(obj.__type__) @@ -443,6 +470,13 @@ def get_type(types: Types, typ: UnionRefMeta) -> "Union": # type: ignore[misc] ... +@t.overload +def get_type( # type: ignore[misc] + types: Types, typ: EnumRefMeta +) -> "BaseEnum": + ... + + @t.overload def get_type( # type: ignore[misc] types: Types, typ: InterfaceRefMeta diff --git a/hiku/utils/typing.py b/hiku/utils/typing.py index 9dc2644d..e2d3687f 100644 --- a/hiku/utils/typing.py +++ b/hiku/utils/typing.py @@ -1,4 +1,6 @@ import sys + +from enum import EnumMeta from typing import Any, Type, Union from hiku.introspection import types as int_types @@ -57,10 +59,10 @@ def convert(typ_: Any) -> Any: # SCALAR return int_types.LIST(convert(typ_.__args__[0])) elif hasattr(typ_, "__scalar_info__"): return int_types.SCALAR(typ_.__scalar_info__.name) - elif hasattr(typ_, "__enum_info__"): + elif isinstance(typ_, EnumMeta): return int_types.ENUM( - typ_.__enum_info__.name, - [v.value for v in typ_.__enum_info__.values], + typ_.__name__, + [v for v in typ_.__members__], ) else: raise TypeError(f"Unknown type: {typ_}") diff --git a/hiku/validate/graph.py b/hiku/validate/graph.py index c43a44b9..088bddb2 100644 --- a/hiku/validate/graph.py +++ b/hiku/validate/graph.py @@ -5,6 +5,7 @@ from collections import Counter from ..directives import Deprecated +from ..enum import BaseEnum from ..graph import ( GraphVisitor, Interface, @@ -72,11 +73,13 @@ def validate( items: t.List[Node], unions: t.List[Union], interfaces: t.List[Interface], + enums: t.List[BaseEnum], ) -> None: validator = cls(items, unions, interfaces) validator.visit_graph_items(items) validator.visit_graph_unions(unions) validator.visit_graph_interfaces(interfaces) + validator.visit_graph_enums(enums) if validator.errors.list: raise GraphValidationError(validator.errors.list) @@ -273,6 +276,15 @@ def visit_interface(self, obj: "Interface") -> t.Any: ) return + def visit_enum(self, obj: BaseEnum) -> t.Any: + if not obj.name: + self.errors.report("Enum must have a name") + return + + if not obj.values: + self.errors.report("Enum must have at least one value") + return + def visit_root(self, obj: Root) -> None: self.visit_node(obj) @@ -317,3 +329,7 @@ def visit_graph_unions(self, unions: t.List[Union]) -> None: def visit_graph_interfaces(self, interfaces: t.List[Interface]) -> None: for interface in interfaces: self.visit(interface) + + def visit_graph_enums(self, enums: t.List[BaseEnum]) -> None: + for enum in enums: + self.visit(enum) diff --git a/hiku/validate/query.py b/hiku/validate/query.py index 341e90dd..5ae4d550 100644 --- a/hiku/validate/query.py +++ b/hiku/validate/query.py @@ -77,6 +77,7 @@ def _false(self, obj: t.Any) -> None: visit_callable = _false visit_unionref = _false visit_interfaceref = _false + visit_enumref = _false visit_scalar = _false def visit_optional(self, obj: OptionalMeta) -> t.Optional[t.OrderedDict]: diff --git a/tests/test_directives.py b/tests/test_directives.py index fcf5a840..b753a08d 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -1,19 +1,17 @@ from enum import Enum from typing import List -from hiku.directives import Directive, DirectiveField, SchemaDirectiveField, Location, SchemaDirective, directive, \ +from hiku.directives import ( + Directive, Location, SchemaDirective, directive, directive_field, schema_directive, schema_directive_field -from hiku.enum import enum +) from hiku.graph import Field, Graph, Link, Node, Root, apply from hiku.introspection.graphql import GraphQLIntrospection from hiku.introspection.types import ENUM, LIST, NON_NULL, SCALAR from hiku.types import Integer, TypeRef -@enum(name='Custom_Options') -class Options(Enum): - A = 'a' - B = 'b' +Options = Enum('Custom_Options', ['A', 'B']) def test_directive_has_info(): @@ -45,7 +43,7 @@ class Custom(Directive): assert custom.__directive_info__.args[1].name == 'options' assert custom.__directive_info__.args[1].field_name == 'options' - assert custom.__directive_info__.args[1].type_ident == NON_NULL(LIST(ENUM('Custom_Options', ['a', 'b']))) + assert custom.__directive_info__.args[1].type_ident == NON_NULL(LIST(ENUM('Custom_Options', ['A', 'B']))) def test_schema_directive_has_info(): diff --git a/tests/test_enum.py b/tests/test_enum.py new file mode 100644 index 00000000..33f4ca88 --- /dev/null +++ b/tests/test_enum.py @@ -0,0 +1,299 @@ +from enum import Enum as PyEnum + +import pytest + +from hiku.denormalize.graphql import DenormalizeGraphQL +from hiku.engine import Engine +from hiku.enum import Enum +from hiku.executors.sync import SyncExecutor +from hiku.graph import Field, Graph, Link, Node, Option, Root +from hiku.types import Integer, Optional, Sequence, TypeRef, EnumRef +from hiku.utils import listify +from hiku.readers.graphql import read +from hiku.validate.graph import GraphValidationError + + +def execute(graph, query): + engine = Engine(SyncExecutor()) + result = engine.execute(graph, query, {}) + return DenormalizeGraphQL(graph, result, "query").process(query) + + +class Status(PyEnum): + ACTIVE = 'ACTIVE' + DELETED = 'DELETED' + + +def test_validate_graph_enums(): + with pytest.raises(GraphValidationError) as err: + Graph([ + Node('User', [ + Field('id', Integer, lambda: None), + ]), + Root([ + Link('user', Optional[TypeRef['User']], lambda: None, requires=None), + ]), + ], enums=[ + Enum('', ['ACTIVE', 'DELETED']), + Enum('Status', []), + ]) + + assert err.value.errors == [ + 'Enum must have a name', + 'Enum must have at least one value', + ] + + +@pytest.mark.parametrize("enum,status", [ + (Enum.from_builtin(Status, 'UserStatus'), Status.ACTIVE), + (Enum('UserStatus', ['ACTIVE', 'DELETED']), 'ACTIVE'), +]) +def test_serialize_enum_field_correct(enum, status): + @listify + def resolve_user_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + elif fname == 'status': + return status + elif fname == 'statuses': + return [status] + elif fname == 'maybe_status': + return None + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + graph = Graph([ + Node('User', [ + Field('id', Integer, resolve_user_fields), + Field('status', EnumRef['UserStatus'], resolve_user_fields), + Field('statuses', Sequence[EnumRef['UserStatus']], resolve_user_fields), + Field('maybe_status', Optional[EnumRef['UserStatus']], resolve_user_fields), + ]), + Root([ + Link('user', Optional[TypeRef['User']], lambda: 1, requires=None), + ]), + ], enums=[enum]) + + query = """ + query GetUser { + user { + id + status + statuses + maybe_status + } + } + """ + result = execute(graph, read(query)) + assert result == { + 'user': { + 'id': 1, + 'status': 'ACTIVE', + 'statuses': ['ACTIVE'], + 'maybe_status': None, + } + } + + +@pytest.mark.parametrize("enum,status,field", [ + (Enum.from_builtin(Status, 'UserStatus'), 'ACTIVE', "status"), + (Enum.from_builtin(Status, 'UserStatus'), 'ACTIVE', "statuses"), + (Enum.from_builtin(Status, 'UserStatus'), 'ACTIVE', "maybe_status"), + (Enum('UserStatus', ['ACTIVE', 'DELETED']), Status.ACTIVE, "status"), + (Enum('UserStatus', ['ACTIVE', 'DELETED']), Status.ACTIVE, "statuses"), + (Enum('UserStatus', ['ACTIVE', 'DELETED']), Status.ACTIVE, "maybe_status"), +]) +def test_serialize_enum_field_incorrect(enum, status, field): + @listify + def resolve_user_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + elif fname == 'status': + return status + elif fname == 'statuses': + return [status] + elif fname == 'maybe_status': + return status + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + def link_user(): + return 1 + + graph = Graph([ + Node('User', [ + Field('id', Integer, resolve_user_fields), + Field('status', EnumRef['UserStatus'], resolve_user_fields), + Field('statuses', Sequence[EnumRef['UserStatus']], resolve_user_fields), + Field('maybe_status', Optional[EnumRef['UserStatus']], resolve_user_fields), + ]), + Root([ + Link('user', Optional[TypeRef['User']], link_user, requires=None), + ]), + ], enums=[enum]) + + query = """ + query GetUser { user { id %s } } + """ % field + + with pytest.raises(TypeError) as err: + execute(graph, read(query)) + + err.match( + "Enum 'UserStatus' can not represent value: {!r}".format(status) + ) + + +def test_root_field_enum(): + def get_statuses(fields): + return [[v for v in Status] for f in fields] + + graph = Graph([ + Root([ + Field('statuses', Sequence[EnumRef['Status']], get_statuses), + ]), + ], enums=[Enum.from_builtin(Status)]) + + query = """ + query GetStatuses { + statuses + } + """ + result = execute(graph, read(query)) + assert result == { + 'statuses': ['ACTIVE', 'DELETED'] + } + + +@pytest.mark.parametrize("enum, status", [ + (Enum.from_builtin(Status, 'UserStatus'), Status.DELETED), + (Enum('UserStatus', ['ACTIVE', 'DELETED']), 'DELETED'), +]) +def test_parse_enum_argument(enum, status): + def link_user(opt): + if opt['status'] == status: + return 1 + raise ValueError( + 'Unknown status: {}'.format(opt['status']) + ) + + @listify + def resolve_user_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + elif fname == 'status': + return status + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + graph = Graph([ + Node('User', [ + Field('id', Integer, resolve_user_fields), + Field('status', EnumRef['UserStatus'], resolve_user_fields), + ]), + Root([ + Link( + 'user', + Optional[TypeRef['User']], + link_user, + requires=None, + options=[ + Option('status', EnumRef['UserStatus'], default=Status.ACTIVE) + ] + ), + ]), + ], enums=[enum]) + + result = execute(graph, read("query GetUser { user(status: DELETED) { id status } }")) + assert result == { + 'user': { + 'id': 1, + 'status': 'DELETED', + } + } + + +@pytest.mark.parametrize("enum, status", [ + (Enum.from_builtin(Status, 'UserStatus'), Status.ACTIVE), + (Enum('UserStatus', ['ACTIVE', 'DELETED']), 'ACTIVE'), +]) +def test_parse_enum_argument_default_value(enum, status): + def link_user(opt): + if opt['status'] == status: + return 1 + raise ValueError( + 'Unknown status: {}'.format(opt['status']) + ) + + @listify + def resolve_user_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + elif fname == 'status': + return status + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + graph = Graph([ + Node('User', [ + Field('id', Integer, resolve_user_fields), + Field('status', EnumRef['UserStatus'], resolve_user_fields), + ]), + Root([ + Link( + 'user', + Optional[TypeRef['User']], + link_user, + requires=None, + options=[ + Option('status', EnumRef['UserStatus'], default=status) + ] + ), + ]), + ], enums=[enum]) + + result = execute(graph, read("query GetUser { user { id status } }")) + assert result == { + 'user': { + 'id': 1, + 'status': 'ACTIVE' + } + } + +@pytest.mark.parametrize("enum", [ + Enum.from_builtin(Status, 'UserStatus'), + Enum('UserStatus', ['ACTIVE', 'DELETED']), +]) +def test_parse_enum_invalid_argument(enum): + graph = Graph([ + Node('User', [ + Field('id', Integer, lambda: None), + ]), + Root([ + Link( + 'user', + Optional[TypeRef['User']], + lambda: None, + requires=None, + options=[ + Option('status', EnumRef['UserStatus']) + ] + ), + ]), + ], enums=[enum]) + + query = """ + query GetUser { user(status: INVALID) { id } } + """ + with pytest.raises(TypeError) as err: + execute(graph, read(query)) + + err.match("Enum 'UserStatus' can not represent value: 'INVALID'") diff --git a/tests/test_introspection_graphql.py b/tests/test_introspection_graphql.py index de5ccc6f..1b3a2f5b 100644 --- a/tests/test_introspection_graphql.py +++ b/tests/test_introspection_graphql.py @@ -1,11 +1,14 @@ +import enum + from typing import Dict, List from unittest.mock import ANY import pytest +from hiku.enum import Enum from hiku.directives import Deprecated, Location, SchemaDirective, schema_directive from hiku.graph import Graph, Interface, Root, Field, Node, Link, Union, apply, Option -from hiku.types import InterfaceRef, String, Integer, Sequence, TypeRef, Boolean, Float, Any, UnionRef +from hiku.types import EnumRef, InterfaceRef, String, Integer, Sequence, TypeRef, Boolean, Float, Any, UnionRef from hiku.types import Optional, Record from hiku.result import denormalize from hiku.engine import Engine @@ -51,6 +54,14 @@ def _interface(name): } +def _enum(name): + return { + 'kind': 'ENUM', + 'name': name, + 'ofType': None + } + + def _iobj(name): return {'kind': 'INPUT_OBJECT', 'name': name, 'ofType': None} @@ -75,6 +86,17 @@ def _field(name, type_, **kwargs): return data +def _enum_value(name, **kwargs): + data = { + 'deprecationReason': None, + 'description': None, + 'isDeprecated': False, + 'name': name, + } + data.update(kwargs) + return data + + def _type(name, kind, **kwargs): data = { 'description': None, @@ -161,13 +183,13 @@ def _ival(name, type_, **kwargs): ] -def introspect(query_graph, mutation_graph=None): +def execute(query_str, query_graph, mutation_graph=None): engine = Engine(SyncExecutor()) query_graph = apply(query_graph, [ GraphQLIntrospection(query_graph, mutation_graph), ]) - query = read(INTROSPECTION_QUERY) + query = read(query_str) errors = validate(query_graph, query) assert not errors @@ -175,6 +197,10 @@ def introspect(query_graph, mutation_graph=None): return denormalize(query_graph, norm_result) +def introspect(query_graph, mutation_graph=None): + return execute(INTROSPECTION_QUERY, query_graph, mutation_graph) + + def test_introspection_query(): @schema_directive( name='custom', @@ -467,3 +493,73 @@ def test_interfaces(): _field('duration', _non_null(_STR)), ]), ]) + + +def test_enum(): + class Status(enum.Enum): + ACTIVE = 'ACTIVE' + DELETED = 'DELETED' + + graph = Graph([ + Node('User', [ + Field('id', Integer, _noop), + Field('status', EnumRef['UserStatus'], _noop), + ]), + Root([ + Link( + 'user', + Optional[TypeRef['User']], + _noop, + requires=None, + options=[ + Option('status', EnumRef['UserStatus'], default=Status.ACTIVE) + ] + ), + ]), + ], enums=[Enum.from_builtin(Status, 'UserStatus')]) + + assert introspect(graph) == _schema([ + _type('User', 'OBJECT', fields=[ + _field('id', _non_null(_INT)), + _field('status', _non_null(_enum('UserStatus'))), + ]), + _type('Query', 'OBJECT', fields=[ + _field('user', _obj('User'), args=[ + _ival('status', _non_null(_enum('UserStatus')), defaultValue='ACTIVE'), + ]), + ]), + _type('UserStatus', 'ENUM', enumValues=[ + _enum_value('ACTIVE'), + _enum_value('DELETED'), + ]), + ]) + + +@pytest.mark.parametrize('enum_name, expected', [ + ('Status', {'kind': 'ENUM', "enumValues": [{"name": "ACTIVE"}, {"name": "DELETED"}]}), + ('XXX', None), +]) +def test_query_enum_as_single_type(enum_name, expected): + query = """ + query IntrospectionQuery { + __type(name: "%s") { + kind + enumValues { + name + } + } + } + """ % enum_name + + class Status(enum.Enum): + ACTIVE = 'ACTIVE' + DELETED = 'DELETED' + + graph = Graph([ + Root([ + Field('status', EnumRef['Status'], _noop), + ]), + ], enums=[Enum.from_builtin(Status)]) + + got = execute(query, graph) + assert got['__type'] == expected