Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] #277: HTTP2 support #412

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 153 additions & 8 deletions restler/engine/transport_layer/messaging.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

""" Transport layer fuctionality using python sockets. """
""" Transport layer functionality using python sockets. """
from __future__ import print_function
from abc import ABCMeta, abstractmethod
import ssl
import socket
import time
import threading
from importlib import util
from typing import Dict, Tuple, Union
from hyper import HTTP20Connection


from utils.logger import raw_network_logging as RAW_LOGGING
from engine.errors import TransportLayerException
Expand All @@ -22,14 +26,155 @@


class HttpSock(object):
"""
Proxy to return the correct socket object.
"""
def __init__(self, connection_settings: Dict) -> None:
if Settings().use_http2:
self._subject = Http2Sock(connection_settings)
else:
self._subject = HttpRawSock(connection_settings)

def sendRecv(self, message: str, req_timeout_sec: int, reconnect: bool = False) -> Tuple[bool, Union[HttpResponse ,str]]:
""" Sends a specified request to the server and waits for a response

@param message: Message to be sent.
@type message : Str
@param req_timeout_sec: The time, in seconds, to wait for request to complete
@type req_timeout_sec : Int

@return:
False if failure, True if success
Response if True returned, Error if False returned
@rtype : Tuple (Bool, String)

"""
return self._subject.sendRecv(message, req_timeout_sec, reconnect=reconnect)


class BaseSocket(object, metaclass=ABCMeta):
__last_request_sent_time = time.time()
__req_sem = threading.Semaphore()

@abstractmethod
def __init__(self, connection_settings: Dict) -> None:
self.connection_settings = connection_settings

host = Settings().host
self.target_ip = connection_settings.target_ip or host
self.target_port = connection_settings.target_port or 433

self.connection_settings = connection_settings

self.ignore_decoding_failures = Settings().ignore_decoding_failures

@abstractmethod
def __del__(self):
pass

@abstractmethod
def sendRecv(self, message: str, req_timeout_sec: int) -> Tuple[bool, str]:
pass

def _get_method_from_message(self, message):
end_of_method_idx = message.find(" ")
method_name = message[0:end_of_method_idx]
return method_name

def _get_payload_from_message(self, message):
# FIXME: really not a safe way of doing this...
payload_index = message.find(DELIM)
body = message[payload_index+len(DELIM):]
return body

def _get_uri_segment_from_message(self, message):
segment_start_index = message.find(' ')
segment_end_index = message[segment_start_index+1:].find(' ') + segment_start_index+1
segment = message[segment_start_index+1:segment_end_index]
return segment

def _get_headers_from_message(self, message) -> Dict:
# FIXME: ugly
header_index = message.find('\r\n')
payload_index = message.find(DELIM)

headers = message[header_index+2:payload_index]
h = dict()
for line in headers.split('\r\n'):
k, v = line.split(':')
h[k.strip()] = v.strip()
return h


class Http2Sock(BaseSocket):
def __init__(self, connection_settings: Dict) -> None:
""" Initializes a socket object using hyper.

@param connection_settings: The connection settings for this socket
@type connection_settings: ConnectionSettings

@return: None
@rtype : None

"""
super().__init__(connection_settings)

self.client = HTTP20Connection(
host=self.target_ip,
port=self.target_port,
secure=self.connection_settings.use_ssl
)

def sendRecv(self, message: str, req_timeout_sec: int, *args, **kwargs) -> Tuple[bool, str]:
super().sendRecv(message, req_timeout_sec)
method = self._get_method_from_message(message)
message_body = self._get_payload_from_message(message)
uri_segment = self._get_uri_segment_from_message(message)

print(message)

self.client.request(
method,
url=uri_segment,
body=bytes(message_body, UTF8), #TODO: allow other encodings
headers=self._get_headers_from_message(message)
)

response = self.client.get_response()

res = Http2Response(response)

return (True, res)

def __del__(self):
pass


class HttpRawSock(BaseSocket):
__last_request_sent_time = time.time()
__request_sem = threading.Semaphore()

def __init__(self, connection_settings):
""" Initializes a socket object using low-level python socket objects.

@param connection_settings: The connection settings for this socket
@type connection_settings: ConnectionSettings

@return: None
@rtype : None

"""
super().__init__(connection_settings)

def set_up_connection(self):
try:
host = Settings().host
target_ip = self.connection_settings.target_ip or host
target_port = self.connection_settings.target_port
self._sock = None

if Settings().request_throttle_ms:
self._request_throttle_sec = Settings().request_throttle_ms/1000.0
else:
self._request_throttle_sec = None

if Settings().use_test_socket:
self._sock = TestSocket(Settings().test_server)
elif self.connection_settings.use_ssl:
Expand All @@ -42,13 +187,13 @@ def set_up_connection(self):
certfile = Settings().client_certificate_path,
keyfile = Settings().client_certificate_key_path,
)

with socket.create_connection((target_ip, target_port or 443)) as sock:
self._sock = context.wrap_socket(sock, server_hostname=host)
with socket.create_connection((self.target_ip, self.target_port)) as sock:
self._sock = context.wrap_socket(sock, server_hostname=self.host)

else:
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.connect((target_ip, target_port or 80))
self._sock.connect((self.target_ip, self.target_port or 80))
except Exception as error:
raise TransportLayerException(f"Exception Creating Socket: {error!s}")

