diff --git a/Dockerfile b/Dockerfile index e9721cc4..5368f213 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,7 +30,7 @@ RUN pdm sync -G docs FROM base as tests -RUN pdm sync -G test +RUN pdm sync -G test -G dev RUN python3 -m pip install tox tox-pdm FROM base as examples diff --git a/hiku/denormalize/base.py b/hiku/denormalize/base.py index 1fb15cf3..87214ccf 100644 --- a/hiku/denormalize/base.py +++ b/hiku/denormalize/base.py @@ -1,7 +1,7 @@ import typing as t from collections import deque -from ..graph import Graph, Union +from ..graph import Graph, Interface, Union from ..query import ( QueryVisitor, Link, @@ -12,10 +12,9 @@ from ..types import ( Record, RecordMeta, - TypeRefMeta, + RefMeta, OptionalMeta, SequenceMeta, - UnionRefMeta, get_type, ) @@ -25,7 +24,7 @@ def __init__(self, graph: Graph, result: Proxy) -> None: self._types = graph.__types__ self._unions = graph.unions_map self._result = result - self._type: t.Deque[t.Union[t.Type[Record], Union]] = deque( + self._type: t.Deque[t.Union[t.Type[Record], Union, Interface]] = deque( [self._types["__root__"]] ) self._data = deque([result]) @@ -49,7 +48,7 @@ def visit_link(self, obj: Link) -> None: else: raise AssertionError(repr(self._type[-1])) - if isinstance(type_, (TypeRefMeta, UnionRefMeta)): + if isinstance(type_, RefMeta): self._type.append(get_type(self._types, type_)) self._res.append({}) self._data.append(self._data[-1][obj.result_key]) @@ -61,7 +60,7 @@ def visit_link(self, obj: Link) -> None: type_ref = type_.__item_type__ if isinstance(type_.__item_type__, OptionalMeta): type_ref = type_.__item_type__.__type__ - assert isinstance(type_ref, (TypeRefMeta, UnionRefMeta)) + assert isinstance(type_ref, RefMeta) self._type.append(get_type(self._types, type_ref)) items = [] for item in self._data[-1][obj.result_key]: @@ -76,7 +75,7 @@ def visit_link(self, obj: Link) -> None: if self._data[-1][obj.result_key] is None: self._res[-1][obj.result_key] = None else: - assert isinstance(type_.__type__, (TypeRefMeta, UnionRefMeta)) + assert isinstance(type_.__type__, RefMeta) self._type.append(get_type(self._types, type_.__type__)) self._res.append({}) self._data.append(self._data[-1][obj.result_key]) diff --git a/hiku/denormalize/graphql.py b/hiku/denormalize/graphql.py index 708f3983..6674d465 100644 --- a/hiku/denormalize/graphql.py +++ b/hiku/denormalize/graphql.py @@ -1,15 +1,14 @@ from collections import deque -from ..graph import Graph, Union +from ..graph import Graph, Interface, Union from ..query import Field, Link from ..result import Proxy from ..types import ( RecordMeta, - TypeRefMeta, + RefMeta, SequenceMeta, OptionalMeta, GenericMeta, - UnionRefMeta, ) from .base import Denormalize @@ -25,11 +24,11 @@ def __init__( def visit_field(self, obj: Field) -> None: if obj.name == "__typename": type_name = self._type_name[-1] - if isinstance(self._type[-1], Union): + if isinstance(self._type[-1], (Union, Interface)): type_name = self._data[-1].__ref__.node self._res[-1][obj.result_key] = type_name else: - if isinstance(self._type[-1], Union): + if isinstance(self._type[-1], (Union, Interface)): type_name = self._data[-1].__ref__.node if obj.name not in self._types[type_name].__field_types__: @@ -46,16 +45,16 @@ def visit_link(self, obj: Link) -> None: raise AssertionError(repr(self._type[-1])) type_ref: GenericMeta - if isinstance(type_, (TypeRefMeta, UnionRefMeta)): + if isinstance(type_, RefMeta): type_ref = type_ elif isinstance(type_, SequenceMeta): type_ref = type_.__item_type__ if isinstance(type_ref, OptionalMeta): type_ref = type_ref.__type__ - assert isinstance(type_ref, (TypeRefMeta, UnionRefMeta)), type_ref + assert isinstance(type_ref, RefMeta), type_ref elif isinstance(type_, OptionalMeta): type_ref = type_.__type__ - assert isinstance(type_ref, (TypeRefMeta, UnionRefMeta)), type_ref + assert isinstance(type_ref, RefMeta), type_ref else: raise AssertionError(repr(type_)) self._type_name.append(type_ref.__type_name__) diff --git a/hiku/engine.py b/hiku/engine.py index a75529ff..4b5136ed 100644 --- a/hiku/engine.py +++ b/hiku/engine.py @@ -32,6 +32,7 @@ ) from .compat import Concatenate, ParamSpec from .executors.base import SyncAsyncExecutor +from .interface import SplitInterfaceQueryByNodes from .query import ( Node as QueryNode, Field as QueryField, @@ -40,6 +41,7 @@ QueryVisitor, ) from .graph import ( + Interface, Link, Maybe, MaybeMany, @@ -63,7 +65,7 @@ TaskSet, SubmitRes, ) -from .union import SplitUnionByNodes +from .union import SplitUnionQueryByNodes from .utils import ImmutableDict if TYPE_CHECKING: @@ -103,24 +105,32 @@ def __init__(self, graph: Graph) -> None: @contextlib.contextmanager def enter_path(self, type_: Any) -> Iterator[None]: try: - self._path.append(type_) + if type_ is not None: + self._path.append(type_) yield finally: - self._path.pop() + if type_ is not None: + self._path.pop() def visit_node(self, obj: QueryNode) -> QueryNode: fields = [] is_union = isinstance(self._path[-1], GraphUnion) + is_interface = isinstance(self._path[-1], Interface) for f in obj.fields: if f.name == "__typename": fields.append(f) continue - enter_path = self.enter_path if is_union else contextlib.nullcontext - type_ = self._graph.nodes_map[f.parent_type] if is_union else None + type_ = None + + if is_union: + type_ = self._graph.nodes_map[f.parent_type] + elif is_interface: + if f.parent_type is not None: + type_ = self._graph.nodes_map[f.parent_type] - with enter_path(type_): # type: ignore[operator] + with self.enter_path(type_): fields.append(self.visit(f)) return obj.copy(fields=fields) @@ -138,6 +148,8 @@ def visit_link(self, obj: QueryLink) -> QueryLink: if isinstance(graph_obj, Link): if graph_obj.is_union: self._path.append(self._graph.unions_map[graph_obj.node]) + elif graph_obj.is_interface: + self._path.append(self._graph.interfaces_map[graph_obj.node]) else: self._path.append(self._graph.nodes_map[graph_obj.node]) try: @@ -358,7 +370,7 @@ def link_ref_maybe(graph_link: Link, ident: Any) -> Optional[Reference]: if ident is Nothing: return None else: - if graph_link.is_union: + if graph_link.is_union or graph_link.is_interface: return Reference(ident[1].__type_name__, ident[0]) return Reference(graph_link.node, ident) @@ -366,13 +378,13 @@ def link_ref_maybe(graph_link: Link, ident: Any) -> Optional[Reference]: def link_ref_one(graph_link: Link, ident: Any) -> Reference: assert ident is not Nothing - if graph_link.is_union: + if graph_link.is_union or graph_link.is_interface: return Reference(ident[1].__type_name__, ident[0]) return Reference(graph_link.node, ident) def link_ref_many(graph_link: Link, idents: List) -> List[Reference]: - if graph_link.is_union: + if graph_link.is_union or graph_link.is_interface: return [Reference(i[1].__type_name__, i[0]) for i in idents] return [Reference(graph_link.node, i) for i in idents] @@ -380,7 +392,7 @@ def link_ref_many(graph_link: Link, idents: List) -> List[Reference]: def link_ref_maybe_many( graph_link: Link, idents: List ) -> List[Optional[Reference]]: - if graph_link.is_union: + if graph_link.is_union or graph_link.is_interface: return [ Reference(i[1].__type_name__, i[0]) if i is not Nothing else None for i in idents @@ -718,7 +730,7 @@ def process_link( # FIXME: call track len(ids) - 1 times because first track was # already called by process_node for this link track_times = len(grouped_ids) - 1 - union_nodes = SplitUnionByNodes( + union_nodes = SplitUnionQueryByNodes( self._graph, self._graph.unions_map[graph_link.node] ).split(query_link.node) for type_name, type_ids in grouped_ids.items(): @@ -730,7 +742,25 @@ def process_link( ) for _ in range(track_times): self._track(path) + elif graph_link.is_interface and isinstance(to_ids, list): + grouped_ids = defaultdict(list) + for id_, type_ref in to_ids: + grouped_ids[type_ref.__type_name__].append(id_) + + track_times = len(grouped_ids) - 1 + interface_nodes = SplitInterfaceQueryByNodes(self._graph).split( + query_link.node + ) + for type_name, type_ids in grouped_ids.items(): + self.process_node( + path, + self._graph.nodes_map[type_name], + interface_nodes[type_name], + list(type_ids), + ) + for _ in range(track_times): + self._track(path) else: if graph_link.type_enum is MaybeMany: to_ids = [id_ for id_ in to_ids if id_ is not Nothing] diff --git a/hiku/federation/graph.py b/hiku/federation/graph.py index 0ed55643..2d4cef7c 100644 --- a/hiku/federation/graph.py +++ b/hiku/federation/graph.py @@ -10,6 +10,7 @@ Field, Graph as _Graph, GraphTransformer, + Interface, Link, Node, Option, @@ -130,6 +131,7 @@ def __init__( data_types: t.Optional[t.Dict[str, t.Type[Record]]] = None, directives: t.Optional[t.Sequence[t.Type[SchemaDirective]]] = None, unions: t.Optional[t.List[Union]] = None, + interfaces: t.Optional[t.List[Interface]] = None, is_async: bool = False, ): if unions is None: @@ -141,4 +143,4 @@ def __init__( items = GraphInit.init(items, is_async, bool(entity_types)) - super().__init__(items, data_types, directives, unions) + super().__init__(items, data_types, directives, unions, interfaces) diff --git a/hiku/federation/sdl.py b/hiku/federation/sdl.py index b5fc921f..0a4addd2 100644 --- a/hiku/federation/sdl.py +++ b/hiku/federation/sdl.py @@ -498,6 +498,7 @@ def skip(node: Node) -> bool: obj.data_types, obj.directives, obj.unions, + obj.interfaces, ) def visit_node(self, obj: Node) -> Node: @@ -509,6 +510,7 @@ def skip(field: t.Union[Field, Link]) -> bool: fields=[self.visit(f) for f in obj.fields if not skip(f)], description=obj.description, directives=obj.directives, + implements=obj.implements, ) diff --git a/hiku/graph.py b/hiku/graph.py index 8568ad02..b0b2c01f 100644 --- a/hiku/graph.py +++ b/hiku/graph.py @@ -11,16 +11,18 @@ from abc import ABC, abstractmethod from itertools import chain from functools import reduce -from collections import OrderedDict +from collections import OrderedDict, defaultdict from typing import List from .types import ( + InterfaceRef, + InterfaceRefMeta, Optional, OptionalMeta, + RefMeta, Sequence, SequenceMeta, TypeRef, - TypeRefMeta, Record, Any, GenericMeta, @@ -236,18 +238,16 @@ class AbstractLink(AbstractBase, ABC): def get_type_enum(type_: TypingMeta) -> t.Tuple[Const, str]: - if isinstance(type_, (TypeRefMeta, UnionRefMeta)): + if isinstance(type_, RefMeta): return One, type_.__type_name__ elif isinstance(type_, OptionalMeta): - if isinstance(type_.__type__, (TypeRefMeta, UnionRefMeta)): + if isinstance(type_.__type__, RefMeta): return Maybe, type_.__type__.__type_name__ elif isinstance(type_, SequenceMeta): - if isinstance(type_.__item_type__, (TypeRefMeta, UnionRefMeta)): + if isinstance(type_.__item_type__, RefMeta): return Many, type_.__item_type__.__type_name__ elif isinstance(type_.__item_type__, OptionalMeta): - if isinstance( - type_.__item_type__.__type__, (TypeRefMeta, UnionRefMeta) - ): + if isinstance(type_.__item_type__.__type__, RefMeta): return MaybeMany, type_.__item_type__.__type__.__type_name__ raise TypeError("Invalid type specified: {!r}".format(type_)) @@ -257,11 +257,35 @@ def is_union(type_: GenericMeta) -> bool: return isinstance(type_.__type__, UnionRefMeta) if isinstance(type_, SequenceMeta): return is_union(type_.__item_type__) - # return isinstance(type_.__item_type__, UnionRefMeta) return isinstance(type_, UnionRefMeta) +def is_interface(type_: GenericMeta) -> bool: + if isinstance(type_, OptionalMeta): + return isinstance(type_.__type__, InterfaceRefMeta) + if isinstance(type_, SequenceMeta): + return is_interface(type_.__item_type__) + + return isinstance(type_, InterfaceRefMeta) + + +def collect_interfaces_types( + items: t.List["Node"], interfaces: t.List["Interface"] +) -> t.Dict[str, t.List[str]]: + interfaces_types = defaultdict(list) + for item in items: + if item.name is not None and item.implements: + for impl in item.implements: + interfaces_types[impl].append(item.name) + + for i in interfaces: + if i.name not in interfaces_types: + interfaces_types[i.name] = [] + + return dict(interfaces_types) + + LT = t.TypeVar("LT", bound=t.Hashable) LR = t.TypeVar("LR", bound=t.Optional[t.Hashable]) @@ -408,7 +432,7 @@ def __init__( def __init__( self, name: str, - type_: t.Type[UnionRef], + type_: t.Type[t.Union[UnionRef, InterfaceRef]], func: LinkOneFunc, *, requires: t.Optional[t.Union[str, t.List[str]]], @@ -479,6 +503,7 @@ def __init__( # type: ignore[no-untyped-def] self.description = description self.directives = directives or () self.is_union = is_union(type_) + self.is_interface = is_interface(type_) def __repr__(self) -> str: return "{}({!r}, {!r}, {!r}, ...)".format( @@ -518,6 +543,31 @@ def accept(self, visitor: "AbstractGraphVisitor") -> t.Any: return visitor.visit_union(self) +class Interface(AbstractBase): + def __init__( + self, + name: str, + fields: t.List["Field"], + *, + description: t.Optional[str] = None, + ): + self.name = name + self.fields = fields + self.description = description + + def __repr__(self) -> str: + return "{}({!r}, {!r}, ...)".format( + self.__class__.__name__, self.name, self.fields + ) + + @cached_property + def fields_map(self) -> OrderedDict: + return OrderedDict((f.name, f) for f in self.fields) + + def accept(self, visitor: "AbstractGraphVisitor") -> t.Any: + return visitor.visit_interface(self) + + class Node(AbstractNode): """Collection of the fields and links, which describes some entity and relations with other entities @@ -542,16 +592,20 @@ def __init__( *, description: t.Optional[str] = None, directives: t.Optional[t.Sequence[SchemaDirective]] = None, + implements: t.Optional[t.Sequence[str]] = None, ): """ :param name: name of the node :param fields: list of fields and links :param description: description of the node + :param directives: list of directives for the node + :param implements: list of interfaces implemented by the node """ self.name = name self.fields = fields self.description = description self.directives: t.Tuple[SchemaDirective, ...] = tuple(directives or ()) + self.implements = tuple(implements or []) def __repr__(self) -> str: return "{}({!r}, {!r}, ...)".format( @@ -571,6 +625,7 @@ def copy(self) -> "Node": fields=self.fields[:], description=self.description, directives=self.directives, + implements=self.implements, ) @@ -631,6 +686,7 @@ def __init__( data_types: t.Optional[t.Dict[str, t.Type[Record]]] = None, directives: t.Optional[t.Sequence[t.Type[SchemaDirective]]] = None, unions: t.Optional[t.List[Union]] = None, + interfaces: t.Optional[t.List[Interface]] = None, ): """ :param items: list of nodes @@ -640,13 +696,18 @@ def __init__( if unions is None: unions = [] - GraphValidator.validate(items, unions) + if interfaces is None: + interfaces = [] + + GraphValidator.validate(items, unions, interfaces) self.items = GraphInit.init(items) self.unions = unions + self.interfaces = interfaces + self.interfaces_types = collect_interfaces_types(self.items, interfaces) self.data_types = data_types or {} self.__types__ = GraphTypes.get_types( - self.items, self.unions, self.data_types + self.items, self.unions, self.interfaces, self.data_types ) self.directives: t.Tuple[t.Type[SchemaDirective], ...] = tuple( directives or () @@ -682,6 +743,10 @@ def nodes_map(self) -> OrderedDict: def unions_map(self) -> OrderedDict: return OrderedDict((u.name, u) for u in self.unions) + @cached_property + def interfaces_map(self) -> OrderedDict: + return OrderedDict((i.name, i) for i in self.interfaces) + def accept(self, visitor: "AbstractGraphVisitor") -> t.Any: return visitor.visit_graph(self) @@ -711,6 +776,10 @@ def visit_node(self, obj: Node) -> t.Any: def visit_union(self, obj: Union) -> t.Any: pass + @abstractmethod + def visit_interface(self, obj: Interface) -> t.Any: + pass + @abstractmethod def visit_root(self, obj: Root) -> t.Any: pass @@ -727,6 +796,9 @@ def visit(self, obj: t.Any) -> t.Any: def visit_union(self, obj: "Union") -> t.Any: pass + def visit_interface(self, obj: "Interface") -> t.Any: + pass + def visit_option(self, obj: "Option") -> t.Any: pass @@ -787,6 +859,7 @@ def visit_node(self, obj: Node) -> Node: [self.visit(f) for f in obj.fields], description=obj.description, directives=obj.directives, + implements=obj.implements, ) def visit_union(self, obj: Union) -> Union: @@ -796,6 +869,13 @@ def visit_union(self, obj: Union) -> Union: description=obj.description, ) + def visit_interface(self, obj: Interface) -> Interface: + return Interface( + obj.name, + obj.fields, + description=obj.description, + ) + def visit_root(self, obj: Root) -> Root: return Root([self.visit(f) for f in obj.fields]) @@ -805,6 +885,7 @@ def visit_graph(self, obj: Graph) -> Graph: obj.data_types, obj.directives, obj.unions, + obj.interfaces, ) @@ -847,6 +928,7 @@ def _visit_graph( self, items: t.List[Node], unions: t.List[Union], + interfaces: t.List[Interface], data_types: t.Dict[str, t.Type[Record]], ) -> t.Dict[str, t.Type[Record]]: types = OrderedDict(data_types) @@ -860,6 +942,9 @@ def _visit_graph( for union in unions: types[union.name] = self.visit(union) + for interface in interfaces: + types[interface.name] = self.visit(interface) + types["__root__"] = Record[ chain.from_iterable(r.__field_types__.items() for r in roots) ] @@ -870,12 +955,15 @@ def get_types( cls, items: t.List[Node], unions: t.List[Union], + interfaces: t.List[Interface], data_types: t.Dict[str, t.Type[Record]], ) -> t.Dict[str, t.Type[Record]]: - return cls()._visit_graph(items, unions, data_types) + return cls()._visit_graph(items, unions, interfaces, data_types) def visit_graph(self, obj: Graph) -> t.Dict[str, t.Type[Record]]: - return self._visit_graph(obj.items, obj.unions, obj.data_types) + return self._visit_graph( + obj.items, obj.unions, obj.interfaces, obj.data_types + ) def visit_node(self, obj: Node) -> t.Type[Record]: return Record[[(f.name, self.visit(f)) for f in obj.fields]] @@ -891,3 +979,6 @@ def visit_field(self, obj: Field) -> t.Union[FieldType, AnyMeta]: def visit_union(self, obj: Union) -> Union: return obj + + def visit_interface(self, obj: Interface) -> Interface: + return obj diff --git a/hiku/interface.py b/hiku/interface.py new file mode 100644 index 00000000..eae50df7 --- /dev/null +++ b/hiku/interface.py @@ -0,0 +1,40 @@ +import typing as t +from itertools import chain + +from hiku.graph import Graph +from hiku.query import Node + + +class SplitInterfaceQueryByNodes: + """ + Split query node into query nodes by interface types with keys + as graph node names. + + Useful when you need to get query nodes for interface + + :return: dict with query nodes as values and graph node names as keys. + """ + + def __init__(self, graph: Graph) -> None: + self._graph = graph + + def split(self, obj: Node) -> t.Dict[str, Node]: + types = [ + self._graph.nodes_map[type_] + for type_ in set( + chain.from_iterable(self._graph.interfaces_types.values()) + ) + ] + + nodes = {} + for type_ in types: + fields = [] + for field in obj.fields: + if field.name in type_.fields_map and ( + field.parent_type == type_.name or field.parent_type is None + ): + fields.append(field) + + nodes[type_.name] = obj.copy(fields=fields) + + return nodes diff --git a/hiku/introspection/graphql.py b/hiku/introspection/graphql.py index 1af7d4bf..28f08d0c 100644 --- a/hiku/introspection/graphql.py +++ b/hiku/introspection/graphql.py @@ -27,6 +27,7 @@ from ..graph import GraphVisitor, GraphTransformer from ..types import ( IDMeta, + InterfaceRefMeta, ScalarMeta, TypeRef, String, @@ -51,6 +52,7 @@ cached_property, ) from .types import ( + INTERFACE, SCALAR, NON_NULL, LIST, @@ -172,6 +174,9 @@ def visit_typeref(self, obj: TypeRefMeta) -> HashedNamedTuple: def visit_unionref(self, obj: UnionRefMeta) -> t.Any: return NON_NULL(UNION(obj.__type_name__, tuple())) + def visit_interfaceref(self, obj: InterfaceRefMeta) -> t.Any: + return NON_NULL(INTERFACE(obj.__type_name__, tuple())) + def visit_string(self, obj: StringMeta) -> HashedNamedTuple: return NON_NULL(SCALAR("String")) @@ -222,6 +227,13 @@ def type_link( return UNION( union.name, tuple(OBJECT(type_name) for type_name in union.types) ) + elif name in schema.query_graph.interfaces_map: + interface = schema.query_graph.interfaces_map[name] + possible_types = schema.query_graph.interfaces_types[name] + return INTERFACE( + interface.name, + tuple(OBJECT(type_name) for type_name in possible_types), + ) else: return Nothing @@ -244,6 +256,17 @@ def root_schema_types(schema: SchemaInfo) -> t.Iterator[HashedNamedTuple]: union.name, tuple(OBJECT(type_name) for type_name in union.types) ) + for interface in schema.query_graph.interfaces: + yield INTERFACE( + interface.name, + tuple( + OBJECT(type_name) + for type_name in schema.query_graph.interfaces_types[ + interface.name + ] + ), + ) + def root_schema_query_type(schema: SchemaInfo) -> HashedNamedTuple: return OBJECT(QUERY_ROOT_NAME) @@ -303,6 +326,15 @@ def type_info( ident.name ].description, } + elif isinstance(ident, INTERFACE): + info = { + "id": ident, + "kind": "INTERFACE", + "name": ident.name, + "description": schema.query_graph.interfaces_map[ + ident.name + ].description, + } else: raise TypeError(repr(ident)) yield [info.get(f.name) for f in fields] @@ -334,6 +366,13 @@ def type_fields_link( "to define schema type".format(ident.name) ) yield field_idents + elif isinstance(ident, INTERFACE): + interface = schema.query_graph.interfaces_map[ident.name] + yield [ + FieldIdent(ident.name, f.name) + for f in interface.fields + if not schema.is_field_hidden(f) + ] else: yield [] @@ -355,12 +394,27 @@ def possible_types_type_link(schema: SchemaInfo, ids: t.List) -> t.Iterator: yield [] for ident in ids: - if isinstance(ident, UNION): + if isinstance(ident, (UNION, INTERFACE)): yield ident.possible_types else: yield [] +@listify +def interfaces_type_link(schema: SchemaInfo, ids: t.List) -> t.Iterator: + if ids is None: + yield [] + + for ident in ids: + if isinstance(ident, OBJECT) and ident.name in schema.nodes_map: + node = schema.nodes_map[ident.name] + yield [ + INTERFACE(interface, tuple()) for interface in node.implements + ] + else: + yield [] + + @listify def field_info( schema: SchemaInfo, fields: t.List[Field], ids: t.List @@ -401,6 +455,10 @@ def field_type_link( node = schema.nodes_map[ident.node] field = node.fields_map[ident.name] yield type_ident.visit(field.type or Any) + elif ident.node in schema.query_graph.interfaces_map: + interface = schema.query_graph.interfaces_map[ident.node] + field = interface.fields_map[ident.name] + yield type_ident.visit(field.type or Any) else: data_type = schema.data_types[ident.node] field_type = data_type.__field_types__[ident.name] @@ -555,7 +613,7 @@ def directive_args_link( Link( "interfaces", Sequence[TypeRef["__Type"]], - na_many, + interfaces_type_link, requires="id", ), # INTERFACE and UNION only @@ -891,6 +949,7 @@ def visit_graph(self, obj: Graph) -> Graph: data_types=obj.data_types, directives=obj.directives, unions=obj.unions, + interfaces=obj.interfaces, ) diff --git a/hiku/introspection/types.py b/hiku/introspection/types.py index 95fa52ee..9f7e5249 100644 --- a/hiku/introspection/types.py +++ b/hiku/introspection/types.py @@ -33,6 +33,7 @@ def __hash__(self: Any) -> int: SCALAR = _namedtuple("SCALAR", "name") OBJECT = _namedtuple("OBJECT", "name") UNION = _namedtuple("UNION", "name possible_types") +INTERFACE = _namedtuple("INTERFACE", "name possible_types") DIRECTIVE = _namedtuple("DIRECTIVE", "name") INPUT_OBJECT = _namedtuple("INPUT_OBJECT", "name") LIST = _namedtuple("LIST", "of_type") diff --git a/hiku/result.py b/hiku/result.py index 04b9b957..45138cf5 100644 --- a/hiku/result.py +++ b/hiku/result.py @@ -175,12 +175,6 @@ def _denormalize( graph, graph_obj.fields_map[f.name], result[f.result_key], f ) return r - # return { - # f.result_key: _denormalize( - # graph, graph_obj.fields_map[f.name], result[f.result_key], f - # ) - # for f in query_obj.fields - # } elif isinstance(query_obj, Field): return result diff --git a/hiku/types.py b/hiku/types.py index 3f6b096b..2ebda611 100644 --- a/hiku/types.py +++ b/hiku/types.py @@ -5,7 +5,7 @@ from typing import TypeVar if t.TYPE_CHECKING: - from hiku.graph import Union + from hiku.graph import Union, Interface class GenericMeta(type): @@ -278,6 +278,28 @@ class UnionRef(metaclass=UnionRefMeta): ... +class InterfaceRefMeta(TypingMeta): + __type_name__: str + + def __cls_init__(cls, *args: str) -> None: + assert len(args) == 1, f"{cls.__name__} takes exactly one argument" + + cls.__type_name__ = args[0] + + def __cls_repr__(self) -> str: + return "{}[{!r}]".format(self.__name__, self.__type_name__) + + def accept(cls, visitor: "AbstractTypeVisitor") -> t.Any: + return visitor.visit_interfaceref(cls) + + +class InterfaceRef(metaclass=InterfaceRefMeta): + ... + + +RefMeta = (TypeRefMeta, UnionRefMeta, InterfaceRefMeta) + + @t.overload def _maybe_typeref(typ: str) -> TypeRefMeta: ... @@ -352,6 +374,10 @@ def visit_callable(self, obj: CallableMeta) -> t.Any: def visit_unionref(self, obj: UnionRefMeta) -> t.Any: pass + @abstractmethod + def visit_interfaceref(self, obj: InterfaceRefMeta) -> t.Any: + pass + class TypeVisitor(AbstractTypeVisitor): def visit_any(self, obj: AnyMeta) -> t.Any: @@ -402,7 +428,7 @@ def visit_callable(self, obj: CallableMeta) -> t.Any: T = TypeVar("T", bound=GenericMeta) -Types = t.Mapping[str, t.Union[t.Type[Record], "Union"]] +Types = t.Mapping[str, t.Union[t.Type[Record], "Union", "Interface"]] @t.overload @@ -417,15 +443,20 @@ def get_type(types: Types, typ: UnionRefMeta) -> "Union": # type: ignore[misc] ... +@t.overload +def get_type( # type: ignore[misc] + types: Types, typ: InterfaceRefMeta +) -> "Interface": + ... + + @t.overload def get_type(types: Types, typ: T) -> T: ... def get_type(types: Types, typ: t.Any) -> t.Any: - if isinstance(typ, TypeRefMeta): - return types[typ.__type_name__] - if isinstance(typ, UnionRefMeta): + if isinstance(typ, RefMeta): return types[typ.__type_name__] else: return typ diff --git a/hiku/union.py b/hiku/union.py index 6b27fea9..bc7fd315 100644 --- a/hiku/union.py +++ b/hiku/union.py @@ -4,7 +4,7 @@ from hiku.query import Node -class SplitUnionByNodes: +class SplitUnionQueryByNodes: """ Split query node into query nodes by union types with keys as graph node names. diff --git a/hiku/utils/__init__.py b/hiku/utils/__init__.py index 15cd75fc..eb378c29 100644 --- a/hiku/utils/__init__.py +++ b/hiku/utils/__init__.py @@ -1,17 +1,23 @@ from functools import wraps from typing import ( + Any, Callable, + TYPE_CHECKING, TypeVar, List, Iterator, ) + from hiku.compat import ParamSpec from .immutable import ImmutableDict, to_immutable_dict from .cached_property import cached_property from .const import const, Const +if TYPE_CHECKING: + from hiku.query import Field + T = TypeVar("T") P = ParamSpec("P") @@ -25,6 +31,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> List[T]: return wrapper +def empty_field(fields: List["Field"], ids: Any) -> Any: + return [[None]] * len(ids) + + __all__ = [ "ImmutableDict", "to_immutable_dict", diff --git a/hiku/validate/graph.py b/hiku/validate/graph.py index 9d34b042..c43a44b9 100644 --- a/hiku/validate/graph.py +++ b/hiku/validate/graph.py @@ -7,6 +7,7 @@ from ..directives import Deprecated from ..graph import ( GraphVisitor, + Interface, Root, Field, Node, @@ -51,17 +52,31 @@ def visit_option(self, obj: Option) -> str: _link_accept_types = (AbstractOption,) _field_accept_types = (AbstractOption,) - def __init__(self, items: t.List[Node], unions: t.List[Union]) -> None: + def __init__( + self, + items: t.List[Node], + unions: t.List[Union], + interfaces: t.List[Interface], + ) -> None: self.items = items self.unions = unions + self.unions_map = {u.name: u for u in unions} + self.interfaces = interfaces + self.interfaces_map = {i.name: i for i in interfaces} self.errors = Errors() self._ctx: t.List = [] @classmethod - def validate(cls, items: t.List[Node], unions: t.List[Union]) -> None: - validator = cls(items, unions) + def validate( + cls, + items: t.List[Node], + unions: t.List[Union], + interfaces: t.List[Interface], + ) -> None: + validator = cls(items, unions, interfaces) validator.visit_graph_items(items) validator.visit_graph_unions(unions) + validator.visit_graph_interfaces(interfaces) if validator.errors.list: raise GraphValidationError(validator.errors.list) @@ -152,15 +167,15 @@ def visit_link(self, obj: Link) -> None: with self.push_ctx(obj): super(GraphValidator, self).visit_link(obj) - graph_nodes_map = {e.name for e in self.items if e.name is not None} - unions_map = {u.name: u for u in self.unions} - if obj.node not in graph_nodes_map: - if obj.node not in unions_map: - self.errors.report( - 'Link "{}" points to the missing node "{}"'.format( - self._format_path(obj), obj.node - ) + graph_nodes = {e.name for e in self.items if e.name is not None} + if obj.node not in ( + graph_nodes | self.unions_map.keys() | self.interfaces_map.keys() + ): + self.errors.report( + 'Link "{}" points to the missing node "{}"'.format( + self._format_path(obj), obj.node ) + ) if obj.requires is not None: requires = ( @@ -207,6 +222,15 @@ def visit_node(self, obj: Node) -> None: if sum((1 for d in obj.directives if isinstance(d, Deprecated))) > 0: self.errors.report("Deprecated directive can not be used in Node") + if obj.implements: + for i in obj.implements: + if i not in self.interfaces_map: + self.errors.report( + 'Node "{}" implements missing interface "{}"'.format( + node_name, i + ) + ) + def visit_union(self, obj: "Union") -> t.Any: if not obj.name: self.errors.report("Union must have a name") @@ -226,6 +250,29 @@ def visit_union(self, obj: "Union") -> t.Any: ) return + def visit_interface(self, obj: "Interface") -> t.Any: + if not obj.name: + self.errors.report("Interface must have a name") + + if not obj.fields: + self.errors.report( + "Interface '{}' must have at least one field".format(obj.name) + ) + + invalid = [ + type_ + for type_ in obj.fields + if not isinstance(type_, AbstractField) + ] + + if invalid: + self.errors.report( + "Interface '{}' fields must be of type 'Field', found '{}'".format( # noqa: E501 + obj.name, invalid + ) + ) + return + def visit_root(self, obj: Root) -> None: self.visit_node(obj) @@ -266,3 +313,7 @@ def visit_graph_items(self, items: t.List[Node]) -> None: def visit_graph_unions(self, unions: t.List[Union]) -> None: for union in unions: self.visit(union) + + def visit_graph_interfaces(self, interfaces: t.List[Interface]) -> None: + for interface in interfaces: + self.visit(interface) diff --git a/hiku/validate/query.py b/hiku/validate/query.py index df15acde..b5d6cd00 100644 --- a/hiku/validate/query.py +++ b/hiku/validate/query.py @@ -28,6 +28,7 @@ QueryVisitor, ) from hiku.graph import ( + Interface, Node, Field, Link, @@ -75,6 +76,7 @@ def _false(self, obj: t.Any) -> None: visit_mapping = _false visit_callable = _false visit_unionref = _false + visit_interfaceref = _false visit_scalar = _false def visit_optional(self, obj: OptionalMeta) -> t.Optional[t.OrderedDict]: @@ -407,6 +409,8 @@ def visit_link(self, obj: QueryLink) -> None: elif isinstance(graph_obj, Link): if graph_obj.is_union: linked_node = self.graph.unions_map[graph_obj.node] + elif graph_obj.is_interface: + linked_node = self.graph.interfaces_map[graph_obj.node] else: linked_node = self.graph.nodes_map[graph_obj.node] @@ -434,6 +438,7 @@ def visit_node(self, obj: QueryNode) -> None: fields: t.Dict = {} is_union_link = isinstance(self.path[-1], Union) + is_interface_link = isinstance(self.path[-1], Interface) for field in obj.fields: if field.name == "__typename": @@ -455,6 +460,46 @@ def visit_node(self, obj: QueryNode) -> None: continue self.path.append(self.graph.nodes_map[field.parent_type]) + elif is_interface_link: + interface = self.path[-1] + if field.parent_type is not None: + self.path.append(self.graph.nodes_map[field.parent_type]) + else: + interface_types = self.graph.interfaces_types[ + interface.name + ] + if not interface_types: + self.errors.report( + "Can not query field '{0}' on interface '{1}'. " + "Interface '{1}' is not implemented by any type. " + "Add at least one type implementing this interface.".format( # noqa: E501 + field.name, interface.name + ) + ) + continue + + if field.name not in interface.fields_map: + implementation = None + for impl in interface_types: + if ( + field.name + in self.graph.nodes_map[impl].fields_map + ): + implementation = impl + break + + self.errors.report( + "Can not query field '{}' on type '{}'. " + "Did you mean to use an inline fragment on '{}'?".format( # noqa: E501 + field.name, interface.name, implementation + ) + ) + continue + + # take the first implementing node + # it does not matter which one, because all of them have + # the same field from interface + self.path.append(self.graph.nodes_map[interface_types[0]]) seen = fields.get(field.result_key) if seen is not None: @@ -470,7 +515,7 @@ def visit_node(self, obj: QueryNode) -> None: fields[field.result_key] = field self.visit(field) - if is_union_link: + if is_union_link or is_interface_link: self.path.pop() diff --git a/tests/test_interface.py b/tests/test_interface.py new file mode 100644 index 00000000..0493d1ec --- /dev/null +++ b/tests/test_interface.py @@ -0,0 +1,523 @@ +import pytest + +from hiku.denormalize.graphql import DenormalizeGraphQL +from hiku.engine import Engine +from hiku.interface import SplitInterfaceQueryByNodes +from hiku.executors.sync import SyncExecutor +from hiku.graph import Field, Graph, Interface, Link, Node, Option, Root +from hiku.types import Integer, InterfaceRef, Optional, Sequence, String, TypeRef +from hiku.utils import empty_field, listify +from hiku.readers.graphql import read +from hiku.validate.graph import GraphValidationError +from hiku.validate.query import validate + + +def execute(graph, query): + engine = Engine(SyncExecutor()) + result = engine.execute(graph, query, {}) + return DenormalizeGraphQL(graph, result, "query").process(query) + + +@listify +def resolve_user_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + +@listify +def resolve_audio_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + if fname == 'duration': + return f'{id_}s' + if fname == 'album': + return f'album#{id_}' + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + +@listify +def resolve_video_fields(fields, ids): + def get_field(fname, id_): + if fname == 'id': + return id_ + if fname == 'duration': + return f'{id_}s' + if fname == 'thumbnailUrl': + return f'/video/{id_}' + + for id_ in ids: + yield [get_field(f.name, id_) for f in fields] + + +def link_user_media(): + return [ + (1, TypeRef['Audio']), + (2, TypeRef['Video']), + ] + + +def link_user(): + return 111 + + +def search_media(opts): + if opts['text'] != 'foo': + return [] + return [ + (1, TypeRef['Audio']), + (2, TypeRef['Video']), + (3, TypeRef['Audio']), + (4, TypeRef['Video']), + ] + + +def get_media(): + return 1, TypeRef['Audio'] + + +def maybe_get_media(): + return 2, TypeRef['Video'] + + +GRAPH = Graph([ + Node('Audio', [ + Field('id', Integer, resolve_audio_fields), + Field('duration', String, resolve_audio_fields), + Field('album', String, resolve_audio_fields), + ], implements=['Media']), + Node('Video', [ + Field('id', Integer, resolve_video_fields), + Field('duration', String, resolve_video_fields), + Field('thumbnailUrl', String, resolve_video_fields, options=[ + Option('size', Integer), + ]), + ], implements=['Media']), + Node('User', [ + Field('id', Integer, resolve_user_fields), + Link('media', Sequence[InterfaceRef['Media']], link_user_media, requires=None), + ]), + Root([ + Link( + 'searchMedia', + Sequence[InterfaceRef['Media']], + search_media, + options=[ + Option('text', String), + ], + requires=None + ), + Link('media', InterfaceRef['Media'], get_media, requires=None), + Link('maybeMedia', Optional[InterfaceRef['Media']], maybe_get_media, requires=None), + Link('user', Optional[TypeRef['User']], link_user, requires=None), + ]), +], interfaces=[ + Interface('Media', [ + Field('id', Integer, empty_field), + Field('duration', String, empty_field), + ]), +]) + + +def test_validate_graph_with_interface(): + with pytest.raises(GraphValidationError) as err: + Graph([ + Node('Audio', [ + Field('id', Integer, resolve_audio_fields), + Field('duration', String, resolve_audio_fields), + Field('album', String, resolve_audio_fields), + ], implements=['WrongInterface']), + Node('Video', [ + Field('id', Integer, resolve_video_fields), + Field('duration', String, resolve_video_fields), + Field('thumbnailUrl', String, resolve_video_fields), + ], implements=['Media']), + Root([ + Link( + 'searchMedia', + Sequence[InterfaceRef['Media']], + search_media, + options=[ + Option('text', String), + ], + requires=None + ), + ]), + ], interfaces=[ + Interface('Media', []), + Interface('', [Field('id', Integer, empty_field)]), + Interface('Invalid', ['WrongType']), # type: ignore + ]) + + assert err.value.errors == [ + 'Node "Audio" implements missing interface "WrongInterface"', + "Interface 'Media' must have at least one field", + 'Interface must have a name', + "Interface 'Invalid' fields must be of type 'Field', found '['WrongType']'", + ] + + +def test_option_not_provided_for_field(): + query = """ + query GetMedia { + media { + __typename + id + duration + ... on Audio { + album + } + ... on Video { + thumbnailUrl + } + } + } + """ + with pytest.raises(TypeError) as err: + execute(GRAPH, read(query)) + err.match("Required option \"size\" for Field('thumbnailUrl'") + + +def test_root_link_to_interface_list(): + query = """ + query SearchMedia($text: String) { + searchMedia(text: $text) { + __typename + id + duration + ... on Audio { + album + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + """ + """ + [(f.name, f.parent_type) for f in node.fields] + where node is searchMedia query node + + [('__typename', None), ('id', None), ('duration', None), ('album', 'Audio'), ('thumbnailUrl', 'Video')] + + """ + result = execute(GRAPH, read(query, {'text': 'foo'})) + assert result == { + 'searchMedia': [ + {'__typename': 'Audio', 'id': 1, 'duration': '1s', 'album': 'album#1'}, + {'__typename': 'Video', 'id': 2, 'duration': '2s', 'thumbnailUrl': '/video/2'}, + {'__typename': 'Audio', 'id': 3, 'duration': '3s', 'album': 'album#3'}, + {'__typename': 'Video', 'id': 4, 'duration': '4s', 'thumbnailUrl': '/video/4'}, + ] + } + + +def test_root_link_to_interface_one(): + query = """ + query GetMedia { + media { + __typename + id + duration + ... on Audio { + album + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + """ + result = execute(GRAPH, read(query)) + assert result == { + 'media': {'__typename': 'Audio', 'id': 1, 'duration': '1s', 'album': 'album#1'}, + } + + +def test_root_link_to_interface_optional(): + query = """ + query MaybeGetMedia { + maybeMedia { + __typename + id + duration + ... on Audio { + album + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + """ + result = execute(GRAPH, read(query)) + assert result == { + 'maybeMedia': {'__typename': 'Video', 'id': 2, 'duration': '2s', 'thumbnailUrl': '/video/2'}, + } + + +def test_non_root_link_to_interface_list(): + query = """ + query GetUserMedia { + user { + id + media { + __typename + id + duration + ... on Audio { + album + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + } + """ + result = execute(GRAPH, read(query)) + assert result == { + 'user': { + 'id': 111, + 'media': [ + {'__typename': 'Audio', 'id': 1, 'duration': '1s', 'album': 'album#1'}, + {'__typename': 'Video', 'id': 2, 'duration': '2s', 'thumbnailUrl': '/video/2'}, + ] + } + } + + +def test_query_with_inline_fragment_and_fragment_spread(): + query = """ + query GetMedia { + media { + __typename + id + duration + ...AudioFragment + ... on Video { + thumbnailUrl(size: 100) + } + } + } + + fragment AudioFragment on Audio { + album + } + """ + result = execute(GRAPH, read(query)) + assert result == { + 'media': {'__typename': 'Audio', 'id': 1, 'duration': '1s', 'album': 'album#1'}, + } + + +def test_query_can_be_without_shared_fields(): + query = """ + query GetMedia { + media { + __typename + ... on Audio { + id + duration + album + } + ... on Video { + id + duration + thumbnailUrl(size: 100) + } + } + } + """ + + result = execute(GRAPH, read(query)) + assert result == { + 'media': {'__typename': 'Audio', 'id': 1, 'duration': '1s', 'album': 'album#1'}, + } + + +def test_validate_interface_has_no_implementations(): + graph = Graph([ + Root([ + Link( + 'media', + InterfaceRef['Media'], + lambda: None, + requires=None + ), + ]), + ], interfaces=[ + Interface('Media', [ + Field('id', Integer, empty_field), + Field('duration', String, empty_field), + ]), + ]) + + query = """ + query GetMedia { + media { + id + duration + } + } + """ + + errors = validate(graph, read(query)) + + assert errors == [ + "Can not query field 'id' on interface 'Media'. " + "Interface 'Media' is not implemented by any type. " + "Add at least one type implementing this interface.", + + "Can not query field 'duration' on interface 'Media'. Interface 'Media' " + "is not implemented by any type. Add at least one type implementing this " + "interface.", + ] + + +def test_validate_query_implementation_node_field_without_inline_fragment(): + query = """ + query GetMedia { + media { + id + duration + album + } + } + """ + + errors = validate(GRAPH, read(query)) + + assert errors == [ + "Can not query field 'album' on type 'Media'. " + "Did you mean to use an inline fragment on 'Audio'?" + ] + + +def test_validate_interface_type_has_no_such_field(): + query = """ + query SearchMedia($text: String) { + searchMedia(text: $text) { + ... on Audio { + invalid_field + duration + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + """ + + errors = validate(GRAPH, read(query, {'text': 'foo'})) + + assert errors == [ + 'Field "invalid_field" is not implemented in the "Audio" node', + ] + + +def test_validate_interface_type_field_has_no_such_option(): + query = """ + query SearchMedia($text: String) { + searchMedia(text: $text) { + ... on Audio { + duration(size: 100) + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + """ + + errors = validate(GRAPH, read(query, {'text': 'foo'})) + + assert errors == [ + 'Unknown options for "Audio.duration": size', + ] + + +def test_split_interface_query_into_nodes(): + query_raw = """ + query SearchMedia($text: String) { + searchMedia(text: $text) { + __typename + id + duration + ... on Audio { + album + } + ... on Video { + thumbnailUrl(size: 100) + } + } + } + """ + query = read(query_raw, {'text': 'foo'}) + + nodes = SplitInterfaceQueryByNodes(GRAPH).split( + query.fields_map['searchMedia'].node + ) + + assert len(nodes) == 2 + + assert 'Audio' in nodes + assert 'Video' in nodes + + assert nodes['Audio'].fields_map.keys() == { + 'id', + 'duration', + ('Audio', 'album') + } + assert nodes['Video'].fields_map.keys() == { + 'id', + 'duration', + ('Video', 'thumbnailUrl') + } + + +def test_split_interface_query_into_nodes_if_no_shared_fields_in_query(): + query_raw = """ + query SearchMedia($text: String) { + searchMedia(text: $text) { + __typename + ... on Audio { + id + duration + album + } + ... on Video { + id + duration + thumbnailUrl(size: 100) + } + } + } + """ + query = read(query_raw, {'text': 'foo'}) + + nodes = SplitInterfaceQueryByNodes(GRAPH).split( + query.fields_map['searchMedia'].node + ) + + assert len(nodes) == 2 + + assert 'Audio' in nodes + assert 'Video' in nodes + + assert nodes['Audio'].fields_map.keys() == { + ('Audio', 'id'), + ('Audio', 'duration'), + ('Audio', 'album') + } + assert nodes['Video'].fields_map.keys() == { + ('Video', 'id'), + ('Video', 'duration'), + ('Video', 'thumbnailUrl') + } diff --git a/tests/test_introspection_graphql.py b/tests/test_introspection_graphql.py index 096559bc..de5ccc6f 100644 --- a/tests/test_introspection_graphql.py +++ b/tests/test_introspection_graphql.py @@ -4,8 +4,8 @@ import pytest from hiku.directives import Deprecated, Location, SchemaDirective, schema_directive -from hiku.graph import Graph, Root, Field, Node, Link, Union, apply, Option -from hiku.types import String, Integer, Sequence, TypeRef, Boolean, Float, Any, UnionRef +from hiku.graph import Graph, Interface, Root, Field, Node, Link, Union, apply, Option +from hiku.types import InterfaceRef, String, Integer, Sequence, TypeRef, Boolean, Float, Any, UnionRef from hiku.types import Optional, Record from hiku.result import denormalize from hiku.engine import Engine @@ -43,6 +43,14 @@ def _union(name): } +def _interface(name): + return { + 'kind': 'INTERFACE', + 'name': name, + 'ofType': None + } + + def _iobj(name): return {'kind': 'INPUT_OBJECT', 'name': name, 'ofType': None} @@ -409,3 +417,53 @@ def test_unions(): _obj('Video'), ]), ]) + + +def test_interfaces(): + graph = Graph([ + Node('Audio', [ + Field('id', Integer, _noop), + Field('duration', String, _noop), + Field('album', String, _noop), + ], implements=['Media']), + Node('Video', [ + Field('id', Integer, _noop), + Field('duration', String, _noop), + Field('thumbnailUrl', String, _noop), + ], implements=['Media']), + Root([ + Link('media', InterfaceRef['Media'], _noop, requires=None), + Link('mediaList', Sequence[InterfaceRef['Media']], _noop, requires=None), + Link('maybeMedia', Optional[InterfaceRef['Media']], _noop, requires=None), + ]), + ], interfaces=[ + Interface('Media', [ + Field('id', Integer, _noop), + Field('duration', String, _noop), + ]), + ]) + + assert introspect(graph) == _schema([ + _type('Audio', 'OBJECT', fields=[ + _field('id', _non_null(_INT)), + _field('duration', _non_null(_STR)), + _field('album', _non_null(_STR)), + ], interfaces=[_interface('Media')]), + _type('Video', 'OBJECT', fields=[ + _field('id', _non_null(_INT)), + _field('duration', _non_null(_STR)), + _field('thumbnailUrl', _non_null(_STR)), + ], interfaces=[_interface('Media')]), + _type('Query', 'OBJECT', fields=[ + _field('media', _non_null(_interface('Media'))), + _field('mediaList', _seq_of(_interface('Media'))), + _field('maybeMedia', _interface('Media')), + ]), + _type('Media', 'INTERFACE', possibleTypes=[ + _obj('Audio'), + _obj('Video'), + ], fields=[ + _field('id', _non_null(_INT)), + _field('duration', _non_null(_STR)), + ]), + ]) diff --git a/tests/test_union.py b/tests/test_union.py index 8b384c3e..d6595b40 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -1,7 +1,8 @@ import pytest from hiku.denormalize.graphql import DenormalizeGraphQL -from hiku.engine import Engine, SplitUnionByNodes +from hiku.engine import Engine +from hiku.union import SplitUnionQueryByNodes from hiku.executors.sync import SyncExecutor from hiku.graph import Field, Graph, Link, Node, Option, Root, Union from hiku.types import Integer, Optional, Sequence, String, TypeRef, UnionRef @@ -380,7 +381,7 @@ def test_split_union_into_nodes(): """ query = read(query_raw, {'text': 'foo'}) - nodes = SplitUnionByNodes(GRAPH, GRAPH.unions_map['Media']).split( + nodes = SplitUnionQueryByNodes(GRAPH, GRAPH.unions_map['Media']).split( query.fields_map['searchMedia'].node )