diff --git a/src/auth_server/mdq.py b/src/auth_server/mdq.py index 297f4a0..ef5e8ee 100644 --- a/src/auth_server/mdq.py +++ b/src/auth_server/mdq.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer from pyexpat import ExpatError -from auth_server.cert_utils import load_cert_from_str, serialize_certificate +from auth_server.cert_utils import load_pem_from_str, serialize_certificate from auth_server.models.gnap import Key, Proof, ProofMethod from auth_server.utils import get_values, hash_with @@ -37,7 +37,7 @@ class MDQCert(MDQBase): def deserialize_cert(cls, v: str) -> Certificate: if isinstance(v, Certificate): return v - return load_cert_from_str(v) + return load_pem_from_str(v) @model_serializer def serialize_mdq_cert(self) -> dict[str, Any]: @@ -84,7 +84,7 @@ async def xml_mdq_get(entity_id: str, mdq_url: str) -> MDQData: for key_descriptor in get_values(key="urn:oasis:names:tc:SAML:2.0:metadata:KeyDescriptor", obj=entity): use = list(get_values(key="@use", obj=key_descriptor))[0] raw_cert = list(get_values(key="http://www.w3.org/2000/09/xmldsig#:X509Certificate", obj=key_descriptor))[0] - cert = load_cert_from_str(raw_cert) + cert = load_pem_from_str(raw_cert) certs.append(MDQCert(use=KeyUse(use), cert=cert)) return MDQData(certs=certs, metadata=entity) except (ExpatError, ValueError): # TODO: handle exceptions properly @@ -97,7 +97,7 @@ async def mdq_data_to_key(mdq_data: MDQData) -> Optional[Key]: # There should only be one or zero signing certs if signing_cert: logger.info("Found cert in metadata") - return Key( # type: ignore[call-arg] + return Key( proof=Proof(method=ProofMethod.MTLS), cert_S256=b64encode(signing_cert[0].fingerprint(algorithm=SHA256())).decode("utf-8"), ) diff --git a/src/auth_server/proof/mtls.py b/src/auth_server/proof/mtls.py index d2be124..d670439 100644 --- a/src/auth_server/proof/mtls.py +++ b/src/auth_server/proof/mtls.py @@ -1,11 +1,8 @@ # -*- coding: utf-8 -*- -from base64 import b64encode - -from cryptography.hazmat.primitives.hashes import SHA256 from loguru import logger -from auth_server.cert_utils import load_cert_from_str +from auth_server.cert_utils import load_pem_from_str, rfc8705_fingerprint from auth_server.models.gnap import Key __author__ = "lundberg" @@ -13,12 +10,12 @@ async def check_mtls_proof(gnap_key: Key, cert: str) -> bool: try: - tls_cert = load_cert_from_str(cert) + tls_cert = load_pem_from_str(cert) except ValueError: logger.error(f"could not load client cert: {cert}") return False - tls_fingerprint = b64encode(tls_cert.fingerprint(algorithm=SHA256())).decode("utf-8") + tls_fingerprint = rfc8705_fingerprint(tls_cert) logger.debug(f"tls cert fingerprint: {str(tls_fingerprint)}") if gnap_key.cert_S256 is not None: @@ -28,8 +25,8 @@ async def check_mtls_proof(gnap_key: Key, cert: str) -> bool: return True logger.info("TLS cert fingerprint does NOT match grant request cert#S256") elif gnap_key.cert is not None: - grant_cert = load_cert_from_str(gnap_key.cert) - grant_cert_fingerprint = b64encode(grant_cert.fingerprint(algorithm=SHA256())).decode("utf-8") + grant_cert = load_pem_from_str(gnap_key.cert) + grant_cert_fingerprint = rfc8705_fingerprint(grant_cert) logger.debug(f"grant cert fingerprint: {grant_cert_fingerprint}") if tls_fingerprint == grant_cert_fingerprint: logger.info("TLS cert fingerprint matches grant request cert fingerprint") diff --git a/src/auth_server/tests/test_app.py b/src/auth_server/tests/test_app.py index 63f3a64..c9fd4d8 100644 --- a/src/auth_server/tests/test_app.py +++ b/src/auth_server/tests/test_app.py @@ -13,11 +13,11 @@ import yaml from cryptography import x509 from cryptography.hazmat.primitives.hashes import SHA256 -from cryptography.hazmat.primitives.serialization import Encoding from jwcrypto import jwk, jws, jwt from starlette.testclient import TestClient from auth_server.api import init_auth_server_api +from auth_server.cert_utils import serialize_certificate from auth_server.config import ClientKey, load_config from auth_server.db.transaction_state import AuthSource, TransactionState from auth_server.models.gnap import ( @@ -88,7 +88,7 @@ def setUp(self) -> None: with open(f"{self.datadir}/test.cert", "rb") as f: self.client_cert = x509.load_pem_x509_certificate(data=f.read()) - self.client_cert_str = base64.b64encode(self.client_cert.public_bytes(encoding=Encoding.DER)).decode("utf-8") + self.client_cert_str = serialize_certificate(cert=self.client_cert) with open(f"{self.datadir}/test_mdq.xml", "rb") as f: self.mdq_response = f.read() self.client_jwk = jwk.JWK.generate(kid="default", kty="EC", crv="P-256") diff --git a/src/auth_server/tests/utils.py b/src/auth_server/tests/utils.py index 493e616..849f0f6 100644 --- a/src/auth_server/tests/utils.py +++ b/src/auth_server/tests/utils.py @@ -59,9 +59,7 @@ def create_tls_fed_metadata( entity_id=entity_id, organization="Test Org", organization_id=organization_id, - issuers=[ - CertIssuers(x509certificate=f"-----BEGIN CERTIFICATE-----\n{client_cert}\n-----END CERTIFICATE-----") - ], + issuers=[CertIssuers(x509certificate=client_cert)], extensions=Extensions(saml_scope=SAMLScopeExtension(scope=scopes)), ) ] diff --git a/src/auth_server/tls_fed_auth.py b/src/auth_server/tls_fed_auth.py index 076ade3..3f4f6a6 100644 --- a/src/auth_server/tls_fed_auth.py +++ b/src/auth_server/tls_fed_auth.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import asyncio -import base64 import json from datetime import datetime, timedelta from pathlib import Path @@ -9,12 +8,12 @@ import aiohttp from aiofiles import open as async_open from async_lru import alru_cache -from cryptography.hazmat.primitives.hashes import SHA256 from cryptography.x509 import load_pem_x509_certificate from jwcrypto import jwk, jws from loguru import logger from pydantic import BaseModel, ConfigDict, ValidationError +from auth_server.cert_utils import rfc8705_fingerprint from auth_server.config import load_config from auth_server.models.gnap import Key, Proof, ProofMethod from auth_server.models.tls_fed_metadata import Entity @@ -262,8 +261,8 @@ async def entity_to_key(entity: Optional[MetadataEntity]) -> Optional[Key]: if certs: # TODO: how do we handle multiple certs? logger.info("Found cert in metadata") - return Key( # type: ignore[call-arg] + return Key( proof=Proof(method=ProofMethod.MTLS), - cert_S256=base64.b64encode(certs[0].fingerprint(algorithm=SHA256())).decode("utf-8"), + cert_S256=rfc8705_fingerprint(certs[0]), ) return None