diff --git a/docs/code_examples/aiohttp_websockets_async.py b/docs/code_examples/aiohttp_websockets_async.py new file mode 100644 index 00000000..69520053 --- /dev/null +++ b/docs/code_examples/aiohttp_websockets_async.py @@ -0,0 +1,50 @@ +import asyncio +import logging + +from gql import Client, gql +from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + +logging.basicConfig(level=logging.INFO) + + +async def main(): + + transport = AIOHTTPWebsocketsTransport( + url="wss://countries.trevorblades.com/graphql" + ) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with Client( + transport=transport, + ) as session: + + # Execute single query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + result = await session.execute(query) + print(result) + + # Request subscription + subscription = gql( + """ + subscription { + somethingChanged { + id + } + } + """ + ) + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/intro.rst b/docs/intro.rst index 8f59ed16..21de16bd 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -36,25 +36,27 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: The corresponding between extra dependencies required and the GQL classes is: -+---------------------+----------------------------------------------------------------+ -| Extra dependencies | Classes | -+=====================+================================================================+ -| aiohttp | :ref:`AIOHTTPTransport ` | -+---------------------+----------------------------------------------------------------+ -| websockets | :ref:`WebsocketsTransport ` | -| | | -| | :ref:`PhoenixChannelWebsocketsTransport ` | -| | | -| | :ref:`AppSyncWebsocketsTransport ` | -+---------------------+----------------------------------------------------------------+ -| requests | :ref:`RequestsHTTPTransport ` | -+---------------------+----------------------------------------------------------------+ -| httpx | :ref:`HTTPTXTransport ` | -| | | -| | :ref:`HTTPXAsyncTransport ` | -+---------------------+----------------------------------------------------------------+ -| botocore | :ref:`AppSyncIAMAuthentication ` | -+---------------------+----------------------------------------------------------------+ ++---------------------+------------------------------------------------------------------+ +| Extra dependencies | Classes | ++=====================+==================================================================+ +| aiohttp | :ref:`AIOHTTPTransport ` | +| | | +| | :ref:`AIOHTTPWebsocketsTransport ` | ++---------------------+------------------------------------------------------------------+ +| websockets | :ref:`WebsocketsTransport ` | +| | | +| | :ref:`PhoenixChannelWebsocketsTransport ` | +| | | +| | :ref:`AppSyncWebsocketsTransport ` | ++---------------------+------------------------------------------------------------------+ +| requests | :ref:`RequestsHTTPTransport ` | ++---------------------+------------------------------------------------------------------+ +| httpx | :ref:`HTTPTXTransport ` | +| | | +| | :ref:`HTTPXAsyncTransport ` | ++---------------------+------------------------------------------------------------------+ +| botocore | :ref:`AppSyncIAMAuthentication ` | ++---------------------+------------------------------------------------------------------+ .. note:: diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 5f9edebe..b7c13c7c 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,6 +21,7 @@ Sub-Packages client transport transport_aiohttp + transport_aiohttp_websockets transport_appsync_auth transport_appsync_websockets transport_exceptions diff --git a/docs/modules/transport_aiohttp_websockets.rst b/docs/modules/transport_aiohttp_websockets.rst new file mode 100644 index 00000000..efa7e1bc --- /dev/null +++ b/docs/modules/transport_aiohttp_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.aiohttp_websockets +================================ + +.. currentmodule:: gql.transport.aiohttp_websockets + +.. automodule:: gql.transport.aiohttp_websockets + :member-order: bysource diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index 68b3eb99..b852108b 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -10,7 +10,9 @@ Reference: :class:`gql.transport.aiohttp.AIOHTTPTransport` .. note:: GraphQL subscriptions are not supported on the HTTP transport. - For subscriptions you should use the :ref:`websockets transport `. + For subscriptions you should use a websockets transport: + :ref:`WebsocketsTransport ` or + :ref:`AIOHTTPWebsocketsTransport `. .. literalinclude:: ../code_examples/aiohttp_async.py diff --git a/docs/transports/aiohttp_websockets.rst b/docs/transports/aiohttp_websockets.rst new file mode 100644 index 00000000..def3372e --- /dev/null +++ b/docs/transports/aiohttp_websockets.rst @@ -0,0 +1,31 @@ +.. _aiohttp_websockets_transport: + +AIOHTTPWebsocketsTransport +========================== + +The AIOHTTPWebsocketsTransport is an alternative to the :ref:`websockets_transport`, +using the `aiohttp` dependency instead of the `websockets` dependency. + +It also supports both: + + - the `Apollo websockets transport protocol`_. + - the `GraphQL-ws websockets transport protocol`_ + +It will propose both subprotocols to the backend and detect the supported protocol +from the response http headers returned by the backend. + +.. note:: + For some backends (graphql-ws before `version 5.6.1`_ without backwards compatibility), it may be necessary to specify + only one subprotocol to the backend. It can be done by using + :code:`subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL]` + or :code:`subprotocols=[AIOHTTPWebsocketsTransport.APOLLO_SUBPROTOCOL]` in the transport arguments. + +This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. + +Reference: :class:`gql.transport.aiohttp_websockets.AIOHTTPWebsocketsTransport` + +.. literalinclude:: ../code_examples/aiohttp_websockets_async.py + +.. _version 5.6.1: https://github.com/enisdenjo/graphql-ws/releases/tag/v5.6.1 +.. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +.. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index 7d751df0..ba5ca136 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -12,5 +12,6 @@ Async transports are transports which are using an underlying async library. The aiohttp httpx_async websockets + aiohttp_websockets phoenix appsync diff --git a/gql/cli.py b/gql/cli.py index dd991546..a7d129e2 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -159,6 +159,7 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: "aiohttp", "phoenix", "websockets", + "aiohttp_websockets", "appsync_http", "appsync_websockets", ], @@ -286,7 +287,12 @@ def autodetect_transport(url: URL) -> str: """Detects which transport should be used depending on url.""" if url.scheme in ["ws", "wss"]: - transport_name = "websockets" + try: + import websockets # noqa: F401 + + transport_name = "websockets" + except ImportError: # pragma: no cover + transport_name = "aiohttp_websockets" else: assert url.scheme in ["http", "https"] @@ -338,6 +344,11 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: return WebsocketsTransport(url=args.server, **transport_args) + elif transport_name == "aiohttp_websockets": + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + return AIOHTTPWebsocketsTransport(url=args.server, **transport_args) + else: from gql.transport.appsync_auth import AppSyncAuthentication diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py new file mode 100644 index 00000000..ff310a82 --- /dev/null +++ b/gql/transport/aiohttp_websockets.py @@ -0,0 +1,1196 @@ +import asyncio +import json +import logging +import sys +import warnings +from contextlib import suppress +from ssl import SSLContext +from typing import ( + Any, + AsyncGenerator, + Collection, + Dict, + Mapping, + Optional, + Tuple, + Union, +) + +import aiohttp +from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp.typedefs import LooseHeaders, StrOrURL +from graphql import DocumentNode, ExecutionResult, print_ast +from multidict import CIMultiDictProxy + +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.async_transport import AsyncTransport +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +""" +Load the appropriate instance of the Literal type +Note: we cannot use try: except ImportError because of the following mypy issue: +https://github.com/python/mypy/issues/8520 +""" +if sys.version_info[:2] >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # pragma: no cover + +log = logging.getLogger("gql.transport.aiohttp_websockets") + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + + +class AIOHTTPWebsocketsTransport(AsyncTransport): + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL: str = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" + + def __init__( + self, + url: StrOrURL, + *, + subprotocols: Optional[Collection[str]] = None, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + init_payload: Dict[str, Any] = {}, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param headers: HTTP Headers that sent with every request + May be either *iterable of key-value pairs* or + :class:`~collections.abc.Mapping` + (e.g. :class:`dict`, + :class:`~multidict.CIMultiDict`). + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param init_payload: Dict of the payload sent in the connection_init message. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param client_session_args: Dict of extra args passed to + `aiohttp.ClientSession`_ + :param connect_args: Dict of extra args passed to + `aiohttp.ClientSession.ws_connect`_ + + .. _aiohttp.ClientSession.ws_connect: + https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession.ws_connect + .. _aiohttp.ClientSession: + https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession + """ + self.url: StrOrURL = url + self.heartbeat: Optional[float] = heartbeat + self.auth: Optional[BasicAuth] = auth + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + self.headers: Optional[LooseHeaders] = headers + + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl + + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + + self.init_payload: Dict[str, Any] = init_payload + + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + + self.session: Optional[aiohttp.ClientSession] = None + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + self.next_query_id: int = 1 + self.listeners: Dict[int, ListenerQueue] = {} + self._connecting: bool = False + self.response_headers: Optional[CIMultiDictProxy[str]] = None + + self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() + + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + + self.payloads: Dict[str, Any] = {} + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + self.supported_subprotocols: Collection[str] = subprotocols or ( + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ) + + self.close_exception: Optional[Exception] = None + + self.client_session_args = client_session_args + self.connect_args = connect_args + + def _parse_answer_graphqlws( + self, answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, _, _ = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = {"type": "connection_init", "payload": self.init_payload} + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + await self._send_init_message_and_wait_ack() + + async def _stop_listener(self, query_id: int): + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket._response.headers + log.debug(f"Response headers: {response_headers!r}") + try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(ping_message) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(pong_message) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = {"id": str(query_id), "type": "stop"} + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = {"id": str(query_id), "type": "complete"} + + await self._send(complete_message) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close""" + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task + self.send_ping_task = None + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = {"type": "connection_terminate"} + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query = {"id": str(query_id), "type": query_type, "payload": payload} + + await self._send(query) + + return query_id + + async def _send(self, message: Dict[str, Any]) -> None: + """Send the provided message to the websocket connection and log the message""" + + if self.websocket is None: + raise TransportClosed("WebSocket connection is closed") + + try: + await self.websocket.send_json(message) + log.info(">>> %s", message) + except ConnectionResetError as e: + await self._fail(e, clean_close=False) + raise e + + async def _receive(self) -> str: + """Wait the next message from the websocket connection and log the answer""" + + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") + + while True: + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break + + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): + raise ConnectionResetError + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") + + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + log.info("<<< %s", answer) + + return answer + + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + log.debug("Entering _receive_data_loop()") + + try: + while True: + + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionResetError, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed as e: + await self._fail(e, clean_close=False) + raise e + + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break + + await self._handle_answer(answer_type, answer_id, execution_result) + + finally: + log.debug("Exiting _receive_data_loop()") + + async def connect(self) -> None: + log.debug("connect: starting") + + if self.session is None: + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.client_session_args: + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) + + if self.websocket is None and not self._connecting: + self._connecting = True + + connect_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.connect_args: + connect_args.update(self.connect_args) + + try: + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + self.websocket = await asyncio.wait_for( + self.session.ws_connect( + url=self.url, + headers=self.headers, + auth=self.auth, + heartbeat=self.heartbeat, + origin=self.origin, + params=self.params, + protocols=self.supported_subprotocols, + proxy=self.proxy, + proxy_auth=self.proxy_auth, + proxy_headers=self.proxy_headers, + timeout=self.websocket_close_timeout, + receive_timeout=self.receive_timeout, + ssl=self.ssl, + **connect_args, + ), + self.connect_timeout, + ) + finally: + self._connecting = False + + self.response_headers = self.websocket._response.headers + + await self._after_connect() + + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() + + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionResetError as e: + raise e + except ( + TransportProtocolError, + TransportServerError, + asyncio.TimeoutError, + ) as e: + await self._fail(e, clean_close=False) + raise e + + # Run the after_init hook of the subclass + await self._after_initialize() + + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + else: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("connect: done") + + async def _clean_close(self) -> None: + """Coroutine which will: + + - send stop messages for each active subscription to the server + - send the connection terminate message + """ + log.debug(f"Listeners: {self.listeners}") + + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): + print(f"Listener {query_id} send_stop: {listener.send_stop}") + + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") + + # Calling the subclass hook + await self._connection_terminate() + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: + + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ + + log.debug("_close_coro: starting") + + try: + + try: + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + except Exception as exc: # pragma: no cover + log.warning( + "_close_coro cancel keep alive task exception: " + repr(exc) + ) + + try: + # Calling the subclass close hook + await self._close_hook() + except Exception as exc: # pragma: no cover + log.warning("_close_coro close_hook exception: " + repr(exc)) + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close() + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) + + log.debug("_close_coro: close websocket connection") + + try: + assert self.websocket is not None + + await self.websocket.close() + self.websocket = None + except Exception as exc: + log.warning("_close_coro websocket close exception: " + repr(exc)) + + log.debug("_close_coro: close aiohttp session") + + if ( + self.client_session_args + and self.client_session_args.get("connector_owner") is False + ): + + log.debug("connector_owner is False -> not closing connector") + + else: + try: + assert self.session is not None + + closed_event = AIOHTTPTransport.create_aiohttp_closed_event( + self.session + ) + await self.session.close() + try: + await asyncio.wait_for( + closed_event.wait(), self.ssl_close_timeout + ) + except asyncio.TimeoutError: + pass + except Exception as exc: # pragma: no cover + log.warning("_close_coro session close exception: " + repr(exc)) + + self.session = None + + log.debug("_close_coro: aiohttp session closed") + + try: + assert self.receive_data_task is not None + + self.receive_data_task.cancel() + with suppress(asyncio.CancelledError): + await self.receive_data_task + except Exception as exc: # pragma: no cover + log.warning( + "_close_coro cancel receive data task exception: " + repr(exc) + ) + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + + log.debug("_close_coro: final cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self.receive_data_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self._wait_closed.is_set(): + log.debug("_fail started but transport is already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) + + async def close(self) -> None: + log.debug("close: starting") + + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() + + log.debug("close: done") + + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + if not self._wait_closed.is_set(): + await self._wait_closed.wait() + + log.debug("wait_close: done") + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + + Send a query but close the async generator as soon as we have the first answer. + + The result is sent as an ExecutionResult object. + """ + first_result = None + + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) + + async for result in generator: + first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) + + return first_result + + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. + + The query can be a graphql query, mutation or subscription. + + The results are sent as an ExecutionResult object. + """ + + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) + + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener + + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + + try: + # Loop over the received answers + while True: + + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() + + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break + + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) diff --git a/tests/conftest.py b/tests/conftest.py index 6a37a5d3..c164c355 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,6 +119,7 @@ async def ssl_aiohttp_server(): for name in [ "websockets.legacy.server", "gql.transport.aiohttp", + "gql.transport.aiohttp_websockets", "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", @@ -210,6 +211,145 @@ async def stop(self): print("Server stopped\n\n\n") +class AIOHTTPWebsocketServer: + def __init__(self, with_ssl=False): + self.runner = None + self.site = None + self.port = None + self.hostname = "127.0.0.1" + self.with_ssl = with_ssl + self.ssl_context = None + if with_ssl: + _, self.ssl_context = get_localhost_ssl_context() + + def get_default_server_handler(answers): + async def default_server_handler(request): + + import aiohttp + import aiohttp.web + from aiohttp import WSMsgType + + ws = aiohttp.web.WebSocketResponse() + ws.headers.update({"dummy": "test1234"}) + await ws.prepare(request) + + try: + # Init and ack + msg = await ws.__anext__() + assert msg.type == WSMsgType.TEXT + result = msg.data + json_result = json.loads(result) + assert json_result["type"] == "connection_init" + await ws.send_str('{"type":"connection_ack"}') + query_id = 1 + + # Wait for queries and send answers + for answer in answers: + msg = await ws.__anext__() + if msg.type == WSMsgType.TEXT: + result = msg.data + + print(f"Server received: {result}", file=sys.stderr) + if isinstance(answer, str) and "{query_id}" in answer: + answer_format_params = {"query_id": query_id} + formatted_answer = answer.format(**answer_format_params) + else: + formatted_answer = answer + await ws.send_str(formatted_answer) + await ws.send_str( + f'{{"type":"complete","id":"{query_id}","payload":null}}' + ) + query_id += 1 + + elif msg.type == WSMsgType.ERROR: + print(f"WebSocket connection closed with: {ws.exception()}") + raise ws.exception() + elif msg.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + ): + print("WebSocket connection closed") + raise ConnectionResetError + + # Wait for connection_terminate + msg = await ws.__anext__() + result = msg.data + json_result = json.loads(result) + assert json_result["type"] == "connection_terminate" + + # Wait for connection close + msg = await ws.__anext__() + + except ConnectionResetError: + pass + + except Exception as e: + print(f"Server exception {e!s}", file=sys.stderr) + + await ws.close() + return ws + + return default_server_handler + + async def shutdown_server(self, app): + print("Shutting down server...") + await app.shutdown() + await app.cleanup() + + async def start(self, handler): + import aiohttp + import aiohttp.web + + app = aiohttp.web.Application() + app.router.add_get("/graphql", handler) + self.runner = aiohttp.web.AppRunner(app) + await self.runner.setup() + + # Use port 0 to bind to an available port + self.site = aiohttp.web.TCPSite( + self.runner, self.hostname, 0, ssl_context=self.ssl_context + ) + await self.site.start() + + # Retrieve the actual port the server is listening on + sockets = self.site._server.sockets + if sockets: + self.port = sockets[0].getsockname()[1] + protocol = "https" if self.with_ssl else "http" + print(f"Server started at {protocol}://{self.hostname}:{self.port}") + + async def stop(self): + if self.site: + await self.site.stop() + if self.runner: + await self.runner.cleanup() + + +@pytest_asyncio.fixture +async def aiohttp_ws_server(request): + """Fixture used to start a dummy server to test the client behaviour + using the aiohttp dependency. + + It can take as argument either a handler function for the websocket server for + complete control OR an array of answers to be sent by the default server handler. + """ + + server_handler = get_aiohttp_ws_server_handler(request) + + try: + test_server = AIOHTTPWebsocketServer() + + # Starting the server with the fixture param as the handler function + await test_server.start(server_handler) + + yield test_server + except Exception as e: + print("Exception received in server fixture:", e) + finally: + await test_server.stop() + + class WebSocketServerHelper: @staticmethod async def send_complete(ws, query_id): @@ -306,6 +446,23 @@ def __exit__(self, type, value, traceback): os.unlink(self.filename) +def get_aiohttp_ws_server_handler(request): + """Get the server handler for the aiohttp websocket server. + + Either get it from test or use the default server handler + if the test provides only an array of answers. + """ + + if isinstance(request.param, types.FunctionType): + server_handler = request.param + + else: + answers = request.param + server_handler = AIOHTTPWebsocketServer.get_default_server_handler(answers) + + return server_handler + + def get_server_handler(request): """Get the server handler. @@ -462,6 +619,48 @@ async def client_and_server(server): yield session, server +@pytest_asyncio.fixture +async def aiohttp_client_and_server(server): + """ + Helper fixture to start a server and a client connected to its port + with an aiohttp websockets transport. + """ + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, server + + +@pytest_asyncio.fixture +async def aiohttp_client_and_aiohttp_ws_server(aiohttp_ws_server): + """ + Helper fixture to start an aiohttp websocket server and + a client connected to its port with an aiohttp websockets transport. + """ + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = aiohttp_ws_server + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, server + + @pytest_asyncio.fixture async def client_and_graphqlws_server(graphqlws_server): """Helper fixture to start a server with the graphql-ws prototocol @@ -483,6 +682,27 @@ async def client_and_graphqlws_server(graphqlws_server): yield session, graphqlws_server +@pytest_asyncio.fixture +async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): + """Helper fixture to start a server with the graphql-ws prototocol + and a client connected to its port.""" + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport( + url=url, + subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + ) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, graphqlws_server + + @pytest_asyncio.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py new file mode 100644 index 00000000..ea48824f --- /dev/null +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -0,0 +1,406 @@ +import asyncio +import json +import types +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"data","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_invalid_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_aiohttp_websocket_invalid_subscription( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +connection_error_server_answer = ( + '{"type":"connection_error","id":null,' + '"payload":{"message":"Unexpected token Q in JSON at position 0"}}' +) + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_server_does_not_send_ack( + event_loop, server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + +async def server_connection_error(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(connection_error_server_answer) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_connection_error], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_sending_invalid_data( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + invalid_data = "QSDF" + print(f">>> {invalid_data}") + await session.transport.websocket.send_str(invalid_data) + + await asyncio.sleep(2 * MS) + + +invalid_payload_server_answer = ( + '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +) + + +async def server_invalid_payload(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_payload_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_sending_invalid_payload( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + # Monkey patching the _send_query method to send an invalid payload + + async def monkey_patch_send_query( + self, + document, + variable_values=None, + operation_name=None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + {"id": str(query_id), "type": "start", "payload": "BLAHBLAH"} + ) + + await self._send(query_str) + return query_id + + session.transport._send_query = types.MethodType( + monkey_patch_send_query, session.transport + ) + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["message"] == "Must provide document" + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "data"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "data", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +payload_is_not_a_dict = ['{"type": "data", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "data", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_aiohttp_websocket_transport_protocol_errors( + event_loop, aiohttp_client_and_server +): + + session, server = aiohttp_client_and_server + + query = gql("query { hello }") + + with pytest.raises((TransportProtocolError, TransportQueryError)): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_without_ack], indirect=True) +async def test_aiohttp_websocket_server_does_not_ack(event_loop, server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_directly], indirect=True) +async def test_aiohttp_websocket_server_closing_directly(event_loop, server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(ConnectionResetError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) +async def test_aiohttp_websocket_server_closing_after_ack( + event_loop, aiohttp_client_and_server +): + + session, server = aiohttp_client_and_server + + query = gql("query { hello }") + + with pytest.raises(TransportClosed): + await session.execute(query) + + +async def server_sending_invalid_query_errors(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + invalid_error = ( + '{"type":"error","id":"404","payload":' + '{"message":"error for no good reason on non existing query"}}' + ) + await ws.send(invalid_error) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_aiohttp_websocket_server_sending_invalid_query_errors( + event_loop, server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + # Invalid server message is ignored + async with Client(transport=sample_transport): + await asyncio.sleep(2 * MS) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_aiohttp_websocket_non_regression_bug_105(event_loop, server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # This test will check a fix to a race condition which happens if the user is trying + # to connect using the same client twice at the same time + # See bug #105 + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + # Create a coroutine which start the connection with the transport but does nothing + async def client_connect(client): + async with client: + await asyncio.sleep(2 * MS) + + # Create two tasks which will try to connect using the same client (not allowed) + connect_task1 = asyncio.ensure_future(client_connect(client)) + connect_task2 = asyncio.ensure_future(client_connect(client)) + + result = await asyncio.gather(connect_task1, connect_task2, return_exceptions=True) + + assert result[0] is None + assert type(result[1]).__name__ == "TransportAlreadyConnected" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +async def test_aiohttp_websocket_using_cli_invalid_query( + event_loop, server, monkeypatch, capsys +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + import io + + from gql.cli import get_parser, main + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(invalid_query_str)) + + # Flush captured output + captured = capsys.readouterr() + + await main(args) + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = 'Cannot query field "bloh" on type "Continent"' + + assert expected_error in captured_err diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py new file mode 100644 index 00000000..d87315c9 --- /dev/null +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -0,0 +1,276 @@ +import asyncio +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"next","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_graphqlws_invalid_query( + event_loop, client_and_aiohttp_websocket_graphql_server, query_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_invalid_subscription], indirect=True +) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_aiohttp_websocket_graphqlws_invalid_subscription( + event_loop, client_and_aiohttp_websocket_graphql_server, query_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( + event_loop, graphqlws_server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=transport): + pass + + +invalid_query_server_answer = ( + '{"id":"1","type":"error","payload":[{"message":"Cannot query field ' + '\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",' + '"locations":[{"line":2,"column":3}]}]}' +) + + +async def server_invalid_query(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_query_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) +async def test_aiohttp_websocket_graphqlws_sending_invalid_query( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("{helo}") + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert ( + error["message"] + == 'Cannot query field "helo" on type "Query". Did you mean "hello"?' + ) + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "next"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "next", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}'] +payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "next", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + error_with_payload_not_a_list, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_aiohttp_websocket_graphqlws_transport_protocol_errors( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("query { hello }") + + with pytest.raises((TransportProtocolError, TransportQueryError)): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) +async def test_aiohttp_websocket_graphqlws_server_does_not_ack( + event_loop, graphqlws_server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) +async def test_aiohttp_websocket_graphqlws_server_closing_directly( + event_loop, graphqlws_server +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(ConnectionResetError): + async with Client(transport=transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) +async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, _ = client_and_aiohttp_websocket_graphql_server + + query = gql("query { hello }") + + with pytest.raises(TransportClosed): + await session.execute(query) + + await session.transport.wait_closed() + + with pytest.raises(TransportClosed): + await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py new file mode 100644 index 00000000..e5db7ca1 --- /dev/null +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -0,0 +1,879 @@ +import asyncio +import json +import sys +import warnings +from typing import List + +import pytest +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportServerError + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +countdown_server_answer = ( + '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +COUNTING_DELAY = 20 * MS +PING_SENDING_DELAY = 50 * MS +PONG_TIMEOUT = 100 * MS + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def server_countdown_factory( + keepalive=False, answer_pings=True, simulate_disconnect=False +): + async def server_countdown_template(ws, path): + import websockets + + logged_messages.clear() + + try: + await WebSocketServerHelper.send_connection_ack( + ws, payload="dummy_connection_ack_payload" + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f" Server: Countdown started from: {count}") + + if simulate_disconnect and count == 8: + await ws.close() + + pong_received: asyncio.Event = asyncio.Event() + + async def counting_coro(): + print(" Server: counting task started") + try: + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format( + query_id=query_id, number=number + ) + ) + await asyncio.sleep(COUNTING_DELAY) + finally: + print(" Server: counting task ended") + + print(" Server: starting counting task") + counting_task = asyncio.ensure_future(counting_coro()) + + async def keepalive_coro(): + print(" Server: keepalive task started") + try: + while True: + await asyncio.sleep(PING_SENDING_DELAY) + try: + # Send a ping + await WebSocketServerHelper.send_ping( + ws, payload="dummy_ping_payload" + ) + + # Wait for a pong + try: + await asyncio.wait_for( + pong_received.wait(), PONG_TIMEOUT + ) + except asyncio.TimeoutError: + print( + "\n Server: No pong received in time!\n" + ) + break + + pong_received.clear() + + except websockets.exceptions.ConnectionClosed: + break + finally: + print(" Server: keepalive task ended") + + if keepalive: + print(" Server: starting keepalive task") + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + print(" Server: receiving task started") + try: + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + answer_type = json_result["type"] + + if answer_type == "complete" and json_result["id"] == str( + query_id + ): + print("Cancelling counting task now") + counting_task.cancel() + if keepalive: + print("Cancelling keep alive task now") + keepalive_task.cancel() + + elif answer_type == "ping": + if answer_pings: + payload = json_result.get("payload", None) + await WebSocketServerHelper.send_pong( + ws, payload=payload + ) + + elif answer_type == "pong": + pong_received.set() + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + print(" Server: waiting for counting task to complete") + await counting_task + except asyncio.CancelledError: + print(" Server: Now counting task is cancelled") + + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) + + if keepalive: + print(" Server: cancelling keepalive task") + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") + + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\n Server: Assertion failed: {e!s}\n") + finally: + print(" Server: waiting for websocket connection to close") + await ws.wait_closed() + print(" Server: connection closed") + + return server_countdown_template + + +async def server_countdown(ws, path): + + server = server_countdown_factory() + await server(ws, path) + + +async def server_countdown_keepalive(ws, path): + + server = server_countdown_factory(keepalive=True) + await server(ws, path) + + +async def server_countdown_dont_answer_pings(ws, path): + + server = server_countdown_factory(answer_pings=False) + await server(ws, path) + + +async def server_countdown_disconnect(ws, path): + + server = server_countdown_factory(simulate_disconnect=True) + await server(ws, path) + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_break( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_close_transport( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(COUNTING_DELAY) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + session, _ = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(ConnectionResetError): + async for result in session.subscribe(subscription): + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + assert "ping" in session.transport.payloads + assert session.transport.payloads["ping"] == "dummy_ping_payload" + assert ( + session.transport.payloads["connection_ack"] == "dummy_connection_ack_payload" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, keep_alive_timeout=(5 * COUNTING_DELAY) + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, keep_alive_timeout=(COUNTING_DELAY / 2) + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, + ping_interval=(10 * COUNTING_DELAY), + pong_timeout=(8 * COUNTING_DELAY), + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_dont_answer_pings], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, ping_interval=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No pong received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + payload = {"count_received": count} + + await transport.send_ping(payload=payload) + + await asyncio.wait_for(transport.pong_received.wait(), 10000 * MS) + + transport.pong_received.clear() + + assert transport.payloads["pong"] == payload + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_manual_pong_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, answer_pings=False) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + async def answer_ping_coro(): + while True: + await transport.ping_received.wait() + transport.ping_received.clear() + await transport.send_pong(payload={"some": "data"}) + + answer_ping_task = asyncio.ensure_future(answer_ping_coro()) + + try: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + finally: + answer_ping_task.cancel() + + assert count == -1 + + +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_graphqlws_subscription_sync( + graphqlws_server, subscription_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( + graphqlws_server, subscription_str +): + """Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Check that the server received a connection_terminate message last + # assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_running_in_thread( + event_loop, graphqlws_server, subscription_str, run_sync_test +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, graphqlws_server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_disconnect], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +@pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) +async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( + event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe +): + + from gql.transport.exceptions import TransportClosed + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 8 + subscription_with_disconnect = gql(subscription_str.format(count=count)) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + session = await client.connect_async( + reconnecting=True, retry_connect=False, retry_execute=False + ) + + # First we make a subscription which will cause a disconnect in the backend + # (count=8) + try: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): + pass + except ConnectionResetError: + pass + + await asyncio.sleep(50 * MS) + + # Then with the same session handle, we make a subscription or an execute + # which will detect that the transport is closed so that the client could + # try to reconnect + try: + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + await session.execute(subscription) + else: + print("\nSUBSCRIPTION_2\n") + async for result in session.subscribe(subscription): + pass + except TransportClosed: + pass + + await asyncio.sleep(50 * MS) + + # And finally with the same session handle, we make a subscription + # which works correctly + print("\nSUBSCRIPTION_3\n") + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await client.close_async() diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py new file mode 100644 index 00000000..f154386b --- /dev/null +++ b/tests/test_aiohttp_websocket_query.py @@ -0,0 +1,707 @@ +import asyncio +import json +import ssl +import sys +from typing import Dict, Mapping + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportQueryError, + TransportServerError, +) + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = pytest.mark.aiohttp + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_data = ( + '{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}' +) + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_starting_client_in_context_manager( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +async def test_aiohttp_websocket_using_ssl_connection( + event_loop, ws_ssl_server, ssl_close_timeout +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(ws_ssl_server.testcert) + + transport = AIOHTTPWebsocketsTransport( + url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout + ) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_simple_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + +server1_two_answers_in_series = [ + query1_server_answer, + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aiohttp_ws_server", [server1_two_answers_in_series], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_two_queries_in_series( + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str +): + + session, server = aiohttp_client_and_aiohttp_ws_server + + query = gql(query_str) + + result1 = await session.execute(query) + + print("Query1 received:", result1) + + result2 = await session.execute(query) + + print("Query2 received:", result2) + + assert result1 == result2 + + +async def server1_two_queries_in_parallel(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await ws.send(query1_server_answer.format(query_id=2)) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 2) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_two_queries_in_parallel( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result1 = None + result2 = None + + async def task1_coro(): + nonlocal result1 + result1 = await session.execute(query) + + async def task2_coro(): + nonlocal result2 + result2 = await session.execute(query) + + task1 = asyncio.ensure_future(task1_coro()) + task2 = asyncio.ensure_future(task2_coro()) + + await asyncio.gather(task1, task2) + + print("Query1 received:", result1) + print("Query2 received:", result2) + + assert result1 == result2 + + +async def server_closing_while_we_are_doing_something_else(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await asyncio.sleep(1 * MS) + + # Closing server after first query + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_closing_while_we_are_doing_something_else], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_server_closing_after_first_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + # First query is working + await session.execute(query) + + # Then we do other things + await asyncio.sleep(1000 * MS) + + # Now the server is closed but we don't know it yet, we have to send a query + # to notice it and to receive the exception + with pytest.raises(TransportClosed): + await session.execute(query) + + +ignore_invalid_id_answers = [ + query1_server_answer, + '{"type":"complete","id": "55"}', + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aiohttp_ws_server", [ignore_invalid_id_answers], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_ignore_invalid_id( + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str +): + + session, server = aiohttp_client_and_aiohttp_ws_server + + query = gql(query_str) + + # First query is working + await session.execute(query) + + # Second query gets no answer -> raises + with pytest.raises(TransportQueryError): + await session.execute(query) + + # Third query is working + await session.execute(query) + + +async def assert_client_is_working(session): + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_series( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + # Check client is disconnect here + assert transport.websocket is None + + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_parallel( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + async def task_coro(): + transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + task1 = asyncio.ensure_future(task_coro()) + task2 = asyncio.ensure_future(task_coro()) + + await asyncio.gather(task1, task2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transport( + event_loop, aiohttp_ws_server +): + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + with pytest.raises(TransportAlreadyConnected): + async with Client(transport=transport): + pass + + +async def server_with_authentication_in_connection_init_payload(ws, path): + # Wait the connection_init message + init_message_str = await ws.recv() + init_message = json.loads(init_message_str) + payload = init_message["payload"] + + if "Authorization" in payload: + if payload["Authorization"] == 12345: + await ws.send('{"type":"connection_ack"}') + + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + else: + await ws.send( + '{"type":"connection_error", "payload": "Invalid Authorization token"}' + ) + else: + await ws.send( + '{"type":"connection_error", "payload": "No Authorization token"}' + ) + + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_connect_success_with_authentication_in_connection_init( + event_loop, server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + init_payload = {"Authorization": 12345} + + transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) + + async with Client(transport=transport) as session: + + query1 = gql(query_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) +async def test_aiohttp_websocket_connect_failed_with_authentication_in_connection_init( + event_loop, server, query_str, init_payload +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) + + for _ in range(2): + with pytest.raises(TransportServerError): + async with Client(transport=transport) as session: + query1 = gql(query_str) + + await session.execute(query1) + + assert transport.session is None + assert transport.websocket is None + + +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + query1 = gql(query1_str) + + result = client.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Execute sync a second time + result = client.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_add_extra_parameters_to_connect( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + # Increase max payload size + transport = AIOHTTPWebsocketsTransport( + url=url, + connect_args={ + "max_msg_size": 2**21, + }, + ) + + query = gql(query1_str) + + async with Client(transport=transport) as session: + await session.execute(query) + + +async def server_sending_keep_alive_before_connection_ack(ws, path): + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_sending_keep_alive_before_connection_ack], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_non_regression_bug_108( + event_loop, aiohttp_client_and_server, query_str +): + + # This test will check that we now ignore keepalive message + # arriving before the connection_ack + # See bug #108 + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("transport_arg", [[], ["--transport=aiohttp_websockets"]]) +async def test_aiohttp_websocket_using_cli( + event_loop, aiohttp_ws_server, transport_arg, monkeypatch, capsys +): + + """ + Note: depending on the transport_arg parameter, if there is no transport argument, + then we will use WebsocketsTransport if the websockets dependency is installed, + or AIOHTTPWebsocketsTransport if that is not the case. + """ + + server = aiohttp_ws_server + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + import io + import json + + from gql.cli import get_parser, main + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, *transport_arg]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + # Flush captured output + captured = capsys.readouterr() + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + +query1_server_answer_with_extensions = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}},' + '"extensions": {{"key1": "val1"}}}}}}' +) + +server1_answers_with_extensions = [ + query1_server_answer_with_extensions, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aiohttp_ws_server", [server1_answers_with_extensions], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_simple_query_with_extensions( + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str +): + + session, server = aiohttp_client_and_aiohttp_ws_server + + query = gql(query_str) + + execution_result = await session.execute(query, get_execution_result=True) + + assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_server): + + server = aiohttp_ws_server + + from aiohttp import TCPConnector + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + connector = TCPConnector() + transport = AIOHTTPWebsocketsTransport( + url=url, + client_session_args={ + "connector": connector, + "connector_owner": False, + }, + ) + + for _ in range(2): + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + await connector.close() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py new file mode 100644 index 00000000..3ebf4dbc --- /dev/null +++ b/tests/test_aiohttp_websocket_subscription.py @@ -0,0 +1,809 @@ +import asyncio +import json +import sys +import warnings +from typing import List + +import pytest +from graphql import ExecutionResult +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportClosed, TransportServerError + +from .conftest import MS, WebSocketServerHelper +from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +starwars_expected_one = { + "stars": 3, + "commentary": "Was expecting more stuff", + "episode": "JEDI", +} + +starwars_expected_two = { + "stars": 5, + "commentary": "This is a great movie!", + "episode": "JEDI", +} + + +async def server_starwars(ws, path): + import websockets + + await WebSocketServerHelper.send_connection_ack(ws) + + try: + await ws.recv() + + reviews = [starwars_expected_one, starwars_expected_two] + + for review in reviews: + + data = ( + '{"type":"data","id":"1","payload":{"data":{"reviewAdded": ' + + json.dumps(review) + + "}}}" + ) + await ws.send(data) + await asyncio.sleep(2 * MS) + + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.wait_connection_terminate(ws) + + except websockets.exceptions.ConnectionClosedOK: + pass + + print("Server is now closed") + + +starwars_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + stars, + commentary, + episode + } + } +""" + +starwars_invalid_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + not_valid_field, + stars, + commentary, + episode + } + } +""" + +countdown_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +WITH_KEEPALIVE = False + + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +async def server_countdown(ws, path): + import websockets + + logged_messages.clear() + + global WITH_KEEPALIVE + try: + await WebSocketServerHelper.send_connection_ack(ws) + if WITH_KEEPALIVE: + await WebSocketServerHelper.send_keepalive(ws) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(2 * MS) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def stopping_coro(): + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + if json_result["type"] == "stop" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + try: + await WebSocketServerHelper.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break + + stopping_task = asyncio.ensure_future(stopping_coro()) + if WITH_KEEPALIVE: + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + except Exception as exc: + print(f"Exception in counting task: {exc!s}") + + stopping_task.cancel() + + try: + await stopping_task + except asyncio.CancelledError: + print("Now stopping task is cancelled") + + if WITH_KEEPALIVE: + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print("Now keepalive task is cancelled") + + await WebSocketServerHelper.send_complete(ws, query_id) + await WebSocketServerHelper.wait_connection_terminate(ws) + except websockets.exceptions.ConnectionClosedOK: + pass + finally: + await ws.wait_closed() + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_get_execution_result( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription, get_execution_result=True): + + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_break( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_task_cancel( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_close_transport( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, _ = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(2 * MS) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_server_connection_closed( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(ConnectionResetError): + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_slow_consumer( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + await asyncio.sleep(10 * MS) + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_operation_name( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +WITH_KEEPALIVE = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( + event_loop, server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( + event_loop, server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync(server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_user_exception(server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(Exception) as exc_info: + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + raise Exception("This is an user exception") + + assert count == 5 + assert "This is an user exception" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_break(server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + break + + assert count == 5 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_graceful_shutdown( + server, subscription_str +): + """Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + interrupt_task = None + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + interrupt_task = asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Catch interrupt_task exception to remove warning + interrupt_task.exception() + + # Check that the server received a connection_terminate message last + assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_running_in_thread( + event_loop, server, subscription_str, run_sync_test +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_starwars], indirect=True) +@pytest.mark.parametrize("subscription_str", [starwars_subscription_str]) +@pytest.mark.parametrize( + "client_params", + [ + {"schema": StarWarsSchema}, + {"introspection": StarWarsIntrospection}, + {"schema": StarWarsTypeDef}, + ], +) +async def test_async_aiohttp_client_validation( + event_loop, server, subscription_str, client_params +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport, **client_params) + + async with client as session: + + variable_values = {"ep": "JEDI"} + + subscription = gql(subscription_str) + + expected = [] + + async for result in session.subscribe( + subscription, variable_values=variable_values, parse_result=False + ): + + review = result["reviewAdded"] + expected.append(review) + + assert "stars" in review + assert "commentary" in review + assert "episode" in review + + assert expected[0] == starwars_expected_one + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_closing_transport(event_loop, server, subscription_str): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + session.transport.websocket._writer._closing = True + + with pytest.raises(ConnectionResetError) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "Cannot write to closing transport" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_null_transport(event_loop, server, subscription_str): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + session.transport.websocket = None + with pytest.raises(TransportClosed) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "WebSocket connection is closed"