Skip to content

Commit

Permalink
TCP servers: Fixed memory leak on connection reset (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Mar 3, 2024
1 parent 54d02f2 commit a2f1f94
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 41 deletions.
2 changes: 2 additions & 0 deletions src/easynetwork/lowlevel/std_asyncio/_flow_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def resume_writing(self) -> None:
waiter.set_result(None)

def connection_lost(self, exc: Exception | None) -> None:
if self.__connection_lost: # Already called, bail out.
return
self.__write_paused = False
self.__connection_lost = True
self.__connection_lost_exception = exc
Expand Down
34 changes: 18 additions & 16 deletions src/easynetwork/lowlevel/std_asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,23 @@ def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore
self.__over_ssl = transport.get_extra_info("sslcontext") is not None

def connection_lost(self, exc: Exception | None) -> None:
if self.__connection_lost: # Already called, bail out.
return

self.__connection_lost = True
self.__read_paused = False

if exc is None and self.__buffer_nbytes_written > 0:
exc = _utils.error_from_errno(_errno.ECONNRESET)
self.__connection_lost_exception = exc
if exc is None:
self.__eof_reached = True
else:
self.__buffer_nbytes_written = 0
self.__connection_lost_exception_tb = exc.__traceback__
self._maybe_release_buffer()

self.__buffer_nbytes_written = 0
self.__buffer = None
self.__buffer_view.release()

if not self.__closed.done():
self.__closed.set_result(None)
Expand Down Expand Up @@ -329,9 +338,8 @@ def eof_received(self) -> bool:
return True

async def receive_data(self, bufsize: int, /) -> bytes:
if self.__connection_lost:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
if bufsize == 0:
return b""
if bufsize < 0:
Expand All @@ -354,9 +362,8 @@ async def receive_data(self, bufsize: int, /) -> bytes:
return data

async def receive_data_into(self, buffer: WriteableBuffer, /) -> int:
if self.__connection_lost:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
with memoryview(buffer).cast("B") as buffer:
if not buffer.nbytes:
return 0
Expand Down Expand Up @@ -395,6 +402,8 @@ async def _wait_for_data(self, requester: str) -> None:
await self.__read_waiter
finally:
self.__read_waiter = None
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)

def _wakeup_read_waiter(self, exc: Exception | None) -> None:
if (waiter := self.__read_waiter) is not None:
Expand Down Expand Up @@ -434,21 +443,14 @@ def _maybe_pause_transport(self) -> None:
self.__read_paused = True

def _maybe_resume_transport(self) -> None:
if self.__connection_lost:
self._maybe_release_buffer()
elif (
if (
self.__read_paused
and (transport := self.__transport) is not None
and self.__buffer_nbytes_written <= self.__read_low_water
):
transport.resume_reading()
self.__read_paused = False

def _maybe_release_buffer(self) -> None:
if not self.__buffer_nbytes_written and self.__connection_lost:
self.__buffer_view.release()
self.__buffer = None

def pause_writing(self) -> None:
self.__write_flow.pause_writing()

Expand Down
64 changes: 39 additions & 25 deletions tests/unit_test/test_async/test_asyncio_backend/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,25 @@ def test____connection_lost____by_closed_transport(
with pytest.raises(BufferError):
protocol.get_buffer(-1)

def test____connection_lost____close_waiter_done(
self,
protocol: StreamReaderBufferedProtocol,
mock_asyncio_transport: MagicMock,
) -> None:
# Arrange
close_waiter = protocol._get_close_waiter()
close_waiter.cancel()
assert close_waiter.done()

# Act
protocol.connection_lost(None)

# Assert
mock_asyncio_transport.close.assert_not_called()

with pytest.raises(BufferError):
protocol.get_buffer(-1)

def test____connection_lost____by_unrelated_error(
self,
protocol: StreamReaderBufferedProtocol,
Expand Down Expand Up @@ -1252,6 +1271,22 @@ async def test____receive_data____connection_lost_by_unrelated_error(

assert exc_info.value is exception

@pytest.mark.asyncio
async def test____receive_data____connection_reset(
self,
event_loop: asyncio.AbstractEventLoop,
protocol: StreamReaderBufferedProtocol,
data_receiver: _ProtocolDataReceiver,
) -> None:
# Arrange
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc")
event_loop.call_soon(protocol.connection_lost, None)

# Act & Assert
for _ in range(3):
with pytest.raises(ConnectionResetError):
_ = await data_receiver(protocol, 1024)

@pytest.mark.asyncio
async def test____receive_data____invalid_bufsize(
self,
Expand Down Expand Up @@ -1325,39 +1360,18 @@ async def test____receive_data____read_flow_control(
mock_asyncio_transport.resume_reading.assert_called_once()

@pytest.mark.asyncio
async def test____receive_data____deferred_buffer_release(
self,
protocol: StreamReaderBufferedProtocol,
data_receiver: _ProtocolDataReceiver,
) -> None:
# Arrange
self.write_in_protocol_buffer(protocol, b"abcdef")

# Act & Assert
protocol.connection_lost(None)
assert protocol._get_read_buffer_size() == 6
protocol.get_buffer(-1) # assert not raises

assert (await data_receiver(protocol, 3)) == b"abc"
assert protocol._get_read_buffer_size() == 3
protocol.get_buffer(-1) # assert not raises

assert (await data_receiver(protocol, 3)) == b"def"
assert protocol._get_read_buffer_size() == 0
with pytest.raises(BufferError):
protocol.get_buffer(-1)

@pytest.mark.asyncio
async def test____receive_data____immediate_buffer_release(
@pytest.mark.parametrize("exception", [None, OSError("Something bad happen")])
async def test____receive_data____buffer_release(
self,
exception: OSError | None,
protocol: StreamReaderBufferedProtocol,
data_receiver: _ProtocolDataReceiver,
) -> None:
# Arrange
self.write_in_protocol_buffer(protocol, b"abcdef")

# Act & Assert
protocol.connection_lost(OSError("Something bad happen"))
protocol.connection_lost(exception)
assert protocol._get_read_buffer_size() == 0
with pytest.raises(BufferError):
protocol.get_buffer(-1)
Expand Down

0 comments on commit a2f1f94

Please sign in to comment.