Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssl support #712

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ or ``main.py``) or to a specific file. The ``--app-factory`` option can be used
from the app path file, if not supplied some default method names are tried
(namely `app`, `app_factory`, `get_app` and `create_app`, which can be
variables, functions, or coroutines).
The ``--ssl-context-factory`` option can be used to define method from the app path file, which returns ssl.SSLContext
for ssl support.

All ``runserver`` arguments can be set via environment variables.

Expand Down
4 changes: 4 additions & 0 deletions aiohttp_devtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo
'or just an instance of aiohttp.Application. env variable AIO_APP_FACTORY')
port_help = 'Port to serve app from, default 8000. env variable: AIO_PORT'
aux_port_help = 'Port to serve auxiliary app (reload and static) on, default port + 1. env variable: AIO_AUX_PORT'
ssl_context_factory_help = 'name of the ssl context factory to create ssl.SSLContext with'
ssl_rootcert_file_help = 'path to a rootCA certificate file for self-signed cert chain (if needed)'


# defaults are all None here so default settings are defined in one place: DEV_DICT validation
Expand All @@ -83,6 +85,8 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo
@click.option('-v', '--verbose', is_flag=True, help=verbose_help)
@click.option("--browser-cache/--no-browser-cache", envvar="AIO_BROWSER_CACHE", default=None,
help=browser_cache_help)
@click.option('--ssl-context-factory', 'ssl_context_factory_name', default=None, help=ssl_context_factory_help)
@click.option('--ssl-rootcert', 'ssl_rootcert_file_path', default=None, help=ssl_rootcert_file_help)
@click.argument('project_args', nargs=-1)
def runserver(**config: Any) -> None:
"""
Expand Down
70 changes: 62 additions & 8 deletions aiohttp_devtools/runserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import sys
from importlib import import_module
from pathlib import Path
from typing import Awaitable, Callable, Optional, Union
from typing import Awaitable, Callable, Optional, Union, Literal
from types import ModuleType

from aiohttp import web
from ssl import SSLContext, create_default_context as create_default_ssl_context

import __main__
from ..exceptions import AiohttpDevConfigError as AdevConfigError
Expand All @@ -26,6 +28,8 @@
'create_app',
]

DEFAULT_PORT = 8000

INFER_HOST = '<inference>'


Expand All @@ -43,9 +47,11 @@
app_factory_name: Optional[str] = None,
host: str = INFER_HOST,
bind_address: str = "localhost",
main_port: int = 8000,
main_port: Optional[int] = None,
aux_port: Optional[int] = None,
browser_cache: bool = False):
browser_cache: bool = False,
ssl_context_factory_name: Optional[str] = None,
ssl_rootcert_file_path: Optional[str] = None):
if root_path:
self.root_path = Path(root_path).resolve()
logger.debug('Root path specified: %s', self.root_path)
Expand Down Expand Up @@ -83,15 +89,39 @@
self.host = bind_address

self.bind_address = bind_address
if main_port is None:
main_port = DEFAULT_PORT if ssl_context_factory_name is None else DEFAULT_PORT + 443
self.main_port = main_port
self.aux_port = aux_port or (main_port + 1)
if aux_port is None:
aux_port = main_port + 1 if ssl_context_factory_name is None else DEFAULT_PORT + 1
self.aux_port = aux_port
self.browser_cache = browser_cache
self.ssl_context_factory_name = ssl_context_factory_name
self.ssl_rootcert_file_path = ssl_rootcert_file_path
logger.debug('config loaded:\n%s', self)

@property
def protocol(self) -> Literal["http", "https"]:
return "http" if self.ssl_context_factory_name is None else "https"

@property
def static_path_str(self) -> Optional[str]:
return str(self.static_path) if self.static_path else None

@property
def client_ssl_context(self) -> Union[SSLContext, None]:
client_ssl_context = None
if self.protocol == 'https':
client_ssl_context = create_default_ssl_context()
if self.ssl_rootcert_file_path:
try:
client_ssl_context.load_verify_locations(self.ssl_rootcert_file_path)
except FileNotFoundError as e:
raise AdevConfigError('{}: {}'.format(e.strerror, self.ssl_rootcert_file_path))
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder about catching arbitrary exceptions. I'd assume the only other errors we should expect here are ssl.SSLError exceptions.

raise AdevConfigError('invalid root cert file: {}'.format(self.ssl_rootcert_file_path))

Check warning on line 122 in aiohttp_devtools/runserver/config.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_devtools/runserver/config.py#L117-L122

Added lines #L117 - L122 were not covered by tests
Comment on lines +116 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it difficult to create a test for this code? Seems to be the only missing coverage now.

return client_ssl_context

def _find_app_path(self, app_path: str) -> Path:
# for backwards compatibility try this first
path = (self.root_path / app_path).resolve()
Expand Down Expand Up @@ -136,15 +166,14 @@
raise AdevConfigError('{} is not a directory'.format(path))
return path

def import_app_factory(self) -> AppFactory:
"""Import and return attribute/class from a python module.
def import_module(self) -> ModuleType:
"""Import and return python module.

Raises:
AdevConfigError - If the import failed.
"""
rel_py_file = self.py_file.relative_to(self.python_path)
module_path = '.'.join(rel_py_file.with_suffix('').parts)