Expand Down
111 changes: 110 additions & 1 deletion restler/engine/transport_layer/response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABCMeta, abstractmethod, abstractproperty
import string
import re
from typing import Dict, List

import hyper
from restler_settings import Settings

DELIM = "\r\n\r\n"
Expand All @@ -14,7 +18,112 @@
# of a sequence because the sequence failed prior to that request being reached.
RESTLER_INVALID_CODE = '999'

class HttpResponse(object):
class AbstractHttpResponse(object, metaclass=ABCMeta):
@abstractmethod
def __init__(self, response):
pass

@abstractproperty
def to_str(self):
pass

@abstractproperty
def to_str(self) -> str:
pass

@abstractproperty
def status_code(self) -> str:
pass

@abstractproperty
def body(self) -> str:
pass

@abstractproperty
def headers(self) -> str:
"""Raw response header section of response"""
pass

@abstractproperty
def headers_dict(self) -> Dict:
pass

@abstractmethod
def has_valid_code(self) -> bool:
pass

@abstractmethod
def has_bug_code(self) -> bool:
pass

@abstractproperty
def json_body(self) -> str:
pass

@abstractproperty
def status_text(self) -> str:
pass


class Http2Response(AbstractHttpResponse):
def __init__(self, response: hyper.HTTP20Response):
""" Hyper response facade
"""
self._response = response
self._body = self._response.read(decode_content=True).decode('utf-8')

@property
def to_str(self) -> str:
#TODO: remove the need for this function.
# It is hacky.
return f"{self.headers}{DELIM}{self.body}"

@property
def status_code(self) -> str:
return str(self._response.status)

@property
def body(self) -> str:
return self._body

@property
def headers(self) -> str:
"""Raw response header section of response"""
h_generator = self._response.headers.iter_raw()
header_str = '\n\r'.join(f"{k.decode('utf-8')}: {v.decode('utf-8')}" for k,v in h_generator)
return header_str

@property
def headers_dict(self) -> Dict:
h_dict = dict()
for k, v in self._response.headers.iter_raw():
h_dict[k] = v
return h_dict

def has_valid_code(self) -> bool:
sc = self._response.status
return sc in VALID_CODES

def has_bug_code(self) -> bool:
sc = self._response.status
custom_bug = sc in Settings().custom_non_bug_codes
fiveXX_code = sc >= 500
return custom_bug or fiveXX_code

@property
def json_body(self) -> str:
# TODO: actually parse json data
return self.body

@property
def status_text(self) -> str:
"""
This is not used in HTTP/2, and so is always an empty string.
"""
return ""


class HttpResponse(AbstractHttpResponse):
def __init__(self, response_str: str=None):
""" Initializes an HttpResponse object

Expand Down
3 changes: 2 additions & 1 deletion restler/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
applicationinsights
pytest
pytest
hyper==0.7.0
3 changes: 3 additions & 0 deletions restler/restler.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def signal_handler(sig, frame):
parser.add_argument('--host',
help='Set to override Host in the grammar (default: do not override)',
type=str, default=None, required=False)
parser.add_argument('--http2',
help='Use HTTP2/0 for server communication (BETA)',
action='store_true')
parser.add_argument('--no_ssl',
help='Set this flag if you do not want to use SSL validation for the socket',
action='store_true')
Expand Down
12 changes: 10 additions & 2 deletions restler/restler_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OptionValidationError(Exception):
pass

class ConnectionSettings(object):
def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=False, disable_cert_validation=False):
def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=False, disable_cert_validation=False, use_http2=False):
""" Initializes an object that contains the connection settings for the socket
@param target_ip: The ip of the target service.
@type target_ip: Str
Expand All @@ -42,6 +42,7 @@ def __init__(self, target_ip, target_port, use_ssl=True, include_user_agent=Fals
self.use_ssl = use_ssl
self.include_user_agent = include_user_agent
self.disable_cert_validation = disable_cert_validation
self.use_http2 = use_http2

class SettingsArg(object):
""" Holds a setting's information """
Expand Down Expand Up @@ -462,6 +463,8 @@ def convert_wildcards_to_regex(str_value):
self._target_port = SettingsArg('target_port', int, None, user_args, minval=0, maxval=TARGET_PORT_MAX)
## Set to use test server/run in test mode
self._use_test_socket = SettingsArg('use_test_socket', bool, False, user_args)
## Use HTTP2/0 for server communication
self._use_http2 = SettingsArg('http2', bool, False, user_args)
## Set the test server identifier
self._test_server = SettingsArg('test_server', str, DEFAULT_TEST_SERVER_ID, user_args)
## Stops fuzzing after given time (hours)
Expand All @@ -483,7 +486,8 @@ def convert_wildcards_to_regex(str_value):
self._target_port.val,
not self._no_ssl.val,
self._include_user_agent.val,
self._disable_cert_validation.val)
self._disable_cert_validation.val,
self._use_http2.val)

# Set per resource arguments
if 'per_resource_settings' in user_args:
Expand Down Expand Up @@ -667,6 +671,10 @@ def settings_file_exists(self):
def use_test_socket(self):
return self._use_test_socket.val

@property
def use_http2(self):
return self._use_http2.val

@property
def test_server(self):
return self._test_server.val
Expand Down