Skip to content

Commit

Permalink
TLS transport: Fixed a big issue due to wrong unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia committed Jun 27, 2024
1 parent 61d89ea commit b30d27b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 44 deletions.
12 changes: 6 additions & 6 deletions src/easynetwork/lowlevel/api_async/transports/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,17 @@ async def recv_into(self, buffer: WriteableBuffer) -> int:

@_utils.inherit_doc(AsyncStreamTransport)
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
if not self.__closing:
self.__data_to_send.append(memoryview(data))
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
self.__data_to_send.append(memoryview(data))
del data
return await self.__flush_data_to_send()

@_utils.inherit_doc(AsyncStreamTransport)
async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
if not self.__closing:
self.__data_to_send.extend(map(memoryview, iterable_of_data))
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
self.__data_to_send.extend(map(memoryview, iterable_of_data))
del iterable_of_data
return await self.__flush_data_to_send()

Expand Down Expand Up @@ -254,8 +256,6 @@ async def _retry_ssl_method(
*args: *_T_PosArgs,
) -> _T_Return:
assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
while True:
try:
result = ssl_object_method(*args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ async def tls_transport(
@pytest.fixture
@staticmethod
def mock_tls_transport_retry(mocker: MockerFixture) -> AsyncMock:
side_effect: Callable[..., Any] = lambda ssl_object_method, *args: ssl_object_method(*args)
return mocker.patch.object(AsyncTLSStreamTransport, "_retry_ssl_method", side_effect=side_effect)
return mocker.patch.object(
AsyncTLSStreamTransport,
"_retry_ssl_method",
autospec=True,
wraps=AsyncTLSStreamTransport._retry_ssl_method,
)

async def test____wrap____default(
self,
Expand Down Expand Up @@ -135,7 +139,7 @@ async def test____wrap____default(
server_hostname="server_hostname",
session=None,
)
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.do_handshake)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.do_handshake)
assert mock_ssl_object.mock_calls == [mocker.call.do_handshake(), mocker.call.getpeercert()]
## Attributes
assert tls_transport._shutdown_timeout == DEFAULT_SSL_SHUTDOWN_TIMEOUT
Expand Down Expand Up @@ -199,7 +203,7 @@ async def test____wrap____with_parameters(
server_hostname=server_hostname,
session=session,
)
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.do_handshake)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.do_handshake)
assert mock_ssl_object.mock_calls == [mocker.call.do_handshake(), mocker.call.getpeercert()]
assert mock_wrapped_transport.mock_calls == [mocker.call.backend()]
## Attributes
Expand Down Expand Up @@ -257,7 +261,7 @@ async def test____wrap____handshake_timeout(
mock_ssl_object: MagicMock,
) -> None:
# Arrange
async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -> Any:
async def retry_side_effect(self: AsyncTLSStreamTransport, ssl_object_method: Callable[..., Any], *args: Any) -> Any:
await asyncio.sleep(5)
return ssl_object_method(*args)

Expand Down Expand Up @@ -325,7 +329,6 @@ async def test____dunder_del____ResourceWarning(
async def test____aclose____close_transport(
self,
tls_transport: AsyncTLSStreamTransport,
mock_tls_transport_retry: AsyncMock,
mock_wrapped_transport: MagicMock,
mock_ssl_object: MagicMock,
standard_compatible: bool,
Expand All @@ -341,22 +344,19 @@ async def test____aclose____close_transport(
# Assert
assert tls_transport.is_closing()
if standard_compatible:
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap)
mock_ssl_object.unwrap.assert_called_once_with()
assert read_bio.eof
assert write_bio.eof
else:
mock_tls_transport_retry.assert_not_called()
mock_ssl_object.unwrap.assert_not_called()
assert not read_bio.eof
assert not write_bio.eof
mock_wrapped_transport.aclose.assert_awaited_once_with()

@pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}")
async def test____aclose____indempotent(
async def test____aclose____idempotent(
self,
tls_transport: AsyncTLSStreamTransport,
mock_tls_transport_retry: AsyncMock,
mock_wrapped_transport: MagicMock,
mock_ssl_object: MagicMock,
standard_compatible: bool,
Expand All @@ -373,12 +373,10 @@ async def test____aclose____indempotent(

# Assert
if standard_compatible:
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap)
mock_ssl_object.unwrap.assert_called_once_with()
assert read_bio.eof
assert write_bio.eof
else:
mock_tls_transport_retry.assert_not_called()
mock_ssl_object.unwrap.assert_not_called()
assert not read_bio.eof
assert not write_bio.eof
Expand All @@ -396,7 +394,7 @@ async def test____aclose____shutdown_timeout(
write_bio: ssl.MemoryBIO,
) -> None:
# Arrange
async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -> Any:
async def retry_side_effect(self: AsyncTLSStreamTransport, ssl_object_method: Callable[..., Any], *args: Any) -> Any:
await asyncio.sleep(5)
return ssl_object_method(*args)

Expand All @@ -406,7 +404,7 @@ async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -
await tls_transport.aclose()

# Assert
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.unwrap)
assert not read_bio.eof
assert not write_bio.eof
mock_wrapped_transport.aclose.assert_awaited_once_with()
Expand All @@ -422,7 +420,7 @@ async def test____aclose____mask_unwrap_error(
write_bio: ssl.MemoryBIO,
) -> None:
# Arrange
async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -> Any:
async def retry_side_effect(self: AsyncTLSStreamTransport, ssl_object_method: Callable[..., Any], *args: Any) -> Any:
try:
return ssl_object_method(*args)
except ssl.SSLError:
Expand All @@ -437,7 +435,7 @@ async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -
await tls_transport.aclose()

# Assert
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.unwrap)
assert read_bio.eof
assert write_bio.eof
mock_wrapped_transport.aclose.assert_awaited_once_with()
Expand All @@ -456,7 +454,7 @@ async def test____recv____default(

# Assert
assert data == b"decrypted-data"
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 123456)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 123456)
mock_ssl_object.read.assert_called_once_with(123456)

