Skip to content

Commit

Permalink
add support for multiple certs in metadata based flows
Browse files Browse the repository at this point in the history
  • Loading branch information
johanlundberg committed Dec 11, 2024
1 parent 6ba11d1 commit 8f61896
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 36 deletions.
10 changes: 7 additions & 3 deletions src/auth_server/db/transaction_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from auth_server.db.client import BaseDB, get_motor_client
from auth_server.mdq import MDQData
from auth_server.models.gnap import Access, GrantRequest, GrantResponse, SubjectRequest
from auth_server.models.gnap import Access, GrantRequest, GrantResponse, Key, SubjectRequest
from auth_server.saml2 import SessionInfo
from auth_server.time_utils import utc_now
from auth_server.tls_fed_auth import MetadataEntity
Expand Down Expand Up @@ -87,12 +87,16 @@ class ConfigState(TransactionState):
config_claims: Dict[str, Any] = Field(default_factory=dict)


class MDQState(TransactionState):
class MetadataState(TransactionState):
keys_from_metadata: List[Key] = Field(default_factory=list)


class MDQState(MetadataState):
auth_source: AuthSource = AuthSource.MDQ
mdq_data: Optional[MDQData] = None


class TLSFEDState(TransactionState):
class TLSFEDState(MetadataState):
auth_source: AuthSource = AuthSource.TLSFED
entity: Optional[MetadataEntity] = None

