diff --git a/src/auth_server/tests/data/test_mdq.xml b/src/auth_server/tests/data/test_mdq.xml index 7ca8675..810993e 100644 --- a/src/auth_server/tests/data/test_mdq.xml +++ b/src/auth_server/tests/data/test_mdq.xml @@ -32,6 +32,13 @@ + + + + MIIFOTCCAyGgAwIBAgIUFfCwL9eeKjTqY5RZCuLLnPvYxdgwDQYJKoZIhvcNAQELBQAwTzELMAkGA1UEBhMCU0UxCTAHBgNVBAgMADENMAsGA1UEBwwEVGVzdDENMAsGA1UECgwEVGVzdDEXMBUGA1UEAwwOdGVzdC5sb2NhbGhvc3QwHhcNMjQxMjExMTMzNTI3WhcNMjQxMjEyMTMzNTI3WjBPMQswCQYDVQQGEwJTRTEJMAcGA1UECAwAMQ0wCwYDVQQHDARUZXN0MQ0wCwYDVQQKDARUZXN0MRcwFQYDVQQDDA50ZXN0LmxvY2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBALougAZhSedNXRcPVYMpCZKKHscY5l8Kb1pLk14++Ktz5olIZKdfY9SWfYkZpAmshubEQ13n0PJFzZohEvZ/xDczbK7xrCAjYZuCFzVLgUj1E3rBm7yN5D1wTSKzmhmGs2JFSDxo5a+NAJDEuZXvi2ypOuWn/KZzmY+aZY9e/L7jTz7e8kT9xZN8n4Nd7Uc50S1RB89zkmbc4M/sRLkFypv7rO8BGStEn+KnaPfAVCsyiPDjoIss5Qm1KdDAl+7g/gmYch+u/ilv+52jkUecDo7cyoipvNcSIawH3pIM7S3tmF7PuUl/Ko7qotNG9OxIQJCSlkyIO3F/hFKWe60tG+Gxh6PnDangOhAt6kvCYUpemELqFQwVjB8KveqkddlQx3TUPM1x5oJ/p6JKklgUxbzWrC9oMrxR9gsjs4jd2384WuADM3C5UxDoPQLEirLUB50Gj9Xkx3dPtEM3kqpAxOh4SKMMvN7vGE8iCAcga7HZzDekwn/R8gUzxLpY7qSvMwgADX7GW+Cb+z9wrPg8gg9vRAFV1XCBMH1+1l4m6+ZaWE+rKFTlLT8YPLKBjBlzZ5MjX3XWhdvQesRu1SlgA+mR7GrAj9xF3BMvUE2Vn2hbQqgJYBSFasP5PvkLITClbR4uMUfeskLcllogHQt2a4Pj71pyRsN8s7SlDRLviAElAgMBAAGjDTALMAkGA1UdEQQCMAAwDQYJKoZIhvcNAQELBQADggIBAAHo8UTXtytQmf0Q6c2pRsn96uVxlxP4+tQ6J1GXAtGq511SpqAR/BnBYbMw6VOwPjfZxKN2HK43dKX6us2wz4vD5RV7rt7ssZwysSn0kCJGqmH8/vRewQrKceamnRsF3Y+PUdXWhqDTJsLnYev/XnkpFQjhKs/1ALY7D7PaH8UoQCNrwa0ZQPKUJaCqZ08E43wbvOlk4Gwosa+HN3eMMsmCj4nURxGV8IpSc445GWHzMGw3JrfWwENFcVp4He9CB3Uem0MqUnU6H4FlFpbiOYGS3oH6fnfqAmTa4aLm0Hg75t5xc/nXPPNZXmwlWzG91QgP/AFv/PpFvc4HdmDIl7kgSYol7SPvwC9Stvw2nXXcc4Vg/ceeYxmbcZWB4bAy8oYPNqq/+GWOQeC2SFlie2H2NtYBRqFEJhlspYpjRR79cU+98syWe76ccDYw2w7+RhX5NEdE3/+VDmlPIePhy0iPXueLjL0VgGvIRWmcxcZ2ZaF/hQ8yTqP7f92igU7Y6ynej+mzPcDzQhXA1wDNSD3cBM2E56/MLQTKmgbeFGgr/MsGOiSpUMYR9Dh1nao1itlBhkvcLkdKy8Ulx4RqsnCohtbexSW3Qu1ObLGOabafL069DzcHL9JmainO3UwFpp/z+SFfyq/ZgRz4I34AXDg/x7BtLIKO/c8Rkzhr3fF4 + + + urn:oasis:names:tc:SAML:2.0:nameid-format:transient diff --git a/src/auth_server/tests/test_app.py b/src/auth_server/tests/test_app.py index 264a136..52372ce 100644 --- a/src/auth_server/tests/test_app.py +++ b/src/auth_server/tests/test_app.py @@ -46,7 +46,7 @@ from auth_server.models.status import Status from auth_server.saml2 import AuthnInfo, NameID, SAMLAttributes, SessionInfo from auth_server.testing import MongoTemporaryInstance -from auth_server.tests.utils import create_tls_fed_metadata, tls_fed_metadata_to_jws +from auth_server.tests.utils import create_cert, create_tls_fed_metadata, tls_fed_metadata_to_jws from auth_server.time_utils import utc_now from auth_server.tls_fed_auth import get_tls_fed_metadata from auth_server.utils import get_hash_by_name, get_signing_key, hash_with, load_jwks @@ -503,8 +503,14 @@ def test_mdq_flow(self, mock_mdq): assert claims["scopes"] == ["localhost"] assert claims["source"] == "http://www.swamid.se/" - @mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock) - def test_tls_fed_flow_remote_metadata(self, mock_metadata): + def _setup_remote_tls_fed_test( + self, entity_id: str, scopes: list[str] | None = None, client_certs: list[str] | None = None + ) -> bytes: + if scopes is None: + scopes = ["test.localhost"] + if client_certs is None: + client_certs = [self.client_cert_str] + self.config["auth_flows"] = json.dumps(["TestFlow", "TLSFEDFlow"]) self.config["tls_fed_metadata"] = json.dumps( [{"remote": "https://metadata.example.com/metadata.jws", "jwks": f"{self.datadir}/tls_fed_jwks.json"}] @@ -516,10 +522,7 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata): tls_fed_jwks = jwk.JWKSet() tls_fed_jwks.import_keyset(f.read()) - entity_id = "https://test.localhost" - metadata = create_tls_fed_metadata( - entity_id=entity_id, scopes=["test.localhost"], client_cert=self.client_cert_str - ) + metadata = create_tls_fed_metadata(entity_id=entity_id, scopes=scopes, client_certs=client_certs) metadata_jws = tls_fed_metadata_to_jws( metadata, key=tls_fed_jwks.get_key("metadata_signing_key_id"), @@ -527,6 +530,12 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata): expires=timedelta(days=14), alg=SupportedAlgorithms.ES256, ) + return metadata_jws + + @mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock) + def test_tls_fed_flow_remote_metadata(self, mock_metadata): + entity_id = "https://test.localhost" + metadata_jws = self._setup_remote_tls_fed_test(entity_id=entity_id) mock_metadata.return_value = MockResponse(content=metadata_jws) # Start transaction @@ -550,6 +559,36 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata): assert claims["organization_id"] == "SE0123456789" assert claims["source"] == "metadata.example.com" + @mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock) + def test_tls_fed_flow_remote_metadata_multi_certs(self, mock_metadata): + entity_id = "https://test.localhost" + new_client_key, new_client_cert = create_cert(common_name="test.localhost") + new_client_cert_str = serialize_certificate(cert=new_client_cert) + client_certs = [new_client_cert_str, self.client_cert_str] + metadata_jws = self._setup_remote_tls_fed_test(entity_id=entity_id, client_certs=client_certs) + mock_metadata.return_value = MockResponse(content=metadata_jws) + + # Start transaction + req = GrantRequest( + client=Client(key=entity_id), + access_token=[AccessTokenRequest(flags=[AccessTokenFlags.BEARER])], + ) + client_header = {"Client-Cert": new_client_cert_str} + response = self.client.post("/transaction", json=req.model_dump(exclude_none=True), headers=client_header) + assert response.status_code == 200 + assert "access_token" in response.json() + access_token = response.json()["access_token"] + assert AccessTokenFlags.BEARER.value in access_token["flags"] + assert access_token["value"] is not None + + # Verify token and check claims + claims = self._get_access_token_claims(access_token=access_token, client=self.client) + assert claims["auth_source"] == AuthSource.TLSFED + assert claims["entity_id"] == "https://test.localhost" + assert claims["scopes"] == ["test.localhost"] + assert claims["organization_id"] == "SE0123456789" + assert claims["source"] == "metadata.example.com" + def test_tls_fed_flow_local_metadata(self): # Create metadata jws and save it as a temporary file with open(f"{self.datadir}/tls_fed_jwks.json", "r") as f: @@ -558,7 +597,7 @@ def test_tls_fed_flow_local_metadata(self): entity_id = "https://test.localhost" metadata = create_tls_fed_metadata( - entity_id=entity_id, scopes=["test.localhost"], client_cert=self.client_cert_str + entity_id=entity_id, scopes=["test.localhost"], client_certs=[self.client_cert_str] ) metadata_jws = tls_fed_metadata_to_jws( metadata, @@ -613,7 +652,7 @@ def test_tls_fed_flow_expired_entity(self, mock_metadata): tls_fed_jwks.import_keyset(f.read()) entity_id = "https://test.localhost" - metadata = create_tls_fed_metadata(entity_id=entity_id, client_cert=self.client_cert_str) + metadata = create_tls_fed_metadata(entity_id=entity_id, client_certs=[self.client_cert_str]) metadata_jws = tls_fed_metadata_to_jws( metadata, key=tls_fed_jwks.get_key("metadata_signing_key_id"), diff --git a/src/auth_server/tests/test_tls_fed_metadata.py b/src/auth_server/tests/test_tls_fed_metadata.py index a3cf7d9..cb56120 100644 --- a/src/auth_server/tests/test_tls_fed_metadata.py +++ b/src/auth_server/tests/test_tls_fed_metadata.py @@ -48,7 +48,7 @@ async def _load_metadata( entity_id=self.entity_id, cache_ttl=self.cache_ttl.seconds, scopes=self.scopes, - client_cert=self.client_cert_str, + client_certs=[self.client_cert_str], ) metadata_jws = tls_fed_metadata_to_jws( metadata, @@ -85,7 +85,7 @@ async def test_parse_faulty_metadata(self): entity_id=self.entity_id, cache_ttl=self.cache_ttl.seconds, scopes=self.scopes, - client_cert=self.client_cert_str, + client_certs=[self.client_cert_str], ).json(by_alias=True) deserialized_metadata = json.loads(serialized_metadata) entity = deserialized_metadata["entities"][0] @@ -117,7 +117,7 @@ async def test_parse_unregistered_extension_in_metadata(self): entity_id=self.entity_id, cache_ttl=self.cache_ttl.seconds, scopes=self.scopes, - client_cert=self.client_cert_str, + client_certs=[self.client_cert_str], ).model_dump_json(by_alias=True) deserialized_metadata = json.loads(serialized_metadata) diff --git a/src/auth_server/tests/utils.py b/src/auth_server/tests/utils.py index 849f0f6..b37d25b 100644 --- a/src/auth_server/tests/utils.py +++ b/src/auth_server/tests/utils.py @@ -4,6 +4,12 @@ from datetime import datetime, timedelta from typing import List, Optional, Union +from cryptography import x509 +from cryptography.hazmat._oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509 import Certificate from jwcrypto import jwk, jws from auth_server.models.jose import SupportedAlgorithms @@ -46,7 +52,7 @@ def tls_fed_metadata_to_jws( def create_tls_fed_metadata( entity_id: str, - client_cert: str, + client_certs: list[str], cache_ttl: int = 3600, organization_id: str = "SE0123456789", scopes: Optional[List[str]] = None, @@ -59,8 +65,43 @@ def create_tls_fed_metadata( entity_id=entity_id, organization="Test Org", organization_id=organization_id, - issuers=[CertIssuers(x509certificate=client_cert)], + issuers=[CertIssuers(x509certificate=client_cert) for client_cert in client_certs], extensions=Extensions(saml_scope=SAMLScopeExtension(scope=scopes)), ) ] return TLSFEDMetadata(version="1.0.0", cache_ttl=cache_ttl, entities=entities) + + +def create_cert( + common_name: str, alt_names: list[str] | None = None, days_valid: int = 1 +) -> tuple[RSAPrivateKey, Certificate]: + if alt_names is None: + alt_names = list() + key = rsa.generate_private_key(public_exponent=65537, key_size=4096) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "SE"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ""), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"), + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ] + ) + _alt_names = [x509.DNSName(alt_name) for alt_name in alt_names] + now = utc_now() + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=days_valid)) + .add_extension( + x509.SubjectAlternativeName(_alt_names), + critical=False, + # Sign our certificate with our private key + ) + .sign(key, hashes.SHA256()) + ) + return key, cert diff --git a/src/auth_server/utils.py b/src/auth_server/utils.py index 94b4d7a..1c711dc 100644 --- a/src/auth_server/utils.py +++ b/src/auth_server/utils.py @@ -3,6 +3,7 @@ import json import logging from base64 import urlsafe_b64encode +from datetime import datetime, timezone from functools import lru_cache from typing import Any, Callable, Generator, Mapping, Sequence, Union from uuid import uuid4 @@ -20,6 +21,10 @@ logger = logging.getLogger(__name__) +def utc_now() -> datetime: + return datetime.now(tz=timezone.utc) + + @lru_cache() def load_jwks() -> jwk.JWKSet: config = load_config()