async def test____recv____null_buffer(
Expand All @@ -473,7 +471,7 @@ async def test____recv____null_buffer(

# Assert
assert data == b""
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 0)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 0)
mock_ssl_object.read.assert_called_once_with(0)

@pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}")
Expand Down Expand Up @@ -540,7 +538,7 @@ async def test____recv_into____default(

# Assert
assert nbytes == 42
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 1234, buffer)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 1234, buffer)
mock_ssl_object.read.assert_called_once_with(1234, buffer)

async def test____recv_into____null_buffer(
Expand All @@ -558,7 +556,7 @@ async def test____recv_into____null_buffer(

# Assert
assert nbytes == 0
mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 1024, buffer)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 1024, buffer)
mock_ssl_object.read.assert_called_once_with(1024, buffer)

@pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}")
Expand Down Expand Up @@ -728,17 +726,20 @@ async def test____send_all____unrelated_ssl_error(
async def test____send_all____closed_transport(
self,
tls_transport: AsyncTLSStreamTransport,
mock_tls_transport_retry: AsyncMock,
mock_ssl_object: MagicMock,
) -> None:
# Arrange
mock_ssl_object.unwrap.return_value = None
await tls_transport.aclose()
mock_tls_transport_retry.reset_mock()

# Act & Assert
with pytest.raises(ConnectionAbortedError):
await tls_transport.send_all(b"decrypted-data")

mock_ssl_object.write.assert_not_called()
mock_tls_transport_retry.assert_not_awaited()
assert tls_transport._test__data_queue() == []

@pytest.mark.usefixtures("mock_tls_transport_retry")
Expand Down Expand Up @@ -786,17 +787,20 @@ async def test____send_all_from_iterable____partial_data(
async def test____send_all_from_iterable____closed_transport(
self,
tls_transport: AsyncTLSStreamTransport,
mock_tls_transport_retry: AsyncMock,
mock_ssl_object: MagicMock,
) -> None:
# Arrange
mock_ssl_object.unwrap.return_value = None
await tls_transport.aclose()
mock_tls_transport_retry.reset_mock()

# Act & Assert
with pytest.raises(ConnectionAbortedError):
await tls_transport.send_all_from_iterable([b"decrypted-data-1", b"decrypted-data-2"])

mock_ssl_object.write.assert_not_called()
mock_tls_transport_retry.assert_not_awaited()
assert tls_transport._test__data_queue() == []

async def test____send_eof____default(
Expand Down Expand Up @@ -992,24 +996,6 @@ async def test____retry____unrelated_ssl_error(
assert read_bio.eof
assert write_bio.eof

@pytest.mark.parametrize("standard_compatible", [False], ids=lambda p: f"standard_compatible=={p}")
async def test____retry____closed_transport(
self,
tls_transport: AsyncTLSStreamTransport,
mock_wrapped_transport: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_wrapped_transport.send_all.return_value = None
ssl_object_method = mocker.stub()
await tls_transport.aclose()

# Act & Assert
with pytest.raises(ConnectionAbortedError):
await tls_transport._retry_ssl_method(ssl_object_method)

ssl_object_method.assert_not_called()

async def test____get_backend____returns_inner_transport_backend(
self,
tls_transport: AsyncTLSStreamTransport,
Expand Down

0 comments on commit b30d27b

Please sign in to comment.