From a011e11d6a5893190f6565d9430eb3c3a3f01355 Mon Sep 17 00:00:00 2001 From: Johan Lundberg Date: Mon, 4 Dec 2023 09:48:04 +0100 Subject: [PATCH] add country code as prefix for org id add subject O and C to claims --- src/auth_server/cert_utils.py | 88 ++++++++++++++----------- src/auth_server/db/transaction_state.py | 2 + src/auth_server/flows.py | 8 ++- src/auth_server/models/claims.py | 2 + src/auth_server/tests/test_ca_flow.py | 6 +- 5 files changed, 66 insertions(+), 40 deletions(-) diff --git a/src/auth_server/cert_utils.py b/src/auth_server/cert_utils.py index b25c326..50a3bbc 100644 --- a/src/auth_server/cert_utils.py +++ b/src/auth_server/cert_utils.py @@ -12,7 +12,7 @@ from cryptography.hazmat.bindings._rust import ObjectIdentifier from cryptography.hazmat.primitives._serialization import Encoding from cryptography.hazmat.primitives.hashes import SHA256 -from cryptography.x509 import Certificate, ExtensionNotFound, load_der_x509_certificate, load_pem_x509_certificate +from cryptography.x509 import Certificate, ExtensionNotFound, Name, load_der_x509_certificate, load_pem_x509_certificate from loguru import logger from pki_tools import Certificate as PKIToolCertificate from pki_tools import Chain @@ -21,8 +21,9 @@ from auth_server.config import ConfigurationError, load_config -OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10") OID_COMMON_NAME = ObjectIdentifier("2.5.4.3") +OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10") +OID_COUNTRY_CODE = ObjectIdentifier("2.5.4.6") OID_SERIAL_NUMBER = ObjectIdentifier("2.5.4.5") OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION = ObjectIdentifier("1.3.6.1.5.5.7.3.2") @@ -97,29 +98,32 @@ def get_org_id_from_cert(cert: Certificate, ca_name: str) -> Optional[str]: return None if supported_ca is SupportedOrgIdCA.EXPITRUST: - return get_org_id_expitrust(cert=cert) + org_id = get_org_id_expitrust(cert=cert) elif supported_ca is SupportedOrgIdCA.EFOS: - return get_org_id_efos(cert=cert) + org_id = get_org_id_efos(cert=cert) elif supported_ca is SupportedOrgIdCA.SITHS: - return get_org_id_siths(cert=cert) + org_id = get_org_id_siths(cert=cert) else: logger.info(f"CA {ca_name} / {ca_org_name} is not implemented for org id extraction") return None + if org_id is None: + return None + + # Add country code as prefix to org id as TLSFED does + client_country_code = get_oid_for_name(x509_name=cert.subject, oid=OID_COUNTRY_CODE) + return f"{client_country_code}{org_id}" + def get_org_id_expitrust(cert: Certificate) -> Optional[str]: """ The org number is just the serial number of the certificate. """ cert_fingerprint = rfc8705_fingerprint(cert) - try: - ret = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value - if isinstance(ret, bytes): - ret = ret.decode("utf-8") - return ret - except IndexError: + serial_number = get_oid_for_name(x509_name=cert.subject, oid=OID_SERIAL_NUMBER) + if serial_number is None: logger.error(f"certificate {cert_fingerprint} has no subject serial number") - return None + return serial_number def get_org_id_siths(cert: Certificate) -> Optional[str]: @@ -128,7 +132,6 @@ def get_org_id_siths(cert: Certificate) -> Optional[str]: ex. SE5565594230-AAAA -> 5565594230 """ cert_fingerprint = rfc8705_fingerprint(cert) - # Check that the certificate has enhancedKeyUsage clientAuthentication try: cert.extensions.get_extension_for_oid(OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION) @@ -137,15 +140,12 @@ def get_org_id_siths(cert: Certificate) -> Optional[str]: return None # Check that the certificate has a subject serial number and parse the org id - try: - serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value - if isinstance(serial_number, bytes): - serial_number = serial_number.decode("utf-8") - org_id, _ = serial_number.split("-") - return org_id.removeprefix("SE") - except IndexError: + serial_number = get_oid_for_name(x509_name=cert.subject, oid=OID_SERIAL_NUMBER) + if serial_number is None: logger.error(f"certificate {cert_fingerprint} has no subject serial number") return None + org_id, _ = serial_number.split("-") + return org_id.removeprefix("SE") def get_org_id_efos(cert: Certificate): @@ -155,27 +155,33 @@ def get_org_id_efos(cert: Certificate): """ cert_fingerprint = rfc8705_fingerprint(cert) # Check that the certificate has a subject serial number and parse the org id - try: - serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value - if isinstance(serial_number, bytes): - serial_number = serial_number.decode("utf-8") - org_id, _ = serial_number.split("-") - return org_id.removeprefix("EFOS16") - except IndexError: + serial_number = get_oid_for_name(x509_name=cert.subject, oid=OID_SERIAL_NUMBER) + if serial_number is None: logger.error(f"certificate {cert_fingerprint} has no subject serial number") return None + org_id, _ = serial_number.split("-") + return org_id.removeprefix("EFOS16") def get_subject_cn(cert: Certificate) -> Optional[str]: - cert_fingerprint = rfc8705_fingerprint(cert) - try: - ret = cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value - if isinstance(ret, bytes): - ret = ret.decode("utf-8") - return ret - except IndexError: - logger.error(f"certificate {cert_fingerprint} has no subject common name") - return None + common_name = get_oid_for_name(x509_name=cert.subject, oid=OID_COMMON_NAME) + if common_name is None: + logger.error(f"certificate {rfc8705_fingerprint(cert)} has no subject common name") + return common_name + + +def get_subject_c(cert: Certificate) -> Optional[str]: + country_code = get_oid_for_name(x509_name=cert.subject, oid=OID_COUNTRY_CODE) + if country_code is None: + logger.error(f"certificate {rfc8705_fingerprint(cert)} has no subject country code") + return country_code + + +def get_subject_o(cert: Certificate) -> Optional[str]: + org_name = get_oid_for_name(x509_name=cert.subject, oid=OID_ORGANIZATION_NAME) + if org_name is None: + logger.error(f"certificate {rfc8705_fingerprint(cert)} has no subject organization name") + return org_name def get_issuer_cn(ca_name: str) -> Optional[str]: @@ -183,13 +189,19 @@ def get_issuer_cn(ca_name: str) -> Optional[str]: if ca_cert is None: logger.error(f"CA {ca_name} not found") return None + issuer_common_name = get_oid_for_name(x509_name=ca_cert.subject, oid=OID_COMMON_NAME) + if issuer_common_name is None: + logger.error(f"CA {ca_name} has no subject common name") + return issuer_common_name + + +def get_oid_for_name(x509_name: Name, oid: ObjectIdentifier) -> Optional[str]: try: - ret = ca_cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value + ret = x509_name.get_attributes_for_oid(oid)[0].value if isinstance(ret, bytes): ret = ret.decode("utf-8") return ret except IndexError: - logger.error(f"CA {ca_name} has no subject common name") return None diff --git a/src/auth_server/db/transaction_state.py b/src/auth_server/db/transaction_state.py index 4b24d10..5e1b4f5 100644 --- a/src/auth_server/db/transaction_state.py +++ b/src/auth_server/db/transaction_state.py @@ -101,6 +101,8 @@ class CAState(TransactionState): auth_source: AuthSource = AuthSource.CA issuer_common_name: Optional[str] = None client_common_name: Optional[str] = None + client_organization_name: Optional[str] = None + client_country_code: Optional[str] = None organization_id: Optional[str] = None diff --git a/src/auth_server/flows.py b/src/auth_server/flows.py index 04bf3b3..04cbf73 100644 --- a/src/auth_server/flows.py +++ b/src/auth_server/flows.py @@ -14,7 +14,9 @@ cert_within_validity_period, get_issuer_cn, get_org_id_from_cert, + get_subject_c, get_subject_cn, + get_subject_o, is_cert_revoked, load_pem_from_str, ) @@ -638,7 +640,9 @@ async def validate_proof(self) -> Optional[GrantResponse]: self.state.issuer_common_name = get_issuer_cn(ca_name=ca_name) # try to get an organization id from the client certificate self.state.organization_id = get_org_id_from_cert(cert=client_cert, ca_name=ca_name) - + # add extra claims from client certificate + self.state.client_organization_name = get_subject_o(cert=client_cert) + self.state.client_country_code = get_subject_c(cert=client_cert) return None async def create_claims(self) -> CAClaims: @@ -650,5 +654,7 @@ async def create_claims(self) -> CAClaims: **base_claims.model_dump(exclude_none=True), organization_id=self.state.organization_id, common_name=self.state.client_common_name, + organization_name=self.state.client_organization_name, + country_code=self.state.client_country_code, source=self.state.issuer_common_name, ) diff --git a/src/auth_server/models/claims.py b/src/auth_server/models/claims.py index adab3a2..b61a36c 100644 --- a/src/auth_server/models/claims.py +++ b/src/auth_server/models/claims.py @@ -23,6 +23,8 @@ class ConfigClaims(Claims): class CAClaims(Claims): common_name: str + organization_name: Optional[str] = None + country_code: Optional[str] = None organization_id: Optional[str] = None diff --git a/src/auth_server/tests/test_ca_flow.py b/src/auth_server/tests/test_ca_flow.py index f6a6026..11416fc 100644 --- a/src/auth_server/tests/test_ca_flow.py +++ b/src/auth_server/tests/test_ca_flow.py @@ -15,7 +15,9 @@ from auth_server.cert_utils import ( cert_signed_by_ca, cert_within_validity_period, + get_subject_c, get_subject_cn, + get_subject_o, is_cert_revoked, load_ca_certs, rfc8705_fingerprint, @@ -142,7 +144,7 @@ def _do_mtls_transaction(self, cert: Certificate) -> Response: def test_mtls_transaction(self): parameters = [ - ("bolag_a.crt", True, "165560000167"), + ("bolag_a.crt", True, "SE165560000167"), ("bolag_b.crt", False, "client certificate revoked"), ("bolag_c.crt", False, "client certificate expired or not yet valid"), ("bolag_e.crt", False, "client certificate expired or not yet valid"), @@ -163,6 +165,8 @@ def test_mtls_transaction(self): assert claims is not None assert claims["organization_id"] == expected_result, f"{cert_name} has wrong org id" assert claims["common_name"] == get_subject_cn(cert=cert), f"{cert_name} has wrong common name" + assert claims["organization_name"] == get_subject_o(cert=cert), f"{cert_name} has wrong common name" + assert claims["country_code"] == get_subject_c(cert=cert), f"{cert_name} has wrong common name" assert claims["source"] is not None else: assert response.status_code == 401