diff --git a/restler/engine/transport_layer/messaging.py b/restler/engine/transport_layer/messaging.py index f2e7f30c..1d359f00 100644 --- a/restler/engine/transport_layer/messaging.py +++ b/restler/engine/transport_layer/messaging.py @@ -1,13 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -""" Transport layer fuctionality using python sockets. """ +""" Transport layer functionality using python sockets. """ from __future__ import print_function +from abc import ABCMeta, abstractmethod import ssl import socket import time import threading from importlib import util +from typing import Dict, Tuple, Union +from hyper import HTTP20Connection + from utils.logger import raw_network_logging as RAW_LOGGING from engine.errors import TransportLayerException @@ -22,14 +26,155 @@ class HttpSock(object): + """ + Proxy to return the correct socket object. + """ + def __init__(self, connection_settings: Dict) -> None: + if Settings().use_http2: + self._subject = Http2Sock(connection_settings) + else: + self._subject = HttpRawSock(connection_settings) + + def sendRecv(self, message: str, req_timeout_sec: int, reconnect: bool = False) -> Tuple[bool, Union[HttpResponse ,str]]: + """ Sends a specified request to the server and waits for a response + + @param message: Message to be sent. + @type message : Str + @param req_timeout_sec: The time, in seconds, to wait for request to complete + @type req_timeout_sec : Int + + @return: + False if failure, True if success + Response if True returned, Error if False returned + @rtype : Tuple (Bool, String) + + """ + return self._subject.sendRecv(message, req_timeout_sec, reconnect=reconnect) + + +class BaseSocket(object, metaclass=ABCMeta): + __last_request_sent_time = time.time() + __req_sem = threading.Semaphore() + + @abstractmethod + def __init__(self, connection_settings: Dict) -> None: + self.connection_settings = connection_settings + + host = Settings().host + self.target_ip = connection_settings.target_ip or host + self.target_port = connection_settings.target_port or 433 + + self.connection_settings = connection_settings + + self.ignore_decoding_failures = Settings().ignore_decoding_failures + + @abstractmethod + def __del__(self): + pass + + @abstractmethod + def sendRecv(self, message: str, req_timeout_sec: int) -> Tuple[bool, str]: + pass + + def _get_method_from_message(self, message): + end_of_method_idx = message.find(" ") + method_name = message[0:end_of_method_idx] + return method_name + + def _get_payload_from_message(self, message): + # FIXME: really not a safe way of doing this... + payload_index = message.find(DELIM) + body = message[payload_index+len(DELIM):] + return body + + def _get_uri_segment_from_message(self, message): + segment_start_index = message.find(' ') + segment_end_index = message[segment_start_index+1:].find(' ') + segment_start_index+1 + segment = message[segment_start_index+1:segment_end_index] + return segment + + def _get_headers_from_message(self, message) -> Dict: + # FIXME: ugly + header_index = message.find('\r\n') + payload_index = message.find(DELIM) + + headers = message[header_index+2:payload_index] + h = dict() + for line in headers.split('\r\n'): + k, v = line.split(':') + h[k.strip()] = v.strip() + return h + + +class Http2Sock(BaseSocket): + def __init__(self, connection_settings: Dict) -> None: + """ Initializes a socket object using hyper. + + @param connection_settings: The connection settings for this socket + @type connection_settings: ConnectionSettings + + @return: None + @rtype : None + + """ + super().__init__(connection_settings) + + self.client = HTTP20Connection( + host=self.target_ip, + port=self.target_port, + secure=self.connection_settings.use_ssl + ) + + def sendRecv(self, message: str, req_timeout_sec: int, *args, **kwargs) -> Tuple[bool, str]: + super().sendRecv(message, req_timeout_sec) + method = self._get_method_from_message(message) + message_body = self._get_payload_from_message(message) + uri_segment = self._get_uri_segment_from_message(message) + + print(message) + + self.client.request( + method, + url=uri_segment, + body=bytes(message_body, UTF8), #TODO: allow other encodings + headers=self._get_headers_from_message(message) + ) + + response = self.client.get_response() + + res = Http2Response(response) + + return (True, res) + + def __del__(self): + pass + + +class HttpRawSock(BaseSocket): __last_request_sent_time = time.time() __request_sem = threading.Semaphore() + def __init__(self, connection_settings): + """ Initializes a socket object using low-level python socket objects. + + @param connection_settings: The connection settings for this socket + @type connection_settings: ConnectionSettings + + @return: None + @rtype : None + + """ + super().__init__(connection_settings) + def set_up_connection(self): try: - host = Settings().host - target_ip = self.connection_settings.target_ip or host - target_port = self.connection_settings.target_port + self._sock = None + + if Settings().request_throttle_ms: + self._request_throttle_sec = Settings().request_throttle_ms/1000.0 + else: + self._request_throttle_sec = None + if Settings().use_test_socket: self._sock = TestSocket(Settings().test_server) elif self.connection_settings.use_ssl: @@ -42,13 +187,13 @@ def set_up_connection(self): certfile = Settings().client_certificate_path, keyfile = Settings().client_certificate_key_path, ) - - with socket.create_connection((target_ip, target_port or 443)) as sock: - self._sock = context.wrap_socket(sock, server_hostname=host) + + with socket.create_connection((self.target_ip, self.target_port)) as sock: + self._sock = context.wrap_socket(sock, server_hostname=self.host) else: self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._sock.connect((target_ip, target_port or 80)) + self._sock.connect((self.target_ip, self.target_port or 80)) except Exception as error: raise TransportLayerException(f"Exception Creating Socket: {error!s}") diff --git a/restler/engine/transport_layer/response.py b/restler/engine/transport_layer/response.py index 5b7897cf..06088ef3 100644 --- a/restler/engine/transport_layer/response.py +++ b/restler/engine/transport_layer/response.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from abc import ABCMeta, abstractmethod, abstractproperty import string import re +from typing import Dict, List + +import hyper from restler_settings import Settings DELIM = "\r\n\r\n" @@ -14,7 +18,112 @@ # of a sequence because the sequence failed prior to that request being reached. RESTLER_INVALID_CODE = '999' -class HttpResponse(object): +class AbstractHttpResponse(object, metaclass=ABCMeta): + @abstractmethod + def __init__(self, response): + pass + + @abstractproperty + def to_str(self): + pass + + @abstractproperty + def to_str(self) -> str: + pass + + @abstractproperty + def status_code(self) -> str: + pass + + @abstractproperty + def body(self) -> str: + pass + + @abstractproperty + def headers(self) -> str: + """Raw response header section of response""" + pass + + @abstractproperty + def headers_dict(self) -> Dict: + pass + + @abstractmethod + def has_valid_code(self) -> bool: + pass + + @abstractmethod + def has_bug_code(self) -> bool: + pass + + @abstractproperty + def json_body(self) -> str: + pass + + @abstractproperty + def status_text(self) -> str: + pass + + +class Http2Response(AbstractHttpResponse): + def __init__(self, response: hyper.HTTP20Response): + """ Hyper response facade + """ + self._response = response + self._body = self._response.read(decode_content=True).decode('utf-8') + + @property + def to_str(self) -> str: + #TODO: remove the need for this function. + # It is hacky. + return f"{self.headers}{DELIM}{self.body}" + + @property + def status_code(self) -> str: + return str(self._response.status) + + @property + def body(self) -> str: + return self._body + + @property + def headers(self) -> str: + """Raw response header section of response""" + h_generator = self._response.headers.iter_raw() + header_str = '\n\r'.join(f"{k.decode('utf-8')}: {v.decode('utf-8')}" for k,v in h_generator) + return header_str + + @property + def headers_dict(self) -> Dict: + h_dict = dict() + for k, v in self._response.headers.iter_raw(): + h_dict[k] = v + return h_dict + + def has_valid_code(self) -> bool: + sc = self._response.status + return sc in VALID_CODES + + def has_bug_code(self) -> bool: + sc = self._response.status + custom_bug = sc in Settings().custom_non_bug_codes + fiveXX_code = sc >= 500 + return custom_bug or fiveXX_code + + @property + def json_body(self) -> str: + # TODO: actually parse json data + return self.body + + @property + def status_text(self) -> str: + """ + This is not used in HTTP/2, and so is always an empty string. + """ + return "" + + +class HttpResponse(AbstractHttpResponse): def __init__(self, response_str: str=None): """ Initializes an HttpResponse object diff --git a/restler/requirements.txt b/restler/requirements.txt index b60b27a2..53aac584 100644 --- a/restler/requirements.txt +++ b/restler/requirements.txt @@ -1,2 +1,3 @@ applicationinsights -pytest \ No newline at end of file +pytest +hyper==0.7.0 diff --git a/restler/restler.py b/restler/restler.py index af2c8233..d402c431 100644 --- a/restler/restler.py +++ b/restler/restler.py @@ -257,6 +257,9 @@ def signal_handler(sig, frame): parser.add_argument('--host', help='Set to override Host in the grammar (default: do not override)', type=str, default=None, required=False) + parser.add_argument('--http2', + help='Use HTTP2/0 for server communication (BETA)', + action='store_true') parser.add_argument('--no_ssl', help='Set this flag if you do not want to use SSL validation for the socket', action='store_true') diff --git a/restler/restler_settings.py b/restler/restler_settings.py index a89a118c..8dc2fdd1 100644 --- a/restler/restler_settings.py +++ b/restler/restler_settings.py @@ -20,7 +20,7 @@ class OptionValidationError(Exception): pass class ConnectionSettings(object): - def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=False, disable_cert_validation=False): + def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=False, disable_cert_validation=False, use_http2=False): """ Initializes an object that contains the connection settings for the socket @param target_ip: The ip of the target service. @type target_ip: Str @@ -42,6 +42,7 @@ def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=Fals self.use_ssl = use_ssl self.include_user_agent = include_user_agent self.disable_cert_validation = disable_cert_validation + self.use_http2 = use_http2 class SettingsArg(object): """ Holds a setting's information """ @@ -462,6 +463,8 @@ def convert_wildcards_to_regex(str_value): self._target_port = SettingsArg('target_port', int, None, user_args, minval=0, maxval=TARGET_PORT_MAX) ## Set to use test server/run in test mode self._use_test_socket = SettingsArg('use_test_socket', bool, False, user_args) + ## Use HTTP2/0 for server communication + self._use_http2 = SettingsArg('http2', bool, False, user_args) ## Set the test server identifier self._test_server = SettingsArg('test_server', str, DEFAULT_TEST_SERVER_ID, user_args) ## Stops fuzzing after given time (hours) @@ -483,7 +486,8 @@ def convert_wildcards_to_regex(str_value): self._target_port.val, not self._no_ssl.val, self._include_user_agent.val, - self._disable_cert_validation.val) + self._disable_cert_validation.val, + self._use_http2.val) # Set per resource arguments if 'per_resource_settings' in user_args: @@ -667,6 +671,10 @@ def settings_file_exists(self): def use_test_socket(self): return self._use_test_socket.val + @property + def use_http2(self): + return self._use_http2.val + @property def test_server(self): return self._test_server.val