diff --git a/demo_server/demo_server/app.py b/demo_server/demo_server/app.py index c65845f3..7c708a8e 100644 --- a/demo_server/demo_server/app.py +++ b/demo_server/demo_server/app.py @@ -1,7 +1,11 @@ from fastapi import FastAPI, HTTPException, Depends -import uvicorn +from hypercorn.config import Config +from hypercorn.asyncio import serve +import asyncio import os, binascii +import uvicorn +import sys from sqlmodel import create_engine, SQLModel, Session @@ -22,6 +26,7 @@ def on_startup(): SQLModel.metadata.create_all(engine) + if __name__ == "__main__": app_port = os.getenv('DEMO_SERVER_PORT') @@ -34,5 +39,17 @@ def on_startup(): if app_host is None: app_host = "0.0.0.0" - uvicorn.run("app:app", reload=True, host=app_host, port=app_port) + use_http2 = False + for i in range(len(sys.argv)): + if sys.argv[i] == '--use_http2': + use_http2 = True + + if not use_http2: + uvicorn.run("app:app", reload=True, host=app_host, port=app_port) + else: + config = Config() + config.bind = [app_host + ":" + str(app_port)] + loop = asyncio.get_event_loop() + loop.run_until_complete(serve(app, config)) + diff --git a/demo_server/requirements.txt b/demo_server/requirements.txt index 8e86e6cc..15632749 100644 --- a/demo_server/requirements.txt +++ b/demo_server/requirements.txt @@ -1,3 +1,4 @@ fastapi==0.78.0 +hypercorn==0.14.3 uvicorn==0.18.1 sqlmodel==0.0.6 \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index a3ce90a7..7bbff50c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,6 +11,7 @@ RUN python3 -m ensurepip RUN pip3 install --upgrade pip RUN pip3 install requests RUN pip3 install applicationinsights +RUN pip3 install h2 COPY ./engine /RESTler/engine RUN python3 -m compileall -b /RESTler/engine COPY ./resultsAnalyzer /RESTler/resultsAnalyzer diff --git a/restler-quick-start.py b/restler-quick-start.py old mode 100644 new mode 100755 index f7040cb7..7378f2da --- a/restler-quick-start.py +++ b/restler-quick-start.py @@ -42,9 +42,11 @@ def compile_spec(api_spec_path, restler_dll_path): print(f"command: {command}") subprocess.run(command, shell=True) -def add_common_settings(ip, port, host, use_ssl, command): +def add_common_settings(ip, port, host, use_ssl, use_http2, command): if not use_ssl: command = f"{command} --no_ssl" + if use_http2: + command = f"{command} --http2" if ip is not None: command = f"{command} --target_ip {ip}" if port is not None: @@ -53,18 +55,18 @@ def add_common_settings(ip, port, host, use_ssl, command): command = f"{command} --host {host}" return command -def replay_bug(ip, port, host, use_ssl, restler_dll_path, replay_log): +def replay_bug(ip, port, host, use_ssl, use_http2, restler_dll_path, replay_log): """ Runs RESTler's replay mode on the specified replay file """ with usedir(RESTLER_TEMP_DIR): command = ( f"dotnet \"{restler_dll_path}\" replay --replay_log \"{replay_log}\"" ) - command = add_common_settings(ip, port, host, use_ssl, command) + command = add_common_settings(ip, port, host, use_ssl, use_http2, command) print(f"command: {command}\n") subprocess.run(command, shell=True) -def replay_from_dir(ip, port, host, use_ssl, restler_dll_path, replay_dir): +def replay_from_dir(ip, port, host, use_ssl, use_http2, restler_dll_path, replay_dir): import glob from pathlib import Path # get all the 500 replay files in the bug buckets directory @@ -74,10 +76,10 @@ def replay_from_dir(ip, port, host, use_ssl, restler_dll_path, replay_dir): if "bug_buckets" in os.path.basename(file_path): continue print(f"Testing replay file: {file_path}") - replay_bug(ip, port, host, use_ssl, restler_dll_path, Path(file_path).absolute()) + replay_bug(ip, port, host, use_ssl, use_http2, restler_dll_path, Path(file_path).absolute()) pass -def test_spec(ip, port, host, use_ssl, restler_dll_path, task): +def test_spec(ip, port, host, use_ssl, use_http2, restler_dll_path, task): """ Runs RESTler's test mode on a specified Compile directory @param ip: The IP of the service to test @@ -88,6 +90,8 @@ def test_spec(ip, port, host, use_ssl, restler_dll_path, task): @type host: Str @param use_ssl: If False, set the --no_ssl parameter when executing RESTler @type use_ssl: Boolean + @param use_http2: If True, set the --http2 parameter when executing RESTler + @type use_http2: Boolean @param restler_dll_path: The absolute path to the RESTler driver's dll @type restler_dll_path: Str @@ -107,8 +111,8 @@ def test_spec(ip, port, host, use_ssl, restler_dll_path, task): f" --settings \"{settings_file_path}\"" ) print(f"command: {command}\n") - command = add_common_settings(ip, port, host, use_ssl, command) - + command = add_common_settings(ip, port, host, use_ssl, use_http2, command) + print(f"command: {command}\n") subprocess.run(command, shell=True) if __name__ == '__main__': @@ -129,6 +133,9 @@ def test_spec(ip, port, host, use_ssl, restler_dll_path, task): parser.add_argument('--use_ssl', help='Set this flag if you want to use SSL validation for the socket', action='store_true') + parser.add_argument('--use_http2', + help='Set this flag if you want to use HTTP2', + action='store_true') parser.add_argument('--host', help='The hostname of the service to test', type=str, required=False, default=None) @@ -146,13 +153,13 @@ def test_spec(ip, port, host, use_ssl, restler_dll_path, task): print(f"\nrestler_dll_path: {restler_dll_path}\n") if args.task == "replay": - replay_from_dir(args.ip, args.port, args.host, args.use_ssl, restler_dll_path.absolute(), args.replay_bug_buckets_dir) + replay_from_dir(args.ip, args.port, args.host, args.use_ssl, args.use_http2, restler_dll_path.absolute(), args.replay_bug_buckets_dir) else: if args.api_spec_path is None: print("api_spec_path is required for all tasks except the replay task.") exit(-1) api_spec_path = os.path.abspath(args.api_spec_path) compile_spec(api_spec_path, restler_dll_path.absolute()) - test_spec(args.ip, args.port, args.host, args.use_ssl, restler_dll_path.absolute(), args.task) + test_spec(args.ip, args.port, args.host, args.use_ssl, args.use_http2, restler_dll_path.absolute(), args.task) print(f"Test complete.\nSee {os.path.abspath(RESTLER_TEMP_DIR)} for results.") diff --git a/restler/end_to_end_tests/test_quick_start.py b/restler/end_to_end_tests/test_quick_start.py index 4b27341e..693f1351 100644 --- a/restler/end_to_end_tests/test_quick_start.py +++ b/restler/end_to_end_tests/test_quick_start.py @@ -7,6 +7,7 @@ created during quick start test. To call: python ./test_quick_start.py +To call with HTTP/2: python ./test_quick_start.py --use_http2 """ import sys import os @@ -51,10 +52,10 @@ def check_expected_output(restler_working_dir, expected_strings, output, task_di net_log = nf.read() raise QuickStartFailedException(f"Failing because expected output '{expected_str}' was not found:\n{stdout}{out}{err}{net_log}") -def test_test_task(restler_working_dir, swagger_path, restler_drop_dir): +def test_test_task(restler_working_dir, swagger_path, restler_drop_dir, use_http2): # Run the quick start script output = subprocess.run( - f'python ./restler-quick-start.py --api_spec_path {swagger_path} --restler_drop_dir {restler_drop_dir} --task test', + f'python ./restler-quick-start.py --api_spec_path {swagger_path} --restler_drop_dir {restler_drop_dir} --task test {use_http2}', shell=True, capture_output=True ) expected_strings = [ @@ -66,10 +67,10 @@ def test_test_task(restler_working_dir, swagger_path, restler_drop_dir): check_output_errors(output) check_expected_output(restler_working_dir, expected_strings, output, "Test") -def test_fuzzlean_task(restler_working_dir, swagger_path, restler_drop_dir): +def test_fuzzlean_task(restler_working_dir, swagger_path, restler_drop_dir, use_http2): # Run the quick start script output = subprocess.run( - f'python ./restler-quick-start.py --api_spec_path {swagger_path} --restler_drop_dir {restler_drop_dir} --task fuzz-lean', + f'python ./restler-quick-start.py --api_spec_path {swagger_path} --restler_drop_dir {restler_drop_dir} --task fuzz-lean {use_http2}', shell=True, capture_output=True ) expected_strings = [ @@ -85,7 +86,7 @@ def test_fuzzlean_task(restler_working_dir, swagger_path, restler_drop_dir): check_output_errors(output) check_expected_output(restler_working_dir, expected_strings, output, "FuzzLean") -def test_fuzz_task(restler_working_dir, swagger_path, restler_drop_dir): +def test_fuzz_task(restler_working_dir, swagger_path, restler_drop_dir, use_http2): import json compile_dir = Path(restler_working_dir, f'Compile') settings_file_path = compile_dir.joinpath('engine_settings.json') @@ -105,17 +106,17 @@ def test_fuzz_task(restler_working_dir, swagger_path, restler_drop_dir): 'Task Fuzz succeeded.' ] output = subprocess.run( - f'python ./restler-quick-start.py --api_spec_path {swagger_path} --restler_drop_dir {restler_drop_dir} --task fuzz', + f'python ./restler-quick-start.py --api_spec_path {swagger_path} --restler_drop_dir {restler_drop_dir} --task fuzz {use_http2}', shell=True, capture_output=True ) check_output_errors(output) # check_expected_output(restler_working_dir, expected_strings, output) -def test_replay_task(restler_working_dir, task_output_dir, restler_drop_dir): +def test_replay_task(restler_working_dir, task_output_dir, restler_drop_dir, use_http2): # Run the quick start script print(f"Testing replay for bugs found in task output dir: {task_output_dir}") output = subprocess.run( - f'python ./restler-quick-start.py --replay_bug_buckets_dir {task_output_dir} --restler_drop_dir {restler_drop_dir} --task replay', + f'python ./restler-quick-start.py --replay_bug_buckets_dir {task_output_dir} --restler_drop_dir {restler_drop_dir} --task replay {use_http2}', shell=True, capture_output=True ) check_output_errors(output) @@ -137,7 +138,7 @@ def test_replay_task(restler_working_dir, task_output_dir, restler_drop_dir): with open(network_log) as rf, open(original_bug_buckets_file_path) as of: orig_buckets = of.read() log_contents = rf.read() - if 'HTTP/1.1 500 Internal Server Error' not in log_contents: + if 'HTTP/1.1 500 Internal Server Error' not in log_contents and 'HTTP/2.0 500 Internal Server Error' not in log_contents: raise QuickStartFailedException(f"Failing because bug buckets {orig_buckets} were not reproduced. Replay log: {log_contents}.") else: print("500 error was reproduced.") @@ -165,7 +166,11 @@ def get_demo_server_output(demo_server_process): else: creationflags = 0 - demo_server_process = subprocess.Popen([sys.executable, demo_server_path], + use_http2 = "" + if len(sys.argv) > 2: + use_http2 = sys.argv[2] + + demo_server_process = subprocess.Popen([sys.executable, demo_server_path, use_http2], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, creationflags=creationflags) @@ -181,17 +186,17 @@ def get_demo_server_output(demo_server_process): test_failed = False try: print("+++++++++++++++++++++++++++++test...") - test_test_task(restler_working_dir, swagger_path, restler_drop_dir) + test_test_task(restler_working_dir, swagger_path, restler_drop_dir, use_http2) print("+++++++++++++++++++++++++++++fuzzlean...") - test_fuzzlean_task(restler_working_dir, swagger_path, restler_drop_dir) + test_fuzzlean_task(restler_working_dir, swagger_path, restler_drop_dir, use_http2) print("+++++++++++++++++++++++++++++replay...") fuzzlean_task_dir = os.path.join(curr, RESTLER_WORKING_DIR, 'FuzzLean') - test_replay_task(restler_working_dir, fuzzlean_task_dir, restler_drop_dir) + test_replay_task(restler_working_dir, fuzzlean_task_dir, restler_drop_dir, use_http2) #print("+++++++++++++++++++++++++++++fuzz...") - #test_fuzz_task(restler_working_dir, swagger_path, restler_drop_dir) + #test_fuzz_task(restler_working_dir, swagger_path, restler_drop_dir, use_http2) except Exception: test_failed = True diff --git a/restler/engine/transport_layer/messaging.py b/restler/engine/transport_layer/messaging.py index f2e7f30c..a153f9c1 100644 --- a/restler/engine/transport_layer/messaging.py +++ b/restler/engine/transport_layer/messaging.py @@ -8,6 +8,10 @@ import time import threading from importlib import util +import h2.connection +import h2.events +import h2.config +from http.client import responses from utils.logger import raw_network_logging as RAW_LOGGING from engine.errors import TransportLayerException @@ -20,11 +24,396 @@ DELIM = "\r\n\r\n" UTF8 = 'utf-8' - class HttpSock(object): __last_request_sent_time = time.time() __request_sem = threading.Semaphore() + def __init__(self, connection_settings): + if Settings().use_http2: + self._subject = Http20Sock(connection_settings) + else: + self._subject = Http11Sock(connection_settings) + + def sendRecv(self, message: str, req_timeout_sec: int, reconnect: bool = False): + """ Sends a specified request to the server and waits for a response + @param message: Message to be sent. + @type message : Str + @param req_timeout_sec: The time, in seconds, to wait for request to complete + @type req_timeout_sec : Int + @return: + False if failure, True if success + Response if True returned, Error if False returned + @rtype : Tuple (Bool, String) + """ + return self._subject.sendRecv(message, req_timeout_sec, reconnect=reconnect) + +class Http20Sock(object): + __last_request_sent_time = time.time() + __request_sem = threading.Semaphore() + + def set_up_connection(self): + try: + host = Settings().host + target_ip = self.connection_settings.target_ip or host + target_port = self.connection_settings.target_port + if Settings().use_test_socket: #TODO: investigate test_server for http2 problems + self._sock = TestSocket(Settings().test_server) + elif self.connection_settings.use_ssl: + if self.connection_settings.disable_cert_validation: + context = ssl._create_unverified_context() + self._scheme = "https" + else: + context = ssl.create_default_context() + self._scheme = "https" + if Settings().client_certificate_path: + context.load_cert_chain( + certfile = Settings().client_certificate_path, + keyfile = Settings().client_certificate_key_path, + ) + + with socket.create_connection((target_ip, target_port or 443)) as sock: + self._sock = context.wrap_socket(sock, server_hostname=host) + self._connection = h2.connection.H2Connection(config=h2.config.H2Configuration()) + self._connection.initiate_connection() + self._sock.sendall(self._connection.data_to_send()) + self._stream_id = self._connection.get_next_available_stream_id() + + else: + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._scheme = "http" + self._sock.connect((target_ip, target_port or 80)) + self._connection = h2.connection.H2Connection(config=h2.config.H2Configuration()) + self._connection.initiate_connection() + self._sock.sendall(self._connection.data_to_send()) + self._stream_id = self._connection.get_next_available_stream_id() + except Exception as error: + raise TransportLayerException(f"Exception Creating Socket: {error!s}") + + def __init__(self, connection_settings): + """ Initializes a socket object using low-level python socket objects. + + @param connection_settings: The connection settings for this socket + @type connection_settings: ConnectionSettings + + @return: None + @rtype : None + + """ + self._request_throttle_sec = (float)(Settings().request_throttle_ms/1000.0)\ + if Settings().request_throttle_ms else None + + self.connection_settings = connection_settings + + self.ignore_decoding_failures = Settings().ignore_decoding_failures + self._connected = False + self._sock = None + + def __del__(self): + """ Destructor - Closes socket + + """ + if self._sock: + self._closeSocket() + + def _get_method_from_message(self, message): + end_of_method_idx = message.find(" ") + method_name = message[0:end_of_method_idx] + return method_name + + def sendRecv(self, message, req_timeout_sec, reconnect=False): + """ Sends a specified request to the server and waits for a response + + @param message: Message to be sent. + @type message : Str + @param req_timeout_sec: The time, in seconds, to wait for request to complete + @type req_timeout_sec : Int + + @return: + False if failure, True if success + Response if True returned, Error if False returned + @rtype : Tuple (Bool, String) + + """ + + try: + if reconnect or not self._connected: + if reconnect: + if self._sock: + self._closeSocket() + self.set_up_connection() + self._connected = True + + self._sendRequest(message) + if not Settings().use_test_socket: # TODO: investigate test socket situation + http_method_name = self._get_method_from_message(message) + received_response = self._recvResponse(req_timeout_sec, http_method_name) + if not received_response and not reconnect: + # Re-connect and try again, since this may be due to the connection being closed. + RAW_LOGGING("Empty response received. Re-creating connection and re-trying.") + return self.sendRecv(message, req_timeout_sec, reconnect=True) + + response = HttpResponse(received_response) + self._stream_id = self._connection.get_next_available_stream_id() + else: + response = self._sock.recv() ##TODO: recv with h2 maybe? investigate test_server problems + self._stream_id = self._connection.get_next_available_stream_id() + RAW_LOGGING(f'Received: {response.to_str!r}\n') + + return (True, response) + except TransportLayerException as error: + response = HttpResponse(str(error).strip('"\'')) + if 'timed out' in str(error): + response._status_code = TIMEOUT_CODE + RAW_LOGGING(f"Reached max req_timeout_sec of {req_timeout_sec}.") + elif self._contains_connection_closed(str(error)): + response._status_code = CONNECTION_CLOSED_CODE + RAW_LOGGING(f"Connection error: {error!s}") + if not reconnect: + RAW_LOGGING("Re-creating connection and re-trying.") + return self.sendRecv(message, req_timeout_sec, reconnect=True) + else: + RAW_LOGGING(f"Unknown error: {error!s}") + if not reconnect: + RAW_LOGGING("Re-creating connection and re-trying.") + return self.sendRecv(message, req_timeout_sec, reconnect=True) + return (False, response) + + def _contains_connection_closed(self, error_str): + """ Returns whether or not the error string contains a connection closed error + + @param error_str: The error string to check for connection closed error + @type error_str: Str + + @return: True if the error string contains the connection closed error + @rtype : Bool + + """ + # WinError 10054 occurs when the server terminates the connection and RESTler + # is being run from a Windows system. + # Errno 104 occurs when the server terminates the connection and RESTler + # is being run from a Linux system. + connection_closed_strings = [ + # Windows + '[WinError 10054]', + '[WinError 10053]', + # Linux + '[Errno 104]' + ] + return any(filter(lambda x : x in error_str, connection_closed_strings)) + + def _sendRequest(self, message): + """ Sends message via current instance of socket object. + + @param message: Message to be sent. + @type message : Str + + @return: None + @rtype : None + + """ + def _get_end_of_header(message): + return message.index(DELIM) + + def _get_start_of_body(message): + return _get_end_of_header(message) + len(DELIM) + + def _append_to_header(message, content): + header = message[:_get_end_of_header(message)] + "\r\n" + content + DELIM + return header + message[_get_start_of_body(message):] + + def _get_data_from_message(message): + data_index = message.find(DELIM) + data = message[data_index+len(DELIM):] + return data + + def _get_uri_segment_from_message(message): + segment_start_index = message.find(' ') + segment_end_index = message[segment_start_index+1:].find(' ') + segment_start_index+1 + segment = message[segment_start_index+1:segment_end_index] + return segment + + def _create_h2_header(message): + + h11_request_header = message[message.index('\r\n')+2:_get_end_of_header(message)] # we only want the header fields, e.g. Accept: ..., Host: ..., ... + + def _get_host_header_value_from_message(h11_request_header): + for line in h11_request_header.split('\r\n'): + key, value = line.split(': ', 1) + if key == 'Host': + return value + + h2_request_header = [ #special headers must appear at the start of the header block + (':method', self._get_method_from_message(message)), + (':authority', _get_host_header_value_from_message(h11_request_header) or Settings().host), + (':scheme', self._scheme), + (':path', _get_uri_segment_from_message(message)) + ] + + for line in h11_request_header.split('\r\n'): + key, value = line.split(': ', 1) + if key != "Host": + h2_request_header.append((key, value)) + + return h2_request_header + + def _send_h2_data(data): + encoded_data = data.encode(UTF8) + window_size = self._connection.local_flow_control_window(stream_id = self._stream_id) + max_frame_size = self._connection.max_outbound_frame_size + bytes_to_send = min(window_size, len(encoded_data)) + bytes_read = 0 + + while bytes_to_send > 0: + chunk_size = min(bytes_to_send, max_frame_size) + data_chunk = encoded_data[bytes_read:bytes_read+max_frame_size] + self._connection.send_data(stream_id=self._stream_id, data=data_chunk) + + bytes_to_send -= chunk_size + bytes_read += chunk_size + + h2_header = _create_h2_header(message) + h2_data = _get_data_from_message(message) + + # Attempt to throttle the request if necessary + self._begin_throttle_request() + + try: + RAW_LOGGING(f'Sending: {message!r}\n') + self._connection.send_headers(self._stream_id, h2_header) + if h2_data: + _send_h2_data(h2_data) + self._connection.end_stream(self._stream_id) + self._sock.sendall(self._connection.data_to_send()) + except Exception as error: + raise TransportLayerException(f"Exception Sending Data: {error!s}") + finally: + self._end_throttle_request() + + def _begin_throttle_request(self): + """ Will attempt to throttle a request by comparing the last time + a request was sent to the throttle time (if any). + + @return: None + @rtype : None + + """ + if self._request_throttle_sec: + Http20Sock.__request_sem.acquire() + elapsed = time.time() - Http20Sock.__last_request_sent_time + throttle_time_remaining = self._request_throttle_sec - elapsed + if throttle_time_remaining > 0: + time.sleep(throttle_time_remaining) + + def _end_throttle_request(self): + """ Will release the throttle lock (if held), so another + request can be sent. + + Sets last_request_sent_time + + @return: None + @rtype : None + + """ + if self._request_throttle_sec: + Http20Sock.__last_request_sent_time = time.time() + Http20Sock.__request_sem.release() + + def _recvResponse(self, req_timeout_sec, method_name): + """ Reads data from socket object. + + @param req_timeout_sec: The time, in seconds, to wait for request to complete + @type req_timeout_sec : Int + + @return: Data received on current socket. + @rtype : Str + + """ + global DELIM + received_bytes = '' + header = '' + data = '' + response_stream_ended = False + + def decode_buf (buf): + try: + return buf.decode(UTF8) + except Exception as ex: + if self.ignore_decoding_failures: + RAW_LOGGING(f'Failed to decode data due to {ex}. \ + Trying again while ignoring offending bytes.') + return buf.decode(UTF8, "ignore") + else: + raise + + while not response_stream_ended: + try: + self._sock.settimeout(req_timeout_sec) + received_bytes = self._sock.recv(2**20) + events = self._connection.receive_data(received_bytes) + + for event in events: + if isinstance(event, h2.events.DataReceived): + # update flow control so the server doesn't starve us + self._connection.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + if event.stream_id == self._stream_id: + data += decode_buf(event.data) + received_bytes += event.data + if isinstance(event, h2.events.ResponseReceived): + if event.stream_id == self._stream_id: + header = 'HTTP/2.0 ' + amount_header_fields = len(event.headers) + for i, tuples in enumerate(event.headers): + header_field = decode_buf(tuples[0]) + header_value = decode_buf(tuples[1]) + if header_field == ':status': + header += header_value + ' ' + responses[int(header_value)] + else: + header += decode_buf(tuples[0]) + ': ' + decode_buf(tuples[1]) + if i == amount_header_fields - 1: + header += DELIM + else: + header += '\r\n' + if isinstance(event, h2.events.StreamEnded): + if event.stream_id == self._stream_id: + response_stream_ended = True + if isinstance(event, h2.events.StreamReset): + if event.stream_id == self._stream_id: + return '' + if isinstance(event, h2.events.ConnectionTerminated): + return '' + + # send any pending data to the server + self._sock.sendall(self._connection.data_to_send()) + except Exception as error: + raise TransportLayerException(f"Exception: {error!s}") + if len(received_bytes) == 0: + return header + data + return header + data + + + + + def _closeSocket(self): + """ Closes open socket object. + + @return: None + @rtype : None + + """ + try: + self._connection.close_connection() + self._sock.sendall(self._connection.data_to_send()) + self._sock.close() + except BrokenPipeError: + self._sock.close() #if the peer is unresponsive due to a crash, the h2 close_connection call will raise an exception + except Exception as error: + raise TransportLayerException(f"Exception: {error!s}") + + +class Http11Sock(object): + __last_request_sent_time = time.time() + __request_sem = threading.Semaphore() + def set_up_connection(self): try: host = Settings().host @@ -209,8 +598,8 @@ def _begin_throttle_request(self): """ if self._request_throttle_sec: - HttpSock.__request_sem.acquire() - elapsed = time.time() - HttpSock.__last_request_sent_time + Http11Sock.__request_sem.acquire() + elapsed = time.time() - Http11Sock.__last_request_sent_time throttle_time_remaining = self._request_throttle_sec - elapsed if throttle_time_remaining > 0: time.sleep(throttle_time_remaining) @@ -226,8 +615,8 @@ def _end_throttle_request(self): """ if self._request_throttle_sec: - HttpSock.__last_request_sent_time = time.time() - HttpSock.__request_sem.release() + Http11Sock.__last_request_sent_time = time.time() + Http11Sock.__request_sem.release() def _recvResponse(self, req_timeout_sec, method_name): """ Reads data from socket object. diff --git a/restler/requirements.txt b/restler/requirements.txt index b60b27a2..adba72ab 100644 --- a/restler/requirements.txt +++ b/restler/requirements.txt @@ -1,2 +1,3 @@ applicationinsights -pytest \ No newline at end of file +pytest +h2 \ No newline at end of file diff --git a/restler/restler.py b/restler/restler.py index 5e9c6da9..24785d3f 100644 --- a/restler/restler.py +++ b/restler/restler.py @@ -256,6 +256,9 @@ def signal_handler(sig, frame): parser.add_argument('--no_ssl', help='Set this flag if you do not want to use SSL validation for the socket', action='store_true') + parser.add_argument('--http2', + help='Set this flag if you would like to use HTTP2', + action='store_true') parser.add_argument('--include_user_agent', help='Set this flag if you would like to add User-Agent to the request headers', action='store_true') diff --git a/restler/restler_settings.py b/restler/restler_settings.py index 95fadc82..a3d9ff5e 100644 --- a/restler/restler_settings.py +++ b/restler/restler_settings.py @@ -20,7 +20,7 @@ class OptionValidationError(Exception): pass class ConnectionSettings(object): - def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=False, disable_cert_validation=False): + def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=False, disable_cert_validation=False, use_http2=False): """ Initializes an object that contains the connection settings for the socket @param target_ip: The ip of the target service. @type target_ip: Str @@ -32,6 +32,8 @@ def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=Fals @type include_user_agent: Boolean @param disable_cert_validation: Whether or not to disable SSL certificate validation @type disable_cert_validation: Bool + @param use_http2: Whether or not to use HTTP2 for connection + @type use_http2: Boolean @return: None @rtype : None @@ -42,6 +44,7 @@ def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=Fals self.use_ssl = use_ssl self.include_user_agent = include_user_agent self.disable_cert_validation = disable_cert_validation + self.use_http2 = use_http2 class SettingsArg(object): """ Holds a setting's information """ @@ -440,6 +443,8 @@ def convert_wildcards_to_regex(str_value): self._max_sequence_length = SettingsArg('max_sequence_length', int, MAX_SEQUENCE_LENGTH_DEFAULT, user_args, minval=0) ## Do not use SSL validation self._no_ssl = SettingsArg('no_ssl', bool, False, user_args) + ## Use HTTP2 + self._use_http2 = SettingsArg('http2', bool, False, user_args) ## Do not print auth token data in logs self._no_tokens_in_logs = SettingsArg('no_tokens_in_logs', bool, True, user_args) ## Save the results in a dir with a fixed name (skip 'experiment' subdir) @@ -492,7 +497,8 @@ def convert_wildcards_to_regex(str_value): self._target_port.val, not self._no_ssl.val, self._include_user_agent.val, - self._disable_cert_validation.val) + self._disable_cert_validation.val, + self._use_http2.val) # Set per resource arguments if 'per_resource_settings' in user_args: @@ -684,6 +690,10 @@ def use_test_socket(self): def test_server(self): return self._test_server.val + @property + def use_http2(self): + return self._use_http2.val + @property def time_budget(self): return self._time_budget.val diff --git a/src/driver/Program.fs b/src/driver/Program.fs index da2a7321..303c0aeb 100644 --- a/src/driver/Program.fs +++ b/src/driver/Program.fs @@ -67,6 +67,8 @@ let usage() = Example: (\w*)/virtualNetworks/(\w*) --no_ssl When connecting to the service, do not use SSL. The default is to connect with SSL. + --http2 + When connecting to the service, use HTTP2. The default is to connect with HTTP/1.1. --host If specified, this string will set or override the Host in each request. Example: management.web.com @@ -239,6 +241,7 @@ module Fuzz = sprintf "--producer_timing_delay %d" parameters.producerTimingDelay else "") (if not parameters.useSsl then "--no_ssl" else "") + (if not parameters.useHTTP2 then "" else "--http2") (match parameters.host with | Some h -> sprintf "--host %s" h | None -> "") @@ -636,6 +639,8 @@ let rec parseEngineArgs task (args:EngineParameters) = function usage() | "--no_ssl"::rest -> parseEngineArgs task { args with useSsl = false } rest + | "--http2"::rest -> + parseEngineArgs task { args with useHTTP2 = true} rest | "--host"::host::rest-> parseEngineArgs task { args with host = Some host } rest | "--settings"::settingsFilePath::rest -> diff --git a/src/driver/Types.fs b/src/driver/Types.fs index 891dc1c8..7a4025dc 100644 --- a/src/driver/Types.fs +++ b/src/driver/Types.fs @@ -53,6 +53,9 @@ type EngineParameters = /// Specifies to use SSL when connecting to the server useSsl : bool + /// Specifies to use HTTP2 when connecting to the server + useHTTP2 : bool + /// The string to use in overriding the Host for each request host : string option @@ -86,6 +89,7 @@ let DefaultEngineParameters = searchStrategy = None producerTimingDelay = 0 useSsl = true + useHTTP2 = false host = None settingsFilePath= "" checkerOptions = []