Skip to content

Commit

Permalink
implement CA cert auth flow
Browse files Browse the repository at this point in the history
  • Loading branch information
johanlundberg committed Nov 30, 2023
1 parent 83ceed5 commit 708623d
Show file tree
Hide file tree
Showing 15 changed files with 557 additions and 37 deletions.
7 changes: 4 additions & 3 deletions src/auth_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from auth_server.config import AuthServerConfig, ConfigurationError, FlowName, load_config
from auth_server.context import ContextRequestRoute
from auth_server.flows import BaseAuthFlow, ConfigFlow, InteractionFlow, MDQFlow, TestFlow, TLSFEDFlow
from auth_server.flows import BaseAuthFlow, CAFlow, ConfigFlow, InteractionFlow, MDQFlow, TestFlow, TLSFEDFlow
from auth_server.log import init_logging
from auth_server.middleware import JOSEMiddleware
from auth_server.routers.interaction import interaction_router
Expand All @@ -28,10 +28,11 @@ def __init__(self):

# Load flows
self.builtin_flow: Dict[FlowName, Type[BaseAuthFlow]] = {
FlowName.TESTFLOW: TestFlow,
FlowName.INTERACTIONFLOW: InteractionFlow,
FlowName.CAFLOW: CAFlow,
FlowName.CONFIGFLOW: ConfigFlow,
FlowName.INTERACTIONFLOW: InteractionFlow,
FlowName.MDQFLOW: MDQFlow,
FlowName.TESTFLOW: TestFlow,
FlowName.TLSFEDFLOW: TLSFEDFlow,
}
self.auth_flows = self.load_flows(config=config)
Expand Down
183 changes: 163 additions & 20 deletions src/auth_server/cert_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# -*- coding: utf-8 -*-
__author__ = "lundberg"

from base64 import b64encode
from datetime import datetime
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Optional

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.bindings._rust import ObjectIdentifier
from cryptography.hazmat.primitives._serialization import Encoding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from cryptography.x509 import Certificate, ExtensionNotFound, load_der_x509_certificate, load_pem_x509_certificate
from loguru import logger
from pki_tools import Certificate as PKIToolCertificate
from pki_tools import Chain
Expand All @@ -18,77 +21,217 @@

from auth_server.config import ConfigurationError, load_config

OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10")
OID_COMMON_NAME = ObjectIdentifier("2.5.4.3")
OID_SERIAL_NUMBER = ObjectIdentifier("2.5.4.5")
OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION = ObjectIdentifier("1.3.6.1.5.5.7.3.2")


class SupportedOrgIdCA(str, Enum):
EFOS = "Swedish Social Insurance Agency"
EXPITRUST = "Expisoft AB"
SITHS = "Inera AB"


def cert_within_validity_period(cert: Certificate) -> bool:
"""
check if certificate is within the validity period
"""
cert_fingerprint = cert.fingerprint(SHA256())
cert_fingerprint = rfc8705_fingerprint(cert)
now = datetime.utcnow()
if now < cert.not_valid_before:
logger.error(f"Certificate {cert_fingerprint!r} not valid before {cert.not_valid_before}")
logger.error(f"Certificate {cert_fingerprint} not valid before {cert.not_valid_before}")
return False
if now > cert.not_valid_after:
logger.error(f"Certificate {cert_fingerprint!r} not valid after {cert.not_valid_after}")
logger.error(f"Certificate {cert_fingerprint} not valid after {cert.not_valid_after}")
return False
return True


def cert_signed_by_ca(cert: Certificate) -> Optional[Certificate]:
def cert_signed_by_ca(cert: Certificate) -> Optional[str]:
"""
check if the cert is signed by any on our loaded CA certs
"""
cert_fingerprint = cert.fingerprint(SHA256())
for ca_cert in load_ca_certs():
cert_fingerprint = rfc8705_fingerprint(cert)
for ca_name, ca_cert in load_ca_certs().items():
try:
cert.verify_directly_issued_by(ca_cert)
logger.debug(f"Certificate {cert_fingerprint!r} signed by CA cert {ca_cert.fingerprint(SHA256())!r}")
return ca_cert
logger.debug(f"Certificate {cert_fingerprint} signed by CA cert {ca_name}")
return ca_name
except (ValueError, TypeError, InvalidSignature):
continue

logger.error(f"Certificate {cert_fingerprint!r} did NOT match any loaded CA cert")
logger.error(f"Certificate {cert_fingerprint} did NOT match any loaded CA cert")
return None


