diff --git a/tests/conftest.py b/tests/conftest.py index 1b0c0e84e..4d2c46a22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,6 +54,13 @@ def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: ) +@pytest.fixture +def tls_client_certificate(request, tls_certificate_authority: trustme.CA) -> trustme.LeafCert: + return tls_certificate_authority.issue_cert( + "client@example.com", common_name=getattr(request, "param", "uvicorn client") + ) + + @pytest.fixture def tls_ca_certificate_pem_path(tls_certificate_authority: trustme.CA): with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem: @@ -107,6 +114,20 @@ def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext: return ssl_ctx +@pytest.fixture +def tls_client_ssl_context( + tls_certificate_authority: trustme.CA, tls_client_certificate: trustme.LeafCert +) -> ssl.SSLContext: + ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + tls_certificate_authority.configure_trust(ssl_ctx) + + # Load the client certificate chain into the SSL context + with tls_client_certificate.private_key_and_cert_chain_pem.tempfile() as client_cert_pem: + ssl_ctx.load_cert_chain(certfile=client_cert_pem) + + return ssl_ctx + + @pytest.fixture(scope="package") def reload_directory_structure(tmp_path_factory: pytest.TempPathFactory): """ diff --git a/tests/test_ssl.py b/tests/test_ssl.py index da60bb8dd..0ea37e5c1 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,5 +1,8 @@ +import ssl + import httpx import pytest +from cryptography import x509 from tests.utils import run_server from uvicorn.config import Config @@ -34,6 +37,79 @@ async def test_run( assert response.status_code == 204 +@pytest.mark.anyio +@pytest.mark.parametrize( + "tls_client_certificate, expected_common_name", + [ + ("test common name", "test common name"), + ], + indirect=["tls_client_certificate"], +) +@pytest.mark.anyio +async def test_run_httptools_client_cert( + tls_client_ssl_context, + tls_certificate_server_cert_path, + tls_certificate_private_key_path, + tls_ca_certificate_pem_path, + expected_common_name, + unused_tcp_port: int, +): + async def app(scope, receive, send): + assert scope["type"] == "http" + assert len(scope["extensions"]["tls"]["client_cert_chain"]) >= 1 + cert = x509.load_pem_x509_certificate(scope["extensions"]["tls"]["client_cert_chain"][0].encode("utf-8")) + assert cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value == expected_common_name + cipher_suites = [cipher["name"] for cipher in ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER).get_ciphers()] + assert scope["extensions"]["tls"]["cipher_suite"] in cipher_suites + assert scope["extensions"]["tls"]["tls_version"].startswith("TLSv") or scope["extensions"]["tls"][ + "tls_version" + ].startswith("SSLv") + + await send({"type": "http.response.start", "status": 204, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + config = Config( + app=app, + loop="asyncio", + http="httptools", + limit_max_requests=1, + ssl_keyfile=tls_certificate_private_key_path, + ssl_certfile=tls_certificate_server_cert_path, + ssl_ca_certs=tls_ca_certificate_pem_path, + ssl_cert_reqs=ssl.CERT_REQUIRED, + port=unused_tcp_port, + ) + async with run_server(config): + async with httpx.AsyncClient(verify=tls_client_ssl_context) as client: + response = await client.get(f"https://127.0.0.1:{unused_tcp_port}") + assert response.status_code == 204 + + +@pytest.mark.anyio +async def test_run_h11_client_cert( + tls_client_ssl_context, + tls_ca_certificate_pem_path, + tls_certificate_server_cert_path, + tls_certificate_private_key_path, + unused_tcp_port: int, +): + config = Config( + app=app, + loop="asyncio", + http="h11", + limit_max_requests=1, + ssl_keyfile=tls_certificate_private_key_path, + ssl_certfile=tls_certificate_server_cert_path, + ssl_ca_certs=tls_ca_certificate_pem_path, + ssl_cert_reqs=ssl.CERT_REQUIRED, + port=unused_tcp_port, + ) + async with run_server(config): + async with httpx.AsyncClient(verify=tls_client_ssl_context) as client: + response = await client.get(f"https://127.0.0.1:{unused_tcp_port}") + assert response.status_code == 204 + + @pytest.mark.anyio async def test_run_chain( tls_ca_ssl_context, diff --git a/uvicorn/_types.py b/uvicorn/_types.py index c927cc11d..ad144b4a3 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -53,6 +53,17 @@ class ASGIVersions(TypedDict): version: Literal["2.0"] | Literal["3.0"] +class TLSExtensionInfo(TypedDict, total=False): + server_cert: str | None + client_cert_chain: list[str] + tls_version: str | None + cipher_suite: str | None + + +class Extensions(TypedDict, total=False): + tls: TLSExtensionInfo + + class HTTPScope(TypedDict): type: Literal["http"] asgi: ASGIVersions @@ -67,7 +78,7 @@ class HTTPScope(TypedDict): client: tuple[str, int] | None server: tuple[str, int | None] | None state: NotRequired[dict[str, Any]] - extensions: NotRequired[dict[str, dict[object, object]]] + extensions: NotRequired[Extensions] class WebSocketScope(TypedDict): diff --git a/uvicorn/config.py b/uvicorn/config.py index ae996c1cb..594d1707e 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -260,6 +260,7 @@ def __init__( self.callback_notify = callback_notify self.ssl_keyfile = ssl_keyfile self.ssl_certfile = ssl_certfile + self.ssl_cert_pem: str | None = None self.ssl_keyfile_password = ssl_keyfile_password self.ssl_version = ssl_version self.ssl_cert_reqs = ssl_cert_reqs @@ -407,6 +408,8 @@ def load(self) -> None: ca_certs=self.ssl_ca_certs, ciphers=self.ssl_ciphers, ) + with open(self.ssl_certfile) as file: + self.ssl_cert_pem = file.read() else: self.ssl = None diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index b8cdde3ab..2794110d9 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -21,7 +21,14 @@ from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable -from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl +from uvicorn.protocols.utils import ( + get_client_addr, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + get_tls_info, + is_ssl, +) from uvicorn.server import ServerState @@ -212,7 +219,12 @@ def handle_events(self) -> None: "query_string": query_string, "headers": self.headers, "state": self.app_state.copy(), + "extensions": {}, } + + if self.config.is_ssl: + self.scope["extensions"]["tls"] = get_tls_info(self.transport, self.config) + if self._should_upgrade(): self.handle_websocket_upgrade(event) return diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index e8795ed35..f00bd14da 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -22,7 +22,14 @@ from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable -from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl +from uvicorn.protocols.utils import ( + get_client_addr, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + get_tls_info, + is_ssl, +) from uvicorn.server import ServerState HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]') @@ -230,8 +237,12 @@ def on_message_begin(self) -> None: "root_path": self.root_path, "headers": self.headers, "state": self.app_state.copy(), + "extensions": {}, } + if self.config.is_ssl: + self.scope["extensions"]["tls"] = get_tls_info(self.transport, self.config) + # Parser callbacks def on_url(self, url: bytes) -> None: self.url += url diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index e1d6f01d5..b71c803b8 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -1,9 +1,11 @@ from __future__ import annotations import asyncio +import ssl import urllib.parse -from uvicorn._types import WWWScope +from uvicorn._types import TLSExtensionInfo, WWWScope +from uvicorn.config import Config class ClientDisconnected(OSError): ... @@ -54,3 +56,37 @@ def get_path_with_query_string(scope: WWWScope) -> str: if scope["query_string"]: path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii")) return path_with_query_string + + +def get_tls_info(transport: asyncio.Transport, config: Config) -> TLSExtensionInfo: + ### + # server_cert: Unable to set from transport information, need to set from server_config + # client_cert_chain: + # tls_version: + # cipher_suite: + ### + + ssl_info: TLSExtensionInfo = { + "server_cert": None, + "client_cert_chain": [], + "tls_version": None, + "cipher_suite": None, + } + + ssl_info["server_cert"] = config.ssl_cert_pem + + ssl_object = transport.get_extra_info("ssl_object") + if ssl_object is not None: + client_chain = ( + ssl_object.get_verified_chain() + if hasattr(ssl_object, "get_verified_chain") + else [ssl_object.getpeercert(binary_form=True)] + ) + for cert in client_chain: + if cert is not None: + ssl_info["client_cert_chain"].append(ssl.DER_cert_to_PEM_cert(cert)) + + ssl_info["tls_version"] = ssl_object.version() + ssl_info["cipher_suite"] = ssl_object.cipher()[0] if ssl_object.cipher() else None + + return ssl_info