Skip to content

Commit

Permalink
Use werkzeug.serving.make_server to levrage shared state.
Browse files Browse the repository at this point in the history
* Use `werkzeug.serving.make_server` to levrage shared state
* Dib port for service on class init
  • Loading branch information
fwkz authored Dec 5, 2017
1 parent 06f0bf9 commit 11be942
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 41 deletions.
43 changes: 11 additions & 32 deletions threat9_test_bed/service_mocks/base_http_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
import multiprocessing
import threading
import uuid
from wsgiref.simple_server import make_server
from wsgiref.simple_server import make_server as wsgiref_make_server

import requests
from flask import Flask, request
from flask import Flask
from werkzeug.serving import make_server as werkzeug_make_server

from ..http_service.gunicorn_server import GunicornServer
from .base_service import BaseService
Expand Down Expand Up @@ -54,7 +53,7 @@ class WSGIRefBasedHttpService(BaseService):
def __init__(self, host: str, port: int, app: Flask):
super().__init__(host, port)
self.app = app
self.server = make_server(self.host, self.port, self.app)
self.server = wsgiref_make_server(self.host, self.port, self.app)
self.server_thread = threading.Thread(target=self.server.serve_forever)

def start(self):
Expand All @@ -76,38 +75,18 @@ class WerkzeugBasedHttpService(BaseService):
"""
def __init__(self, host: str, port: int, app: Flask, ssl=False):
super().__init__(host, port)
self.url_scheme = "https" if ssl else "http"
self.terminate_url = uuid.uuid4().hex
self.app = app

self.app.add_url_rule(
f"/{self.terminate_url}",
"shutdown_server",
self.shutdown_server,
methods=['POST'],
)

self.server_thread = threading.Thread(
target=self.app.run,
args=(self.host, self.port),
kwargs={"ssl_context": "adhoc"} if ssl else None
self.server = werkzeug_make_server(
self.host, self.port, self.app,
threaded=True,
ssl_context="adhoc" if ssl else None
)
self.server_thread = threading.Thread(target=self.server.serve_forever)

def start(self):
self.server_thread.start()

def teardown(self):
requests.post(
f"{self.url_scheme}://{self.host}:{self.port}"
f"/{self.terminate_url}",
verify=False,
)
self.server.shutdown()
self.server_thread.join()

@staticmethod
def shutdown_server():
func = request.environ.get('werkzeug.server.shutdown')
if func is None:
raise RuntimeError('Not running with the Werkzeug Server')
func()
return "Server terminated.", 200
self.server.server_close()
19 changes: 10 additions & 9 deletions threat9_test_bed/service_mocks/base_service.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import logging
import socket
import time
from contextlib import closing

logger = logging.getLogger(__name__)


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return int(s.getsockname()[1])


class BaseService:
def __init__(self, host: str, port: int):
self.host = host
self.port = find_free_port() if port == 0 else port
self.port, self.dibbed_port_socket = self.dib_port(port)

def _wait_for_service(self):
elapsed_time = 0
Expand All @@ -25,7 +18,8 @@ def _wait_for_service(self):
s.settimeout(1)
try:
s.connect((self.host, self.port))
except (ConnectionRefusedError, ConnectionAbortedError):
except (ConnectionRefusedError, ConnectionAbortedError,
socket.timeout):
elapsed_time = time.time() - start_time
s.close()
else:
Expand All @@ -43,6 +37,7 @@ def teardown(self):

def __enter__(self):
logger.debug(f"Starting {self}...")
self.dibbed_port_socket.close()
self.start()
self._wait_for_service()
logger.debug(f"{self} has been started.")
Expand All @@ -53,6 +48,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.teardown()
logger.debug(f"{self} has been terminated.")

@staticmethod
def dib_port(port=0) -> (int, socket.socket):
socket_ = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_.bind(('', port))
return int(socket_.getsockname()[1]), socket_

def __repr__(self):
return (
f"{self.__class__.__name__}(host='{self.host}', port={self.port})"
Expand Down

0 comments on commit 11be942

Please sign in to comment.