Skip to content

Commit

Permalink
use shared function for cert fingerprint
Browse files Browse the repository at this point in the history
  • Loading branch information
johanlundberg committed Nov 30, 2023
1 parent aa4b356 commit f6b7769
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/auth_server/mdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer
from pyexpat import ExpatError

from auth_server.cert_utils import load_cert_from_str, serialize_certificate
from auth_server.cert_utils import load_pem_from_str, serialize_certificate
from auth_server.models.gnap import Key, Proof, ProofMethod
from auth_server.utils import get_values, hash_with

Expand All @@ -37,7 +37,7 @@ class MDQCert(MDQBase):
def deserialize_cert(cls, v: str) -> Certificate:
if isinstance(v, Certificate):
return v
return load_cert_from_str(v)
return load_pem_from_str(v)

@model_serializer
def serialize_mdq_cert(self) -> dict[str, Any]:
Expand Down Expand Up @@ -84,7 +84,7 @@ async def xml_mdq_get(entity_id: str, mdq_url: str) -> MDQData:
for key_descriptor in get_values(key="urn:oasis:names:tc:SAML:2.0:metadata:KeyDescriptor", obj=entity):
use = list(get_values(key="@use", obj=key_descriptor))[0]
raw_cert = list(get_values(key="http://www.w3.org/2000/09/xmldsig#:X509Certificate", obj=key_descriptor))[0]
cert = load_cert_from_str(raw_cert)
cert = load_pem_from_str(raw_cert)
certs.append(MDQCert(use=KeyUse(use), cert=cert))
return MDQData(certs=certs, metadata=entity)
except (ExpatError, ValueError): # TODO: handle exceptions properly
Expand All @@ -97,7 +97,7 @@ async def mdq_data_to_key(mdq_data: MDQData) -> Optional[Key]:
# There should only be one or zero signing certs
if signing_cert:
logger.info("Found cert in metadata")
return Key( # type: ignore[call-arg]
return Key(
proof=Proof(method=ProofMethod.MTLS),
cert_S256=b64encode(signing_cert[0].fingerprint(algorithm=SHA256())).decode("utf-8"),
)
Expand Down
13 changes: 5 additions & 8 deletions src/auth_server/proof/mtls.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
# -*- coding: utf-8 -*-

from base64 import b64encode

from cryptography.hazmat.primitives.hashes import SHA256
from loguru import logger

from auth_server.cert_utils import load_cert_from_str
from auth_server.cert_utils import load_pem_from_str, rfc8705_fingerprint
from auth_server.models.gnap import Key

__author__ = "lundberg"


async def check_mtls_proof(gnap_key: Key, cert: str) -> bool:
try:
tls_cert = load_cert_from_str(cert)
tls_cert = load_pem_from_str(cert)
except ValueError:
logger.error(f"could not load client cert: {cert}")
return False

tls_fingerprint = b64encode(tls_cert.fingerprint(algorithm=SHA256())).decode("utf-8")
tls_fingerprint = rfc8705_fingerprint(tls_cert)
logger.debug(f"tls cert fingerprint: {str(tls_fingerprint)}")

if gnap_key.cert_S256 is not None:
Expand All @@ -28,8 +25,8 @@ async def check_mtls_proof(gnap_key: Key, cert: str) -> bool:
return True
logger.info("TLS cert fingerprint does NOT match grant request cert#S256")
elif gnap_key.cert is not None:
grant_cert = load_cert_from_str(gnap_key.cert)
grant_cert_fingerprint = b64encode(grant_cert.fingerprint(algorithm=SHA256())).decode("utf-8")
grant_cert = load_pem_from_str(gnap_key.cert)
grant_cert_fingerprint = rfc8705_fingerprint(grant_cert)
logger.debug(f"grant cert fingerprint: {grant_cert_fingerprint}")
if tls_fingerprint == grant_cert_fingerprint:
logger.info("TLS cert fingerprint matches grant request cert fingerprint")
Expand Down
4 changes: 2 additions & 2 deletions src/auth_server/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import yaml
from cryptography import x509
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.serialization import Encoding
from jwcrypto import jwk, jws, jwt
from starlette.testclient import TestClient

from auth_server.api import init_auth_server_api
from auth_server.cert_utils import serialize_certificate
from auth_server.config import ClientKey, load_config
from auth_server.db.transaction_state import AuthSource, TransactionState
from auth_server.models.gnap import (
Expand Down Expand Up @@ -88,7 +88,7 @@ def setUp(self) -> None:

with open(f"{self.datadir}/test.cert", "rb") as f:
self.client_cert = x509.load_pem_x509_certificate(data=f.read())
self.client_cert_str = base64.b64encode(self.client_cert.public_bytes(encoding=Encoding.DER)).decode("utf-8")
self.client_cert_str = serialize_certificate(cert=self.client_cert)
with open(f"{self.datadir}/test_mdq.xml", "rb") as f:
self.mdq_response = f.read()
self.client_jwk = jwk.JWK.generate(kid="default", kty="EC", crv="P-256")
Expand Down
4 changes: 1 addition & 3 deletions src/auth_server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def create_tls_fed_metadata(
entity_id=entity_id,
organization="Test Org",
organization_id=organization_id,
issuers=[
CertIssuers(x509certificate=f"-----BEGIN CERTIFICATE-----\n{client_cert}\n-----END CERTIFICATE-----")
],
issuers=[CertIssuers(x509certificate=client_cert)],
extensions=Extensions(saml_scope=SAMLScopeExtension(scope=scopes)),
)
]
Expand Down
7 changes: 3 additions & 4 deletions src/auth_server/tls_fed_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import asyncio
import base64
import json
from datetime import datetime, timedelta
from pathlib import Path
Expand All @@ -9,12 +8,12 @@
import aiohttp
from aiofiles import open as async_open
from async_lru import alru_cache
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.x509 import load_pem_x509_certificate
from jwcrypto import jwk, jws
from loguru import logger
from pydantic import BaseModel, ConfigDict, ValidationError

from auth_server.cert_utils import rfc8705_fingerprint
from auth_server.config import load_config
from auth_server.models.gnap import Key, Proof, ProofMethod
from auth_server.models.tls_fed_metadata import Entity
Expand Down Expand Up @@ -262,8 +261,8 @@ async def entity_to_key(entity: Optional[MetadataEntity]) -> Optional[Key]:
if certs:
# TODO: how do we handle multiple certs?
logger.info("Found cert in metadata")
return Key( # type: ignore[call-arg]
return Key(
proof=Proof(method=ProofMethod.MTLS),
cert_S256=base64.b64encode(certs[0].fingerprint(algorithm=SHA256())).decode("utf-8"),
cert_S256=rfc8705_fingerprint(certs[0]),
)
return None

0 comments on commit f6b7769

Please sign in to comment.