async def is_cert_revoked(cert: Certificate, ca_cert: Certificate) -> bool:
async def is_cert_revoked(cert: Certificate, ca_name: str) -> bool:
"""
check if cert is revoked
"""
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
raise ConfigurationError(f"CA cert {ca_name} not found")
try:
return is_revoked(
cert=PKIToolCertificate.from_cryptography(cert=cert), chain=Chain.from_cryptography([ca_cert])
)
except PKIToolsError as e:
logger.error(f"Certificate {cert.fingerprint(SHA256())!r} failed revoke check: {e}")
logger.error(f"Certificate {rfc8705_fingerprint(cert)} failed revoke check: {e}")
return True


def get_org_id_from_cert(cert: Certificate, ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if not ca_cert:
raise ConfigurationError(f"CA cert {ca_name} not found")
try:
ca_org_name = ca_cert.issuer.get_attributes_for_oid(OID_ORGANIZATION_NAME)[0].value
except IndexError:
logger.error(f"CA certificate {ca_name} has no org name")
return None
try:
supported_ca = SupportedOrgIdCA(ca_org_name)
except ValueError:
logger.info(f"CA {ca_name} is not supported for org id extraction")
return None

if supported_ca is SupportedOrgIdCA.EXPITRUST:
return get_org_id_expitrust(cert=cert)
elif supported_ca is SupportedOrgIdCA.EFOS:
return get_org_id_efos(cert=cert)
elif supported_ca is SupportedOrgIdCA.SITHS:
return get_org_id_siths(cert=cert)
else:
logger.info(f"CA {ca_name} / {ca_org_name} is not implemented for org id extraction")
return None


def get_org_id_expitrust(cert: Certificate) -> Optional[str]:
"""
The org number is just the serial number of the certificate.
"""
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_org_id_siths(cert: Certificate) -> Optional[str]:
"""
The org number is the first part of the serial number of the certificate with a prefix of SE.
ex. SE5565594230-AAAA -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)

# Check that the certificate has enhancedKeyUsage clientAuthentication
try:
cert.extensions.get_extension_for_oid(OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION)
except ExtensionNotFound:
logger.error(f"certificate {cert_fingerprint} has no enhancedKeyUsage clientAuthentication")
return None

# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("SE")
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_org_id_efos(cert: Certificate):
"""
The org number is the first part of the serial number of the certificate with a prefix of EFOS16.
ex. EFOS165565594230-012345 -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)
# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("EFOS16")
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_subject_cn(cert: Certificate) -> Optional[str]:
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject common name")
return None


def get_issuer_cn(ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
logger.error(f"CA {ca_name} not found")
return None
try:
ret = ca_cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"CA {ca_name} has no subject common name")
return None


@lru_cache()
def load_ca_certs() -> list[Certificate]:
def load_ca_certs() -> dict[str, Certificate]:
config = load_config()
if config.ca_certs_path is None:
raise ConfigurationError("no CA certs path specified in config")
certs = []
certs = {}
path = Path(config.ca_certs_path)
for crt in path.glob("**/*.crt"):
for crt in path.glob("**/*.c*"): # match .crt and .cer files
if crt.is_dir():
continue
try:
with open(crt, "rb") as f:
content = f.read()
try:
certs.append(load_pem_x509_certificate(content))
cert = load_pem_x509_certificate(content)
except ValueError:
certs.append(load_der_x509_certificate(content))
cert = load_der_x509_certificate(content)
if cert_within_validity_period(cert):
certs[cert.subject.rfc4514_string()] = cert
except (IOError, ValueError) as e:
logger.error(f"Failed to load CA cert {crt}: {e}")
logger.info(f"Loaded {len(certs)} CA certs")
logger.debug(f"Certs loaded: {certs.keys()}")
return certs


def load_cert_from_str(cert: str) -> Certificate:
def load_pem_from_str(cert: str) -> Certificate:
if not cert.startswith("-----BEGIN CERTIFICATE-----"):
cert = f"-----BEGIN CERTIFICATE-----\n{cert}\n-----END CERTIFICATE-----"
return load_pem_x509_certificate(cert.encode())


def serialize_certificate(cert: Certificate) -> str:
return cert.public_bytes(encoding=Encoding.PEM).decode("utf-8")
def serialize_certificate(cert: Certificate, encoding: Encoding = Encoding.PEM) -> str:
public_bytes = cert.public_bytes(encoding=encoding)
if encoding == Encoding.DER:
return b64encode(public_bytes).decode("ascii")
else:
return public_bytes.decode("ascii")


def rfc8705_fingerprint(cert: Certificate):
return b64encode(cert.fingerprint(algorithm=SHA256())).decode("ascii")
3 changes: 2 additions & 1 deletion src/auth_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class Environment(str, Enum):


class FlowName(str, Enum):
CAFLOW = "CAFlow"
CONFIGFLOW = "ConfigFlow"
MDQFLOW = "MDQFlow"
INTERACTIONFLOW = "InteractionFlow"
MDQFLOW = "MDQFlow"
TESTFLOW = "TestFlow"
TLSFEDFLOW = "TLSFEDFlow"

