Skip to content

Commit

Permalink
Run scenarios flask application as gunicorn app. (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
fwkz authored Nov 19, 2017
1 parent 1393c2f commit 657c706
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 24 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def test_exploit(target):
cgi_mock.assert_called_once()
assert check_mock.call_count == 2
```
#### Random port
To avoid `port` collison during tests you can tell `HttpServiceMock` to set
it for you by passing `0`
```python
@pytest.fixture(scope="session")
def trash_target():
with HttpServiceMock("127.0.0.1", 0) as http_service:
yield http_service
```
### `HttpScenarioService`
`HttpScenarioService` allows for creating test utilities using pre-defined
[scenarios](#http-scenarios)
Expand All @@ -86,6 +95,29 @@ def trash_target():

```

#### Adhoc SSL support
You can serve `HttpScenarioService` using adhoc SSL certificate by setting
`ssl` keyword argument to `True`:

```python
@pytest.fixture(scope="session")
def trash_target():
with HttpScenarioService("127.0.0.1", 8443, HttpScenario.TRASH,
ssl=True) as http_service:
yield http_service
```

#### Random port
To avoid `port` collison during tests you can tell `HttpScenarioService` to set
it for you by passing `0`
```python
@pytest.fixture(scope="session")
def trash_target():
with HttpScenarioService("127.0.0.1", 0,
HttpScenario.TRASH) as http_service:
yield http_service
```

## Services
### `http`
```bash
Expand Down
8 changes: 4 additions & 4 deletions threat9_test_bed/http_server.py → threat9_test_bed/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

logger = logging.getLogger(__name__)

http_server = Flask(__name__)
app = Flask(__name__)

ALLOWED_METHODS = [
"GET",
Expand All @@ -20,11 +20,11 @@
]


@http_server.route('/', defaults={'path': ''}, methods=ALLOWED_METHODS)
@http_server.route('/<path:path>', methods=ALLOWED_METHODS)
@app.route('/', defaults={'path': ''}, methods=ALLOWED_METHODS)
@app.route('/<path:path>', methods=ALLOWED_METHODS)
def catch_all(path):
scenario_handler = SCENARIO_TO_HANDLER_MAP.get(
http_server.config["SCENARIO"],
app.config["SCENARIO"],
error,
)
logger.debug(
Expand Down
24 changes: 19 additions & 5 deletions threat9_test_bed/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import click

from .http_server import http_server
from .app import app
from .gunicorn_server import GunicornServer
from .scenarios import HttpScenario

logger = logging.getLogger(__name__)
Expand All @@ -24,10 +25,16 @@ def cli():
help='HTTP server behaviour.')
def run_http_server(scenario, port):
logger.debug("Starting `http` server...")
http_server.config.update(
app.config.update(
SCENARIO=HttpScenario[scenario],
)
http_server.run(port=port)
GunicornServer(
app=app,
bind=f"127.0.0.1:{port}",
worker_class="gthread",
threads=8,
accesslog="-",
).run()
logger.debug(f"`http` server has been started on port {port}.")


Expand All @@ -42,8 +49,15 @@ def run_http_server(scenario, port):
help='HTTP server behaviour.')
def run_https_server(scenario, port):
logger.debug("Starting `https` server...")
http_server.config.update(
app.config.update(
SCENARIO=HttpScenario[scenario],
)
http_server.run(port=port, ssl_context='adhoc')
GunicornServer(
app=app,
bind=f"127.0.0.1:{port}",
worker_class="gthread",
threads=8,
ssl=True,
accesslog="-",
).run()
logger.debug(f"`https` server has been started on port {port}.")
43 changes: 43 additions & 0 deletions threat9_test_bed/gunicorn_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import tempfile
from pathlib import Path

from gunicorn.app.base import BaseApplication
from OpenSSL import crypto
from werkzeug.serving import generate_adhoc_ssl_pair


class GunicornServer(BaseApplication):
def __init__(self, app, **kwargs):
self.options = kwargs
self.application = app
super().__init__()

def load_config(self):
if self.options.get("ssl"):
cert_path, pkey_path = self.generate_devel_ssl_pair()
self.options["certfile"] = str(cert_path)
self.options["keyfile"] = str(pkey_path)

config = {
key: value for key, value in self.options.items()
if key in self.cfg.settings and value is not None
}
for key, value in config.items():
self.cfg.set(key.lower(), value)

def load(self):
return self.application

@staticmethod
def generate_devel_ssl_pair() -> (Path, Path):
cert_path = Path(tempfile.gettempdir()) / "threat9-test-bed.crt"
pkey_path = Path(tempfile.gettempdir()) / "threat9-test-bed.key"

if not cert_path.exists() or not pkey_path.exists():
cert, pkey = generate_adhoc_ssl_pair()
with open(cert_path, 'wb') as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
with open(pkey_path, 'wb') as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey))

return cert_path, pkey_path
79 changes: 71 additions & 8 deletions threat9_test_bed/service_mocks/base_http_service.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import logging
import multiprocessing
import socket
import threading
import time
from contextlib import closing
from wsgiref.simple_server import make_server

from flask import Flask

from ..gunicorn_server import GunicornServer

logger = logging.getLogger(__name__)


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return int(s.getsockname()[1])


class BaseHttpService:
def __init__(self, host: str, port: int, app: Flask):
self.host = host
self.port = port
self.port = find_free_port() if port == 0 else port
self.app = app
self.server = make_server(self.host, self.port, self.app)
self.server_thread = threading.Thread(target=self.server.serve_forever)

def _wait_for_service(self):
elapsed_time = 0
Expand All @@ -24,7 +32,7 @@ def _wait_for_service(self):
s = socket.socket()
s.settimeout(1)
try:
s.connect(self.server.server_address)
s.connect((self.host, self.port))
except (ConnectionRefusedError, ConnectionAbortedError):
elapsed_time = time.time() - start_time
s.close()
Expand All @@ -35,21 +43,76 @@ def _wait_for_service(self):
raise TimeoutError(f"{self.__class__.__name__} "
f"couldn't be set up before test.")

def start(self):
raise NotImplementedError()

def teardown(self):
raise NotImplementedError()

def __enter__(self):
logger.debug(f"Starting {self}...")
self.server_thread.start()
self.start()
self._wait_for_service()
logger.debug(f"{self} has been started.")
return self

def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug(f"Terminating {self}...")
self.server.shutdown()
self.server_thread.join()
self.server.server_close()
self.teardown()
logger.debug(f"{self} has been terminated.")

def __repr__(self):
return (
f"{self.__class__.__name__}(host='{self.host}', port={self.port})"
)


class GunicornBasedHttpService(BaseHttpService):
""" `gunicorn` based HTTP service
`Flask` application served using `gunicorn` in separate process using
async workers (threads in this case).
Application served by this base class suppose to handle unbuffered
requests, `nginx` in this case is no option hence async workers.
"""
def __init__(self, host: str, port: int, app: Flask, ssl=False):
super().__init__(host, port, app)
self.server = GunicornServer(
app=self.app,
bind=f"{self.host}:{self.port}",
worker_class="gthread",
threads=8,
ssl=ssl,
accesslog="-",
)
self.server_process = multiprocessing.Process(target=self.server.run)

def start(self):
self.server_process.start()

def teardown(self):
self.server_process.terminate()
self.server_process.join()


class WSGIRefBasedHttpService(BaseHttpService):
""" `wsgiref` based HTTP service
`Flask` application served using `wsgiref` in separate thread.
We can leverage shared state between main thread and thread handling
`wsgiref` server and dynamically attach `Mock` object as view functions.
"""
def __init__(self, host: str, port: int, app: Flask):
super().__init__(host, port, app)
self.server = make_server(self.host, self.port, self.app)
self.server_thread = threading.Thread(target=self.server.serve_forever)

def start(self):
self.server_thread.start()

def teardown(self):
self.server.shutdown()
self.server_thread.join()
self.server.server_close()
10 changes: 5 additions & 5 deletions threat9_test_bed/service_mocks/http_scenario_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging

from ..http_server import http_server
from ..app import app
from ..scenarios import HttpScenario
from .base_http_service import BaseHttpService
from .base_http_service import GunicornBasedHttpService

logger = logging.getLogger(__name__)


class HttpScenarioService(BaseHttpService):
class HttpScenarioService(GunicornBasedHttpService):
def __init__(self, host: str, port: int, scenario: HttpScenario):
http_server.config.update(SCENARIO=scenario)
super().__init__(host, port, http_server)
app.config.update(SCENARIO=scenario)
super().__init__(host, port, app)
4 changes: 2 additions & 2 deletions threat9_test_bed/service_mocks/http_service_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from flask import Flask

from .base_http_service import BaseHttpService
from .base_http_service import WSGIRefBasedHttpService

logger = logging.getLogger(__name__)


class HttpServiceMock(BaseHttpService):
class HttpServiceMock(WSGIRefBasedHttpService):
def __init__(self, host: str, port: int):
super().__init__(host, port, Flask("target"))

Expand Down

0 comments on commit 657c706

Please sign in to comment.