diff --git a/src/auth_server/tests/data/test_mdq.xml b/src/auth_server/tests/data/test_mdq.xml
index 7ca8675..810993e 100644
--- a/src/auth_server/tests/data/test_mdq.xml
+++ b/src/auth_server/tests/data/test_mdq.xml
@@ -32,6 +32,13 @@
+
+
+
+ MIIFOTCCAyGgAwIBAgIUFfCwL9eeKjTqY5RZCuLLnPvYxdgwDQYJKoZIhvcNAQELBQAwTzELMAkGA1UEBhMCU0UxCTAHBgNVBAgMADENMAsGA1UEBwwEVGVzdDENMAsGA1UECgwEVGVzdDEXMBUGA1UEAwwOdGVzdC5sb2NhbGhvc3QwHhcNMjQxMjExMTMzNTI3WhcNMjQxMjEyMTMzNTI3WjBPMQswCQYDVQQGEwJTRTEJMAcGA1UECAwAMQ0wCwYDVQQHDARUZXN0MQ0wCwYDVQQKDARUZXN0MRcwFQYDVQQDDA50ZXN0LmxvY2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBALougAZhSedNXRcPVYMpCZKKHscY5l8Kb1pLk14++Ktz5olIZKdfY9SWfYkZpAmshubEQ13n0PJFzZohEvZ/xDczbK7xrCAjYZuCFzVLgUj1E3rBm7yN5D1wTSKzmhmGs2JFSDxo5a+NAJDEuZXvi2ypOuWn/KZzmY+aZY9e/L7jTz7e8kT9xZN8n4Nd7Uc50S1RB89zkmbc4M/sRLkFypv7rO8BGStEn+KnaPfAVCsyiPDjoIss5Qm1KdDAl+7g/gmYch+u/ilv+52jkUecDo7cyoipvNcSIawH3pIM7S3tmF7PuUl/Ko7qotNG9OxIQJCSlkyIO3F/hFKWe60tG+Gxh6PnDangOhAt6kvCYUpemELqFQwVjB8KveqkddlQx3TUPM1x5oJ/p6JKklgUxbzWrC9oMrxR9gsjs4jd2384WuADM3C5UxDoPQLEirLUB50Gj9Xkx3dPtEM3kqpAxOh4SKMMvN7vGE8iCAcga7HZzDekwn/R8gUzxLpY7qSvMwgADX7GW+Cb+z9wrPg8gg9vRAFV1XCBMH1+1l4m6+ZaWE+rKFTlLT8YPLKBjBlzZ5MjX3XWhdvQesRu1SlgA+mR7GrAj9xF3BMvUE2Vn2hbQqgJYBSFasP5PvkLITClbR4uMUfeskLcllogHQt2a4Pj71pyRsN8s7SlDRLviAElAgMBAAGjDTALMAkGA1UdEQQCMAAwDQYJKoZIhvcNAQELBQADggIBAAHo8UTXtytQmf0Q6c2pRsn96uVxlxP4+tQ6J1GXAtGq511SpqAR/BnBYbMw6VOwPjfZxKN2HK43dKX6us2wz4vD5RV7rt7ssZwysSn0kCJGqmH8/vRewQrKceamnRsF3Y+PUdXWhqDTJsLnYev/XnkpFQjhKs/1ALY7D7PaH8UoQCNrwa0ZQPKUJaCqZ08E43wbvOlk4Gwosa+HN3eMMsmCj4nURxGV8IpSc445GWHzMGw3JrfWwENFcVp4He9CB3Uem0MqUnU6H4FlFpbiOYGS3oH6fnfqAmTa4aLm0Hg75t5xc/nXPPNZXmwlWzG91QgP/AFv/PpFvc4HdmDIl7kgSYol7SPvwC9Stvw2nXXcc4Vg/ceeYxmbcZWB4bAy8oYPNqq/+GWOQeC2SFlie2H2NtYBRqFEJhlspYpjRR79cU+98syWe76ccDYw2w7+RhX5NEdE3/+VDmlPIePhy0iPXueLjL0VgGvIRWmcxcZ2ZaF/hQ8yTqP7f92igU7Y6ynej+mzPcDzQhXA1wDNSD3cBM2E56/MLQTKmgbeFGgr/MsGOiSpUMYR9Dh1nao1itlBhkvcLkdKy8Ulx4RqsnCohtbexSW3Qu1ObLGOabafL069DzcHL9JmainO3UwFpp/z+SFfyq/ZgRz4I34AXDg/x7BtLIKO/c8Rkzhr3fF4
+
+
+
urn:oasis:names:tc:SAML:2.0:nameid-format:transient
diff --git a/src/auth_server/tests/test_app.py b/src/auth_server/tests/test_app.py
index 264a136..52372ce 100644
--- a/src/auth_server/tests/test_app.py
+++ b/src/auth_server/tests/test_app.py
@@ -46,7 +46,7 @@
from auth_server.models.status import Status
from auth_server.saml2 import AuthnInfo, NameID, SAMLAttributes, SessionInfo
from auth_server.testing import MongoTemporaryInstance
-from auth_server.tests.utils import create_tls_fed_metadata, tls_fed_metadata_to_jws
+from auth_server.tests.utils import create_cert, create_tls_fed_metadata, tls_fed_metadata_to_jws
from auth_server.time_utils import utc_now
from auth_server.tls_fed_auth import get_tls_fed_metadata
from auth_server.utils import get_hash_by_name, get_signing_key, hash_with, load_jwks
@@ -503,8 +503,14 @@ def test_mdq_flow(self, mock_mdq):
assert claims["scopes"] == ["localhost"]
assert claims["source"] == "http://www.swamid.se/"
- @mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
- def test_tls_fed_flow_remote_metadata(self, mock_metadata):
+ def _setup_remote_tls_fed_test(
+ self, entity_id: str, scopes: list[str] | None = None, client_certs: list[str] | None = None
+ ) -> bytes:
+ if scopes is None:
+ scopes = ["test.localhost"]
+ if client_certs is None:
+ client_certs = [self.client_cert_str]
+
self.config["auth_flows"] = json.dumps(["TestFlow", "TLSFEDFlow"])
self.config["tls_fed_metadata"] = json.dumps(
[{"remote": "https://metadata.example.com/metadata.jws", "jwks": f"{self.datadir}/tls_fed_jwks.json"}]
@@ -516,10 +522,7 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata):
tls_fed_jwks = jwk.JWKSet()
tls_fed_jwks.import_keyset(f.read())
- entity_id = "https://test.localhost"
- metadata = create_tls_fed_metadata(
- entity_id=entity_id, scopes=["test.localhost"], client_cert=self.client_cert_str
- )
+ metadata = create_tls_fed_metadata(entity_id=entity_id, scopes=scopes, client_certs=client_certs)
metadata_jws = tls_fed_metadata_to_jws(
metadata,
key=tls_fed_jwks.get_key("metadata_signing_key_id"),
@@ -527,6 +530,12 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata):
expires=timedelta(days=14),
alg=SupportedAlgorithms.ES256,
)
+ return metadata_jws
+
+ @mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
+ def test_tls_fed_flow_remote_metadata(self, mock_metadata):
+ entity_id = "https://test.localhost"
+ metadata_jws = self._setup_remote_tls_fed_test(entity_id=entity_id)
mock_metadata.return_value = MockResponse(content=metadata_jws)
# Start transaction
@@ -550,6 +559,36 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata):
assert claims["organization_id"] == "SE0123456789"
assert claims["source"] == "metadata.example.com"
+ @mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
+ def test_tls_fed_flow_remote_metadata_multi_certs(self, mock_metadata):
+ entity_id = "https://test.localhost"
+ new_client_key, new_client_cert = create_cert(common_name="test.localhost")
+ new_client_cert_str = serialize_certificate(cert=new_client_cert)
+ client_certs = [new_client_cert_str, self.client_cert_str]
+ metadata_jws = self._setup_remote_tls_fed_test(entity_id=entity_id, client_certs=client_certs)
+ mock_metadata.return_value = MockResponse(content=metadata_jws)
+
+ # Start transaction
+ req = GrantRequest(
+ client=Client(key=entity_id),
+ access_token=[AccessTokenRequest(flags=[AccessTokenFlags.BEARER])],
+ )
+ client_header = {"Client-Cert": new_client_cert_str}
+ response = self.client.post("/transaction", json=req.model_dump(exclude_none=True), headers=client_header)
+ assert response.status_code == 200
+ assert "access_token" in response.json()
+ access_token = response.json()["access_token"]
+ assert AccessTokenFlags.BEARER.value in access_token["flags"]
+ assert access_token["value"] is not None
+
+ # Verify token and check claims
+ claims = self._get_access_token_claims(access_token=access_token, client=self.client)
+ assert claims["auth_source"] == AuthSource.TLSFED
+ assert claims["entity_id"] == "https://test.localhost"
+ assert claims["scopes"] == ["test.localhost"]
+ assert claims["organization_id"] == "SE0123456789"
+ assert claims["source"] == "metadata.example.com"
+
def test_tls_fed_flow_local_metadata(self):
# Create metadata jws and save it as a temporary file
with open(f"{self.datadir}/tls_fed_jwks.json", "r") as f:
@@ -558,7 +597,7 @@ def test_tls_fed_flow_local_metadata(self):
entity_id = "https://test.localhost"
metadata = create_tls_fed_metadata(
- entity_id=entity_id, scopes=["test.localhost"], client_cert=self.client_cert_str
+ entity_id=entity_id, scopes=["test.localhost"], client_certs=[self.client_cert_str]
)
metadata_jws = tls_fed_metadata_to_jws(
metadata,
@@ -613,7 +652,7 @@ def test_tls_fed_flow_expired_entity(self, mock_metadata):
tls_fed_jwks.import_keyset(f.read())
entity_id = "https://test.localhost"
- metadata = create_tls_fed_metadata(entity_id=entity_id, client_cert=self.client_cert_str)
+ metadata = create_tls_fed_metadata(entity_id=entity_id, client_certs=[self.client_cert_str])
metadata_jws = tls_fed_metadata_to_jws(
metadata,
key=tls_fed_jwks.get_key("metadata_signing_key_id"),
diff --git a/src/auth_server/tests/test_tls_fed_metadata.py b/src/auth_server/tests/test_tls_fed_metadata.py
index a3cf7d9..cb56120 100644
--- a/src/auth_server/tests/test_tls_fed_metadata.py
+++ b/src/auth_server/tests/test_tls_fed_metadata.py
@@ -48,7 +48,7 @@ async def _load_metadata(
entity_id=self.entity_id,
cache_ttl=self.cache_ttl.seconds,
scopes=self.scopes,
- client_cert=self.client_cert_str,
+ client_certs=[self.client_cert_str],
)
metadata_jws = tls_fed_metadata_to_jws(
metadata,
@@ -85,7 +85,7 @@ async def test_parse_faulty_metadata(self):
entity_id=self.entity_id,
cache_ttl=self.cache_ttl.seconds,
scopes=self.scopes,
- client_cert=self.client_cert_str,
+ client_certs=[self.client_cert_str],
).json(by_alias=True)
deserialized_metadata = json.loads(serialized_metadata)
entity = deserialized_metadata["entities"][0]
@@ -117,7 +117,7 @@ async def test_parse_unregistered_extension_in_metadata(self):
entity_id=self.entity_id,
cache_ttl=self.cache_ttl.seconds,
scopes=self.scopes,
- client_cert=self.client_cert_str,
+ client_certs=[self.client_cert_str],
).model_dump_json(by_alias=True)
deserialized_metadata = json.loads(serialized_metadata)
diff --git a/src/auth_server/tests/utils.py b/src/auth_server/tests/utils.py
index 849f0f6..b37d25b 100644
--- a/src/auth_server/tests/utils.py
+++ b/src/auth_server/tests/utils.py
@@ -4,6 +4,12 @@
from datetime import datetime, timedelta
from typing import List, Optional, Union
+from cryptography import x509
+from cryptography.hazmat._oid import NameOID
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
+from cryptography.x509 import Certificate
from jwcrypto import jwk, jws
from auth_server.models.jose import SupportedAlgorithms
@@ -46,7 +52,7 @@ def tls_fed_metadata_to_jws(
def create_tls_fed_metadata(
entity_id: str,
- client_cert: str,
+ client_certs: list[str],
cache_ttl: int = 3600,
organization_id: str = "SE0123456789",
scopes: Optional[List[str]] = None,
@@ -59,8 +65,43 @@ def create_tls_fed_metadata(
entity_id=entity_id,
organization="Test Org",
organization_id=organization_id,
- issuers=[CertIssuers(x509certificate=client_cert)],
+ issuers=[CertIssuers(x509certificate=client_cert) for client_cert in client_certs],
extensions=Extensions(saml_scope=SAMLScopeExtension(scope=scopes)),
)
]
return TLSFEDMetadata(version="1.0.0", cache_ttl=cache_ttl, entities=entities)
+
+
+def create_cert(
+ common_name: str, alt_names: list[str] | None = None, days_valid: int = 1
+) -> tuple[RSAPrivateKey, Certificate]:
+ if alt_names is None:
+ alt_names = list()
+ key = rsa.generate_private_key(public_exponent=65537, key_size=4096)
+ subject = issuer = x509.Name(
+ [
+ x509.NameAttribute(NameOID.COUNTRY_NAME, "SE"),
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ""),
+ x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"),
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"),
+ x509.NameAttribute(NameOID.COMMON_NAME, common_name),
+ ]
+ )
+ _alt_names = [x509.DNSName(alt_name) for alt_name in alt_names]
+ now = utc_now()
+ cert = (
+ x509.CertificateBuilder()
+ .subject_name(subject)
+ .issuer_name(issuer)
+ .public_key(key.public_key())
+ .serial_number(x509.random_serial_number())
+ .not_valid_before(now)
+ .not_valid_after(now + timedelta(days=days_valid))
+ .add_extension(
+ x509.SubjectAlternativeName(_alt_names),
+ critical=False,
+ # Sign our certificate with our private key
+ )
+ .sign(key, hashes.SHA256())
+ )
+ return key, cert
diff --git a/src/auth_server/utils.py b/src/auth_server/utils.py
index 94b4d7a..1c711dc 100644
--- a/src/auth_server/utils.py
+++ b/src/auth_server/utils.py
@@ -3,6 +3,7 @@
import json
import logging
from base64 import urlsafe_b64encode
+from datetime import datetime, timezone
from functools import lru_cache
from typing import Any, Callable, Generator, Mapping, Sequence, Union
from uuid import uuid4
@@ -20,6 +21,10 @@
logger = logging.getLogger(__name__)
+def utc_now() -> datetime:
+ return datetime.now(tz=timezone.utc)
+
+
@lru_cache()
def load_jwks() -> jwk.JWKSet:
config = load_config()