Expand Down
38 changes: 28 additions & 10 deletions src/auth_server/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
TLSFEDState,
get_transaction_state_db,
)
from auth_server.mdq import mdq_data_to_key, xml_mdq_get
from auth_server.mdq import mdq_data_to_keys, xml_mdq_get
from auth_server.models.claims import CAClaims, Claims, ConfigClaims, MDQClaims, SAMLAssertionClaims, TLSFEDClaims
from auth_server.models.gnap import (
AccessTokenFlags,
Expand All @@ -58,7 +58,7 @@
from auth_server.proof.jws import check_jws_proof, check_jwsd_proof
from auth_server.proof.mtls import check_mtls_proof
from auth_server.time_utils import utc_now
from auth_server.tls_fed_auth import entity_to_key, get_entity
from auth_server.tls_fed_auth import entity_to_keys, get_entity
from auth_server.utils import get_hex_uuid4, get_values

__author__ = "lundberg"
Expand Down Expand Up @@ -519,7 +519,23 @@ async def handle_interaction(self) -> Optional[GrantResponse]:
return None


class MDQFlow(OnlyMTLSProofFlow):
class MetadataFlow(OnlyMTLSProofFlow):
# Used to handle multiple keys in metadata when rolling out new a new key
async def validate_proof(self) -> Optional[GrantResponse]:
for client_key in self.state.keys_from_metadata:
self.state.grant_request.client.key = client_key
try:
await super().validate_proof()
except NextFlowException:
pass
if self.state.proof_ok:
break
if not self.state.proof_ok:
raise NextFlowException(status_code=401, detail="no client certificate found")
return None


class MDQFlow(MetadataFlow):
@classmethod
def load_state(cls, state: Mapping[str, Any]) -> MDQState:
return MDQState.from_dict(state=state)
Expand All @@ -541,11 +557,12 @@ async def lookup_client_key(self) -> Optional[GrantResponse]:
# Look for a key using mdq
logger.info(f"Trying to load key from mdq")
self.state.mdq_data = await xml_mdq_get(entity_id=key_id, mdq_url=self.config.mdq_server)
client_key = await mdq_data_to_key(self.state.mdq_data)
client_keys = await mdq_data_to_keys(self.state.mdq_data)

if not client_key:
if not client_keys:
raise NextFlowException(status_code=400, detail=f"no client key found for {key_id}")
self.state.grant_request.client.key = client_key

self.state.keys_from_metadata = client_keys
return None

async def create_claims(self) -> MDQClaims:
Expand Down Expand Up @@ -578,7 +595,7 @@ async def create_claims(self) -> MDQClaims:
return MDQClaims(**base_claims.model_dump(exclude_none=True), entity_id=entity_id, scopes=scopes, source=source)


class TLSFEDFlow(OnlyMTLSProofFlow):
class TLSFEDFlow(MetadataFlow):
@classmethod
def load_state(cls, state: Mapping[str, Any]) -> TLSFEDState:
return TLSFEDState.from_dict(state=state)
Expand All @@ -600,11 +617,12 @@ async def lookup_client_key(self) -> Optional[GrantResponse]:
# Look for a key in the TLS fed metadata
logger.info("Trying to load key from TLS fed auth")
self.state.entity = await get_entity(entity_id=key_id)
client_key = await entity_to_key(self.state.entity)
client_keys = await entity_to_keys(self.state.entity)

if not client_key:
if not client_keys:
raise NextFlowException(status_code=400, detail=f"no client key found for {key_id}")
self.state.grant_request.client.key = client_key

self.state.keys_from_metadata = client_keys
return None

async def create_claims(self) -> TLSFEDClaims:
Expand Down
30 changes: 16 additions & 14 deletions src/auth_server/mdq.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
# -*- coding: utf-8 -*-
import logging
from base64 import b64encode
from collections import OrderedDict as _OrderedDict
from enum import Enum
from typing import Any, List, Optional, OrderedDict
from typing import Any, List, OrderedDict

import aiohttp
import xmltodict
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
from cryptography.hazmat.primitives.hashes import SHA1
from cryptography.x509 import Certificate
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer
from pyexpat import ExpatError

from auth_server.cert_utils import load_pem_from_str, serialize_certificate
from auth_server.cert_utils import load_pem_from_str, rfc8705_fingerprint, serialize_certificate
from auth_server.models.gnap import Key, Proof, ProofMethod
from auth_server.utils import get_values, hash_with

Expand Down Expand Up @@ -83,7 +82,7 @@ async def xml_mdq_get(entity_id: str, mdq_url: str) -> MDQData:
entity = xmltodict.parse(xml, process_namespaces=True)
certs = []
# Certs
for key_descriptor in get_values(key="urn:oasis:names:tc:SAML:2.0:metadata:KeyDescriptor", obj=entity):
for key_descriptor in list(get_values(key="urn:oasis:names:tc:SAML:2.0:metadata:KeyDescriptor", obj=entity))[0]:
use = list(get_values(key="@use", obj=key_descriptor))[0]
raw_cert = list(get_values(key="http://www.w3.org/2000/09/xmldsig#:X509Certificate", obj=key_descriptor))[0]
cert = load_pem_from_str(raw_cert)
Expand All @@ -94,13 +93,16 @@ async def xml_mdq_get(entity_id: str, mdq_url: str) -> MDQData:
return MDQData()


async def mdq_data_to_key(mdq_data: MDQData) -> Optional[Key]:
signing_cert = [item.cert for item in mdq_data.certs if item.use == KeyUse.SIGNING]
# There should only be one or zero signing certs
if signing_cert:
logger.info("Found cert in metadata")
return Key(
proof=Proof(method=ProofMethod.MTLS),
cert_S256=b64encode(signing_cert[0].fingerprint(algorithm=SHA256())).decode("utf-8"),
async def mdq_data_to_keys(mdq_data: MDQData) -> list[Key]:
keys = list()
signing_certs = [item.cert for item in mdq_data.certs if item.use == KeyUse.SIGNING]
for cert in signing_certs:
_fingerprint = rfc8705_fingerprint(cert)
logger.info(f"Found cert in metadata, S256: {_fingerprint}")
keys.append(
Key(
proof=Proof(method=ProofMethod.MTLS),
cert_S256=_fingerprint,
)
)
return None
return keys
21 changes: 12 additions & 9 deletions src/auth_server/tls_fed_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,20 +271,23 @@ async def get_entity(entity_id: str) -> Optional[MetadataEntity]:
return None


async def entity_to_key(entity: Optional[MetadataEntity]) -> Optional[Key]:
async def entity_to_keys(entity: Optional[MetadataEntity]) -> list[Key]:
keys: list[Key] = []
if entity is None:
return None
return keys

certs = [
load_pem_x509_certificate(item.x509certificate.encode())
for item in entity.issuers
if item.x509certificate is not None
]
if certs:
# TODO: how do we handle multiple certs?
logger.info("Found cert in metadata")
return Key(
proof=Proof(method=ProofMethod.MTLS),
cert_S256=rfc8705_fingerprint(certs[0]),
for cert in certs:
_fingerprint = rfc8705_fingerprint(cert)
logger.info(f"Found cert in metadata, S256: {_fingerprint}")
keys.append(
Key(
proof=Proof(method=ProofMethod.MTLS),
cert_S256=_fingerprint,
)
)
return None
return keys

0 comments on commit 8f61896

Please sign in to comment.