diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index 334d115b..6c793394 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -364,7 +364,7 @@ async def __listener_accept(self, listener: AsyncListenerSocketAdapter, task_gro _errno.errorcode[exc.errno], os.strerror(exc.errno), ACCEPT_CAPACITY_ERROR_SLEEP_TIME, - exc_info=True, + exc_info=exc, ) await backend.sleep(ACCEPT_CAPACITY_ERROR_SLEEP_TIME) else: @@ -428,7 +428,14 @@ async def __client_coroutine(self, accepted_socket: AcceptedSocket) -> None: assert inspect.isawaitable(_on_connection_hook) # nosec assert_used await _on_connection_hook del _on_connection_hook - client_exit_stack.push_async_callback(self.__request_handler.on_disconnection, client) + + async def disconnect_client() -> None: + try: + await self.__request_handler.on_disconnection(client) + except* ConnectionError: + self.__logger.warning("ConnectionError raised in request_handler.on_disconnection()") + + client_exit_stack.push_async_callback(disconnect_client) del client_exit_stack @@ -504,7 +511,7 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress | self.__logger.warning( "There have been attempts to do operation on closed client %s", client_address, - exc_info=True, + exc_info=excgrp, ) except* ConnectionError: # This exception come from the request handler ( most likely due to client.send_packet() ) @@ -515,9 +522,9 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress | _remove_traceback_frames_in_place(exc, 1) # Removes the 'yield' frame just above self.__logger.error("-" * 40) if client_address is None: - self.__logger.exception("Error in client task") + self.__logger.error("Error in client task", exc_info=exc) else: - self.__logger.exception("Exception occurred during processing of request from %s", client_address) + self.__logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc) self.__logger.error("-" * 40) def get_addresses(self) -> Sequence[SocketAddress]: diff --git a/src/easynetwork/api_async/server/udp.py b/src/easynetwork/api_async/server/udp.py index b9565ec5..d31b333f 100644 --- a/src/easynetwork/api_async/server/udp.py +++ b/src/easynetwork/api_async/server/udp.py @@ -371,12 +371,12 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress) self.__logger.warning( "There have been attempts to do operation on closed client %s", client_address, - exc_info=True, + exc_info=excgrp, ) except Exception as exc: _remove_traceback_frames_in_place(exc, 1) # Removes the 'yield' frame just above self.__logger.error("-" * 40) - self.__logger.exception("Exception occurred during processing of request from %s", client_address) + self.__logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc) self.__logger.error("-" * 40) @_contextlib.contextmanager diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index 277f4c59..740b737a 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -71,6 +71,7 @@ class MyAsyncTCPRequestHandler(AsyncStreamRequestHandler[str, str]): close_all_clients_on_connection: bool = False close_client_after_n_request: int = -1 server: AsyncTCPNetworkServer[str, str] + fail_on_disconnection: bool = False async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: AsyncTCPNetworkServer[str, str]) -> None: await super().service_init(exit_stack, server) @@ -100,6 +101,8 @@ async def on_connection(self, client: AsyncStreamClient[str]) -> None: async def on_disconnection(self, client: AsyncStreamClient[str]) -> None: del self.connected_clients[client.address] del self.request_count[client.address] + if self.fail_on_disconnection: + raise ConnectionError("Trying to use the client in a disconnected state") async def handle(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, str]: if self.close_client_after_n_request >= 0 and self.request_count[client.address] >= self.close_client_after_n_request: @@ -361,6 +364,14 @@ async def factory() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: yield factory + @staticmethod + async def _wait_client_disconnected(writer: asyncio.StreamWriter, request_handler: MyAsyncTCPRequestHandler) -> None: + writer.close() + await writer.wait_closed() + async with asyncio.timeout(1): + while request_handler.connected_clients: + await asyncio.sleep(0.1) + @pytest.mark.parametrize("host", [None, ""], ids=repr) @pytest.mark.parametrize("log_client_connection", [True, False], ids=lambda p: f"log_client_connection=={p}") @pytest.mark.parametrize("use_ssl", ["NO_SSL"], indirect=True) @@ -479,12 +490,7 @@ async def test____serve_forever____accept_client( assert request_handler.request_received[client_address] == ["hello, world."] - writer.close() - await writer.wait_closed() - - async with asyncio.timeout(1): - while client_address in request_handler.connected_clients: - await asyncio.sleep(0.1) + await self._wait_client_disconnected(writer, request_handler) # skip Windows for this test, the ECONNRESET will happen on socket.send() or socket.recv() @pytest.mark.xfail('sys.platform == "win32"', reason="socket.getpeername() works by some magic") @@ -641,11 +647,7 @@ async def test____serve_forever____connection_reset_error( enable_socket_linger(writer.get_extra_info("socket"), timeout=0) - writer.close() - await writer.wait_closed() - async with asyncio.timeout(1): - while request_handler.connected_clients: - await asyncio.sleep(0.1) + await self._wait_client_disconnected(writer, request_handler) # ECONNRESET not logged assert len(caplog.records) == 0 @@ -744,6 +746,24 @@ async def test____serve_forever____connection_error_in_request_handler( assert await reader.read() == b"" assert len(caplog.records) == 0 + async def test____serve_forever____connection_error_in_disconnect_hook( + self, + client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]], + request_handler: MyAsyncTCPRequestHandler, + caplog: pytest.LogCaptureFixture, + server: MyAsyncTCPServer, + ) -> None: + caplog.set_level(logging.WARNING, server.logger.name) + _, writer = await client_factory() + request_handler.fail_on_disconnection = True + + await self._wait_client_disconnected(writer, request_handler) + + # ECONNRESET not logged + assert len(caplog.records) == 1 + assert caplog.records[0].levelno == logging.WARNING + assert caplog.records[0].message == "ConnectionError raised in request_handler.on_disconnection()" + async def test____serve_forever____explicitly_closed_by_request_handler( self, client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]],