diff --git a/src/easynetwork/lowlevel/std_asyncio/_flow_control.py b/src/easynetwork/lowlevel/std_asyncio/_flow_control.py index ef15b429..77628a52 100644 --- a/src/easynetwork/lowlevel/std_asyncio/_flow_control.py +++ b/src/easynetwork/lowlevel/std_asyncio/_flow_control.py @@ -88,6 +88,8 @@ def resume_writing(self) -> None: waiter.set_result(None) def connection_lost(self, exc: Exception | None) -> None: + if self.__connection_lost: # Already called, bail out. + return self.__write_paused = False self.__connection_lost = True self.__connection_lost_exception = exc diff --git a/src/easynetwork/lowlevel/std_asyncio/stream/socket.py b/src/easynetwork/lowlevel/std_asyncio/stream/socket.py index 907857e8..8587b70c 100644 --- a/src/easynetwork/lowlevel/std_asyncio/stream/socket.py +++ b/src/easynetwork/lowlevel/std_asyncio/stream/socket.py @@ -285,14 +285,23 @@ def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore self.__over_ssl = transport.get_extra_info("sslcontext") is not None def connection_lost(self, exc: Exception | None) -> None: + if self.__connection_lost: # Already called, bail out. + return + self.__connection_lost = True + self.__read_paused = False + + if exc is None and self.__buffer_nbytes_written > 0: + exc = _utils.error_from_errno(_errno.ECONNRESET) self.__connection_lost_exception = exc if exc is None: self.__eof_reached = True else: - self.__buffer_nbytes_written = 0 self.__connection_lost_exception_tb = exc.__traceback__ - self._maybe_release_buffer() + + self.__buffer_nbytes_written = 0 + self.__buffer = None + self.__buffer_view.release() if not self.__closed.done(): self.__closed.set_result(None) @@ -329,9 +338,8 @@ def eof_received(self) -> bool: return True async def receive_data(self, bufsize: int, /) -> bytes: - if self.__connection_lost: - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + if self.__connection_lost_exception is not None: + raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) if bufsize == 0: return b"" if bufsize < 0: @@ -354,9 +362,8 @@ async def receive_data(self, bufsize: int, /) -> bytes: return data async def receive_data_into(self, buffer: WriteableBuffer, /) -> int: - if self.__connection_lost: - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + if self.__connection_lost_exception is not None: + raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) with memoryview(buffer).cast("B") as buffer: if not buffer.nbytes: return 0 @@ -395,6 +402,8 @@ async def _wait_for_data(self, requester: str) -> None: await self.__read_waiter finally: self.__read_waiter = None + if self.__connection_lost_exception is not None: + raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) def _wakeup_read_waiter(self, exc: Exception | None) -> None: if (waiter := self.__read_waiter) is not None: @@ -434,9 +443,7 @@ def _maybe_pause_transport(self) -> None: self.__read_paused = True def _maybe_resume_transport(self) -> None: - if self.__connection_lost: - self._maybe_release_buffer() - elif ( + if ( self.__read_paused and (transport := self.__transport) is not None and self.__buffer_nbytes_written <= self.__read_low_water @@ -444,11 +451,6 @@ def _maybe_resume_transport(self) -> None: transport.resume_reading() self.__read_paused = False - def _maybe_release_buffer(self) -> None: - if not self.__buffer_nbytes_written and self.__connection_lost: - self.__buffer_view.release() - self.__buffer = None - def pause_writing(self) -> None: self.__write_flow.pause_writing() diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index e20348a3..08f6e9f2 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -1080,6 +1080,25 @@ def test____connection_lost____by_closed_transport( with pytest.raises(BufferError): protocol.get_buffer(-1) + def test____connection_lost____close_waiter_done( + self, + protocol: StreamReaderBufferedProtocol, + mock_asyncio_transport: MagicMock, + ) -> None: + # Arrange + close_waiter = protocol._get_close_waiter() + close_waiter.cancel() + assert close_waiter.done() + + # Act + protocol.connection_lost(None) + + # Assert + mock_asyncio_transport.close.assert_not_called() + + with pytest.raises(BufferError): + protocol.get_buffer(-1) + def test____connection_lost____by_unrelated_error( self, protocol: StreamReaderBufferedProtocol, @@ -1252,6 +1271,22 @@ async def test____receive_data____connection_lost_by_unrelated_error( assert exc_info.value is exception + @pytest.mark.asyncio + async def test____receive_data____connection_reset( + self, + event_loop: asyncio.AbstractEventLoop, + protocol: StreamReaderBufferedProtocol, + data_receiver: _ProtocolDataReceiver, + ) -> None: + # Arrange + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") + event_loop.call_soon(protocol.connection_lost, None) + + # Act & Assert + for _ in range(3): + with pytest.raises(ConnectionResetError): + _ = await data_receiver(protocol, 1024) + @pytest.mark.asyncio async def test____receive_data____invalid_bufsize( self, @@ -1325,31 +1360,10 @@ async def test____receive_data____read_flow_control( mock_asyncio_transport.resume_reading.assert_called_once() @pytest.mark.asyncio - async def test____receive_data____deferred_buffer_release( - self, - protocol: StreamReaderBufferedProtocol, - data_receiver: _ProtocolDataReceiver, - ) -> None: - # Arrange - self.write_in_protocol_buffer(protocol, b"abcdef") - - # Act & Assert - protocol.connection_lost(None) - assert protocol._get_read_buffer_size() == 6 - protocol.get_buffer(-1) # assert not raises - - assert (await data_receiver(protocol, 3)) == b"abc" - assert protocol._get_read_buffer_size() == 3 - protocol.get_buffer(-1) # assert not raises - - assert (await data_receiver(protocol, 3)) == b"def" - assert protocol._get_read_buffer_size() == 0 - with pytest.raises(BufferError): - protocol.get_buffer(-1) - - @pytest.mark.asyncio - async def test____receive_data____immediate_buffer_release( + @pytest.mark.parametrize("exception", [None, OSError("Something bad happen")]) + async def test____receive_data____buffer_release( self, + exception: OSError | None, protocol: StreamReaderBufferedProtocol, data_receiver: _ProtocolDataReceiver, ) -> None: @@ -1357,7 +1371,7 @@ async def test____receive_data____immediate_buffer_release( self.write_in_protocol_buffer(protocol, b"abcdef") # Act & Assert - protocol.connection_lost(OSError("Something bad happen")) + protocol.connection_lost(exception) assert protocol._get_read_buffer_size() == 0 with pytest.raises(BufferError): protocol.get_buffer(-1)