Expand Down
3 changes: 2 additions & 1 deletion src/auth_server/db/transaction_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ class TLSFEDState(TransactionState):

class CAState(TransactionState):
auth_source: AuthSource = AuthSource.CA
issuer_common_name: Optional[str] = None
client_common_name: Optional[str] = None
organization_id: Optional[str] = None
ca: Optional[str] = None


class TransactionStateDB(BaseDB):
Expand Down
40 changes: 28 additions & 12 deletions src/auth_server/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
from jwcrypto.jwk import JWK
from loguru import logger

from auth_server.cert_utils import cert_signed_by_ca, cert_within_validity_period, is_cert_revoked
from auth_server.cert_utils import (
cert_signed_by_ca,
cert_within_validity_period,
get_issuer_cn,
get_org_id_from_cert,
get_subject_cn,
is_cert_revoked,
load_pem_from_str,
)
from auth_server.config import AuthServerConfig
from auth_server.context import ContextRequest
from auth_server.db.transaction_state import (
Expand Down Expand Up @@ -553,7 +561,7 @@ async def create_claims(self) -> MDQClaims:
source = self.config.mdq_server # Default source to mdq server if registrationAuthority is not set

base_claims = await super().create_claims()
return MDQClaims(**base_claims.dict(exclude_none=True), entity_id=entity_id, scopes=scopes, source=source)
return MDQClaims(**base_claims.model_dump(exclude_none=True), entity_id=entity_id, scopes=scopes, source=source)


class TLSFEDFlow(OnlyMTLSProofFlow):
Expand Down Expand Up @@ -614,25 +622,33 @@ async def validate_proof(self) -> Optional[GrantResponse]:
if not self.state.proof_ok:
raise NextFlowException(status_code=401, detail="client certificate does not match grant request")

if not cert_within_validity_period(cert=self.request.context.client_cert):
raise NextFlowException(status_code=401, detail="client certificate expired")
client_cert = load_pem_from_str(self.request.context.client_cert)
if not cert_within_validity_period(cert=client_cert):
raise StopTransactionException(status_code=401, detail="client certificate expired or not yet valid")

ca_name = cert_signed_by_ca(cert=client_cert)
if ca_name is None:
raise StopTransactionException(status_code=401, detail="client certificate not signed by CA")

ca_cert = cert_signed_by_ca(cert=self.request.context.client_cert)
if ca_cert is None:
raise NextFlowException(status_code=401, detail="client certificate not signed by CA")
if await is_cert_revoked(cert=client_cert, ca_name=ca_name) is True:
raise StopTransactionException(status_code=401, detail="client certificate revoked")

if await is_cert_revoked(cert=self.request.context.client_cert, ca_cert=ca_cert) is True:
raise NextFlowException(status_code=401, detail="client certificate revoked")
# set client CN and issuer CN in state for use in claims
self.state.client_common_name = get_subject_cn(cert=client_cert)
self.state.issuer_common_name = get_issuer_cn(ca_name=ca_name)
# try to get an organization id from the client certificate
self.state.organization_id = get_org_id_from_cert(cert=client_cert, ca_name=ca_name)

return None

async def create_claims(self) -> CAClaims:
if not self.state.organization_id:
raise NextFlowException(status_code=400, detail="missing organization id")
if self.config.ca_certs_mandatory_org_id and self.state.organization_id is None:
raise StopTransactionException(status_code=401, detail="missing organization id in client certificate")

base_claims = await super().create_claims()
return CAClaims(
**base_claims.model_dump(exclude_none=True),
organization_id=self.state.organization_id,
source=self.state.ca,
common_name=self.state.client_common_name,
source=self.state.issuer_common_name,
)
1 change: 1 addition & 0 deletions src/auth_server/models/claims.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ConfigClaims(Claims):


class CAClaims(Claims):
common_name: str
organization_id: Optional[str] = None


Expand Down
12 changes: 12 additions & 0 deletions src/auth_server/tests/data/ca/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Test certs:
https://eid.expisoft.se/expitrust-test-certifikat/

Convert p12 to x509 PEM certificate AND key format, use these commands:

openssl pkcs12 -in path.p12 -out newfile.crt.pem -clcerts -nokeys
openssl pkcs12 -in path.p12 -out newfile.key.pem -nocerts -nodes

Test CA certs:
https://eid.expisoft.se/expitrust-test-certifikat/
https://inera.atlassian.net/wiki/spaces/IAM/pages/289082989/PKI-struktur+och+rotcertifikat
https://repository.efos.se/
Loading

0 comments on commit 708623d

Please sign in to comment.