sys.path.insert(0, str(self.python_path))
module = import_module(module_path)
# Rewrite the package name, so it will appear the same as running the app.
Expand All @@ -153,6 +182,16 @@

logger.debug('successfully loaded "%s" from "%s"', module_path, self.python_path)

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return module

def get_app_factory(self, module: ModuleType) -> AppFactory:
"""Return attribute/class from a python module.

Raises:
AdevConfigError - If the import failed.
"""

if self.app_factory_name is None:
try:
self.app_factory_name = next(an for an in APP_FACTORY_NAMES if hasattr(module, an))
Expand All @@ -179,9 +218,24 @@
raise AdevConfigError("'{}.{}' should not have required arguments.".format(
self.py_file.name, self.app_factory_name))

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return attr # type: ignore[no-any-return]

def get_ssl_context(self, module: ModuleType) -> Union[SSLContext, None]:
if self.ssl_context_factory_name is None:
return None
else:
try:
attr = getattr(module, self.ssl_context_factory_name)
except AttributeError:
raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format(
self.py_file.name, self.ssl_context_factory_name))
ssl_context = attr()
if isinstance(ssl_context, SSLContext):
return ssl_context
else:
raise AdevConfigError("ssl-context-factory '{}' in module '{}' didn't return valid SSLContext".format(
self.ssl_context_factory_name, self.py_file.name))

async def load_app(self, app_factory: AppFactory) -> web.Application:
if isinstance(app_factory, web.Application):
return app_factory
Expand Down
11 changes: 6 additions & 5 deletions aiohttp_devtools/runserver/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
from multiprocessing import set_start_method
from typing import Any, Type, TypedDict
from typing import Any, Type, TypedDict, Union

from aiohttp.abc import AbstractAccessLogger
from aiohttp.web import Application
Expand All @@ -11,6 +11,7 @@
from .log_handlers import AuxAccessLogger
from .serve import check_port_open, create_auxiliary_app
from .watch import AppTask, LiveReloadTask
from ssl import SSLContext


class RunServer(TypedDict):
Expand All @@ -19,6 +20,7 @@ class RunServer(TypedDict):
port: int
shutdown_timeout: float
access_log_class: Type[AbstractAccessLogger]
ssl_context: Union[SSLContext, None]


