Skip to content

Commit

Permalink
Add a test for WANT_READ during sendall()
Browse files Browse the repository at this point in the history
  • Loading branch information
webknjaz committed Jan 23, 2024
1 parent c8fdfdf commit 9dae3c6
Showing 1 changed file with 126 additions and 3 deletions.
129 changes: 126 additions & 3 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
return server


def loopback(server_factory=None, client_factory=None):
def loopback(server_factory=None, client_factory=None, blocking=True):
"""
Create a connected socket pair and force two connected SSL sockets
to talk to each other via memory BIOs.
Expand All @@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):

handshake(client, server)

server.setblocking(True)
client.setblocking(True)
server.setblocking(blocking)
client.setblocking(blocking)
return server, client


Expand Down Expand Up @@ -3297,11 +3297,134 @@ def test_memoryview_really_doesnt_overfill(self):
self._doesnt_overfill_test(_make_memoryview)


@pytest.fixture
def nonblocking_tls_connections_pair():
"""Return a non-blocking TLS loopback connections pair."""
return loopback(blocking=False)


@pytest.fixture
def nonblocking_tls_server_connection(nonblocking_tls_connections_pair):
"""Return a non-blocking TLS server socket connected to loopback."""
return nonblocking_tls_connections_pair[0]


@pytest.fixture
def nonblocking_tls_client_connection(nonblocking_tls_connections_pair):
"""Return a non-blocking TLS client socket connected to loopback."""
return nonblocking_tls_connections_pair[1]


class TestConnectionSendall:
"""
Tests for `Connection.sendall`.
"""

def test_want_write(
self,
monkeypatch,
nonblocking_tls_server_connection,
nonblocking_tls_client_connection,
):
msg = b"x"
garbage_size = 1024 * 1024 * 64
large_payload = b"p" * garbage_size * 2
payload_size = len(large_payload)

sent_garbage_size = 0
try:
sent_garbage_size += nonblocking_tls_client_connection.send(
msg * garbage_size,
)
except WantWriteError:
pass
for i in range(garbage_size):
try:
sent_garbage_size += nonblocking_tls_client_connection.send(
msg,
)
except WantWriteError:
break
else:
pytest.fail(
"Failed to fill socket buffer, cannot test "
"'want write' in `sendall()`"
)
garbage_payload = sent_garbage_size * msg

def consume_garbage(conn):
assert patched_ssl_write.want_write_counter >= 1
assert not consume_garbage.garbage_consumed

while len(consume_garbage.consumed) < sent_garbage_size:
try:
consume_garbage.consumed += conn.recv(
sent_garbage_size - len(consume_garbage.consumed),
)
except WantReadError:
pass

assert consume_garbage.consumed == garbage_payload

consume_garbage.garbage_consumed = True

consume_garbage.garbage_consumed = False
consume_garbage.consumed = b""

def consume_payload(conn):
try:
consume_payload.consumed += conn.recv(payload_size)
except WantReadError:
pass

consume_payload.consumed = b""

original_ssl_write = _lib.SSL_write

def patched_ssl_write(ctx, data, size):
write_result = original_ssl_write(ctx, data, size)
try:
nonblocking_tls_client_connection._raise_ssl_error(
ctx,
write_result,
)
except WantWriteError:
patched_ssl_write.want_write_counter += 1
consume_data_on_server = (
consume_payload
if consume_garbage.garbage_consumed
else consume_garbage
)

consume_data_on_server(nonblocking_tls_server_connection)
# NOTE: We don't re-raise it as the calling code will do
# NOTE: the same after the call.
return write_result

patched_ssl_write.want_write_counter = 0

# NOTE: Make the client think it needs a handshake so that it'll
# NOTE: attempt to `do_handshake()` on the next `SSL_write()`
# NOTE: that originates from `sendall()`:
nonblocking_tls_client_connection.set_connect_state()
try:
nonblocking_tls_client_connection.do_handshake()
except WantWriteError:
assert True # Sanity check
except:
assert False # This should never happen (see the note above)

with monkeypatch.context() as mp_ctx:
mp_ctx.setattr(_lib, "SSL_write", patched_ssl_write)
nonblocking_tls_client_connection.sendall(large_payload)

assert consume_garbage.garbage_consumed

# NOTE: Read the leftover data from the very last `SSL_write()`
consume_payload(nonblocking_tls_server_connection)

assert consume_payload.consumed == large_payload

def test_wrong_args(self):
"""
When called with arguments other than a string argument for its first
Expand Down

0 comments on commit 9dae3c6

Please sign in to comment.