From 053711b93e8526150fb7c4e1b8a147e70b44da04 Mon Sep 17 00:00:00 2001 From: Sviatoslav Sydorenko Date: Mon, 9 Nov 2020 14:59:27 +0100 Subject: [PATCH] Add a test for `WANT_READ` during `sendall()` --- tests/test_ssl.py | 117 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 114 insertions(+), 3 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index aed236703..ab40030d7 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -308,7 +308,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD): return server -def loopback(server_factory=None, client_factory=None): +def loopback(server_factory=None, client_factory=None, blocking=True): """ Create a connected socket pair and force two connected SSL sockets to talk to each other via memory BIOs. @@ -324,8 +324,8 @@ def loopback(server_factory=None, client_factory=None): handshake(client, server) - server.setblocking(True) - client.setblocking(True) + server.setblocking(blocking) + client.setblocking(blocking) return server, client @@ -3131,11 +3131,122 @@ def test_memoryview_really_doesnt_overfill(self): self._doesnt_overfill_test(_make_memoryview) +@pytest.fixture +def nonblocking_tls_connections_pair(): + """Return a non-blocking TLS loopback connections pair.""" + return loopback(blocking=False) + + +@pytest.fixture +def nonblocking_tls_server_connection(nonblocking_tls_connections_pair): + """Return a non-blocking TLS server socket connected to loopback.""" + return nonblocking_tls_connections_pair[0] + + +@pytest.fixture +def nonblocking_tls_client_connection(nonblocking_tls_connections_pair): + """Return a non-blocking TLS client socket connected to loopback.""" + return nonblocking_tls_connections_pair[1] + + class TestConnectionSendall(object): """ Tests for `Connection.sendall`. """ + def test_want_write( + self, + monkeypatch, + nonblocking_tls_server_connection, + nonblocking_tls_client_connection, + ): + msg = b"x" + garbage_size = 1024 * 1024 * 64 + garbage_payload = msg * garbage_size + large_payload = b"p" * garbage_size * 2 + payload_size = len(large_payload) + + for i in range(garbage_size): + try: + nonblocking_tls_client_connection.send(msg) + except WantWriteError: + break + else: + pytest.fail( + "Failed to fill socket buffer, cannot test " + "'want write' in `sendall()`" + ) + + def consume_garbage(conn): + if patched_ssl_write.want_write_counter < 5: + # NOTE: Ensure that sendall will make a few internal retries + return + + assert not consume_garbage.garbage_consumed + + consume_garbage.consumed += conn.recv(garbage_size) + if len(consume_garbage.consumed) < garbage_size: + return + + assert consume_garbage.consumed == garbage_payload + + consume_garbage.garbage_consumed = True + consume_garbage.garbage_consumed = False + consume_garbage.consumed = b"" + + consumed_payload = b"" + def consume_payload(conn): + consumed_payload += conn.recv(payload_size) + # FIXME: invoke conn.renegotiate()? + + original_ssl_write = _lib.SSL_write + def patched_ssl_write(ctx, data, size): + consume_data_on_server = ( + consume_payload if consume_garbage.garbage_consumed + else consume_garbage + ) + consume_data_on_server(nonblocking_tls_server_connection) + write_result = original_ssl_write(ctx, data, size) + try: + nonblocking_tls_client_connection._raise_ssl_error( + ctx, write_result, + ) + except WantWriteError: + patched_ssl_write.want_write_counter += 1 + consume_data_on_server = ( + consume_payload if consume_garbage.garbage_consumed + else consume_garbage + ) + consume_data_on_server(nonblocking_tls_server_connection) + #breakpoint() + # NOTE: We don't re-raise it as the calling code will do + # NOTE: the same after the call. + return write_result + + patched_ssl_write.want_write_counter = 0 + + # NOTE: Make the client think it needs a handshake so that it'll + # NOTE: attempt to `do_handshake()` on the next `SSL_write()` + # NOTE: that originates from `sendall()`: + nonblocking_tls_client_connection.set_connect_state() + try: + nonblocking_tls_client_connection.do_handshake() + except WantWriteError: + assert True # Sanity check + except: + assert False # This should never happen (see the note above) + + monkeypatch.setattr(_lib, "SSL_write", patched_ssl_write) + + nonblocking_tls_client_connection.sendall(large_payload) + + assert consume_garbage.garbage_consumed + + # NOTE: Read the leftover data from the very last `SSL_write()` + consume_payload(nonblocking_tls_server_connection) + + assert consumed_payload == large_payload + def test_wrong_args(self): """ When called with arguments other than a string argument for its first