def runserver(**config_kwargs: Any) -> RunServer:
Expand All @@ -29,9 +31,8 @@ def runserver(**config_kwargs: Any) -> RunServer:
"""
# force a full reload in sub processes so they load an updated version of code, this must be called only once
set_start_method('spawn')

config = Config(**config_kwargs)
config.import_app_factory()
config.import_module()

asyncio.run(check_port_open(config.main_port, host=config.bind_address))

Expand All @@ -57,7 +58,7 @@ def runserver(**config_kwargs: Any) -> RunServer:
logger.info('serving static files from ./%s/ at %s%s', rel_path, url, config.static_url)

return {"app": aux_app, "host": config.bind_address, "port": config.aux_port,
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger}
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None}


def serve_static(*, static_path: str, livereload: bool = True, bind_address: str = "localhost", port: int = 8000,
Expand All @@ -75,4 +76,4 @@ def serve_static(*, static_path: str, livereload: bool = True, bind_address: str
livereload_status = 'ON' if livereload else 'OFF'
logger.info('Serving "%s" at http://%s:%d, livereload %s', static_path, bind_address, port, livereload_status)
return {"app": app, "host": bind_address, "port": port,
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger}
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None}
19 changes: 12 additions & 7 deletions aiohttp_devtools/runserver/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from errno import EADDRINUSE
from pathlib import Path
from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple
from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple, Union

from aiohttp import WSMsgType, web
from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH
Expand All @@ -25,6 +25,8 @@
from .log_handlers import AccessLogger
from .utils import MutableValue

from ssl import SSLContext

try:
from aiohttp_jinja2 import static_root_key
except ImportError:
Expand Down Expand Up @@ -120,7 +122,8 @@ def shutdown() -> NoReturn:

path = config.path_prefix + "/shutdown"
app.router.add_route("GET", path, do_shutdown, name="_devtools.shutdown")
dft_logger.debug("Created shutdown endpoint at http://{}:{}{}".format(config.host, config.main_port, path))
dft_logger.debug("Created shutdown endpoint at {}://{}:{}{}".format(
config.protocol, config.host, config.main_port, path))

if config.static_path is not None:
static_url = 'http://{}:{}/{}'.format(config.host, config.aux_port, static_path)
Expand Down Expand Up @@ -164,12 +167,14 @@ def set_tty(tty_path: Optional[str]) -> Iterator[None]:
def serve_main_app(config: Config, tty_path: Optional[str]) -> None:
with set_tty(tty_path):
setup_logging(config.verbose)
app_factory = config.import_app_factory()
module = config.import_module()
app_factory = config.get_app_factory(module)
ssl_context = config.get_ssl_context(module)
if sys.version_info >= (3, 11):
with asyncio.Runner() as runner:
app_runner = runner.run(create_main_app(config, app_factory))
try:
runner.run(start_main_app(app_runner, config.bind_address, config.main_port))
runner.run(start_main_app(app_runner, config.bind_address, config.main_port, ssl_context))
runner.get_loop().run_forever()
except KeyboardInterrupt:
pass
Expand All @@ -180,7 +185,7 @@ def serve_main_app(config: Config, tty_path: Optional[str]) -> None:
loop = asyncio.new_event_loop()
runner = loop.run_until_complete(create_main_app(config, app_factory))
try:
loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port))
loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port, ssl_context))
loop.run_forever()
except KeyboardInterrupt: # pragma: no cover
pass
Expand All @@ -197,9 +202,9 @@ async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRun
return web.AppRunner(app, access_log_class=AccessLogger, shutdown_timeout=0.1)


async def start_main_app(runner: web.AppRunner, host: str, port: int) -> None:
async def start_main_app(runner: web.AppRunner, host: str, port: int, ssl_context: Union[SSLContext, None]) -> None:
await runner.setup()
site = web.TCPSite(runner, host=host, port=port)
site = web.TCPSite(runner, host=host, port=port, ssl_context=ssl_context)
await site.start()


Expand Down
16 changes: 11 additions & 5 deletions aiohttp_devtools/runserver/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..logs import rs_dft_logger as logger
from .config import Config
from .serve import LAST_RELOAD, STATIC_PATH, WS, serve_main_app, src_reload
from ssl import SSLContext


class WatchTask:
Expand Down Expand Up @@ -55,13 +56,17 @@ def __init__(self, config: Config):
self._reloads = 0
self._session: Optional[ClientSession] = None
self._runner = None
self._client_ssl_context: Union[None, SSLContext] = None
assert self._config.watch_path

super().__init__(self._config.watch_path)

async def _run(self, live_checks: int = 150) -> None:
assert self._app is not None

self._session = ClientSession()
self._client_ssl_context = self._config.client_ssl_context

try:
self._start_dev_server()

Expand Down Expand Up @@ -107,12 +112,12 @@ async def _src_reload_when_live(self, checks: int) -> None:
assert self._app is not None and self._session is not None

if self._app[WS]:
url = "http://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config)
url = "{0.protocol}://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config)
logger.debug('checking app at "%s" is running before prompting reload...', url)
for i in range(checks):
await asyncio.sleep(0.1)
try:
async with self._session.get(url):
async with self._session.get(url, ssl=self._client_ssl_context):
pass
except OSError as e:
logger.debug('try %d | OSError %d app not running', i, e.errno)
Expand All @@ -123,7 +128,8 @@ async def _src_reload_when_live(self, checks: int) -> None:

def _start_dev_server(self) -> None:
act = 'Start' if self._reloads == 0 else 'Restart'
logger.info('%sing dev server at http://%s:%s ●', act, self._config.host, self._config.main_port)
logger.info('%sing dev server at %s://%s:%s ●',
act, self._config.protocol, self._config.host, self._config.main_port)

try:
tty_path = os.ttyname(sys.stdin.fileno())
Expand All @@ -141,12 +147,12 @@ async def _stop_dev_server(self) -> None:
if self._process.is_alive():
logger.debug('stopping server process...')
if self._config.shutdown_by_url: # Workaround for signals not working on Windows
url = "http://{0.host}:{0.main_port}{0.path_prefix}/shutdown".format(self._config)
url = "{0.protocol}://{0.host}:{0.main_port}{0.path_prefix}/shutdown".format(self._config)
logger.debug("Attempting to stop process via shutdown endpoint {}".format(url))
try:
with suppress(ClientConnectionError):
async with ClientSession() as session:
async with session.get(url):
async with session.get(url, ssl=self._client_ssl_context):
pass
except (ConnectionError, ClientError, asyncio.TimeoutError) as ex:
if self._process.is_alive():
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ pytest-mock==3.14.0
pytest-sugar==1.0.0
pytest-timeout==2.2.0
pytest-toolbox==0.4
pytest-datafiles==3.0.0
watchfiles==1.0.4
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
else:
forked = pytest.mark.forked

if sys.platform == 'linux':
linux_forked = pytest.mark.forked
else:
def linux_forked(func):
return func

SIMPLE_APP = {
'app.py': """\
from aiohttp import web
Expand Down
32 changes: 32 additions & 0 deletions tests/test_certs/rootCA.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-----BEGIN CERTIFICATE-----
MIIFlTCCA32gAwIBAgIUMqRqzVHCUfN7kz43bWrwlfmtl7kwDQYJKoZIhvcNAQEN
BQAwWjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDETMBEGA1UEAwwKVGVzdFJvb3RDQTAe
Fw0yNTAxMjYxMjE3MDBaFw0zNTAxMjQxMjE3MDBaMFoxCzAJBgNVBAYTAkFVMRMw
EQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0
eSBMdGQxEzARBgNVBAMMClRlc3RSb290Q0EwggIiMA0GCSqGSIb3DQEBAQUAA4IC
DwAwggIKAoICAQDFvixQRLk0R2WOnXDkdMrmittYqWfHr3ZhZtS6HvFWBSV6AWc3
DbseUgE7uD5xdJFlId35UH7HCFeeu8y/KkOPwH9KIzSWbZNcT3UJSDtnoA/sVYtN
MuS6Uu4DNkbDRNHf1udqc+0EwPpiZ7/3FwQify0pXyq7PbkOcJyFQh2YHG/EjZ4I
mBSz8NMwYQDeVMLxhQHTXruHIef1clLSSTRCXKLLKoKw/Rzje1jrBvLLollOJxLT
UXC1Fbpuh3KMnhwWsX4F4N8iWczcPxwCGcmYJA5xjo5tstkYzShUtNmMbFu3FCS8
Vl/h25I3Znq7VdEI+brR7ZEeJj0yp9H1Aiev6XAojqWoNC1M63HgYY7uhl3YGC6f
uwx0qgmGI32dzv5JHCpOtI8N2V5rwwtYBVws8lGmkqbUEkF5oO5V6yQHulVsdGr1
Kn5OPGolY8QmGcCE0LmvzRZCwZU2UcVxJsDJkNwup1C7wQEWC5pePEr58j3H3z6y
d3pkxaQmzXSB4jGJRzKbth6BQF47WwcphYjMtdWZUvy860isu9CEGjxbLjweATra
5o/8MIRuRPiJI2wlnEXHYWY96vrBQ202seQzMtJAtVoQxdpfokRHY8+jKfwZ/gRR
7tXxIRGfHoOgU9I8jtLNp782o/gjVTs9UGT0I66+PzpzS+XjshdH25OktwIDAQAB
o1MwUTAdBgNVHQ4EFgQUlT7d176QebrmSVanT1sGL2TyFuIwHwYDVR0jBBgwFoAU
lT7d176QebrmSVanT1sGL2TyFuIwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B
AQ0FAAOCAgEAr6ZFZu6WYPUVY9zxJesNmnrm3xGbQn62iU6SrG9tsi/SFkQNPcVe
0CJ/zdA89yKet2Hpo95NSz9O4Jm5gapvGk8ie9UecqzEuKSLWR7mozupaPqDfF0O
YGgnhVMJIPGXbbm52oVV6FZtTRatQHatEnUS/09w2HkA/fyXbvRFA9O6RREevhjU
jcsB/ORx4Ni162Nr8waf6/2pJIturomz8hRtVsD5m6dGQuk7R6d7KZQQ+4Td7Cru
1xOxoWNDc0BBTbkv7DjOcy3YewgANgXqSsLrjprv30InoBgHvL8303EUkge268vd
jZ9mEsXdbZAVX1exetdBcoMQG8UmkKPnyU09w9NltnR7gVqZQyPDNZKTefP505X6
67du/bw3Try/qUbiwJoyr1hf2d7rAJQ2CHDgedz8v5UszX4FAZ/yB5gUUxczld+r
6CCNR7FRfCCNmU6WPSa6CFvlg3x7JRXIdITHMtr14bhtLSmcfmRZhpG9N8r54C4P
L5OluPzU2P2JpV8i8YX8az5mFCdPxrAzjoAN8KU9WYp1LjKkTRT0UGYaTXLcVxyx
4+AWPJgT2GLXRyAcoEFdRQDSG+8jUy+ra0iEN6jp6JN04zBhIWVoQoA6+8u3PAna
DBVn5n32PZQjfu21u+cjvR3TrA3dXwi0/DPOYAeYr2S4D2R+6EAwFAo=
-----END CERTIFICATE-----
Loading