Skip to content

Commit

Permalink
add country code as prefix for org id
Browse files Browse the repository at this point in the history
add subject O and C to claims
  • Loading branch information
johanlundberg committed Dec 4, 2023
1 parent 0d70acd commit a011e11
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 40 deletions.
88 changes: 50 additions & 38 deletions src/auth_server/cert_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cryptography.hazmat.bindings._rust import ObjectIdentifier
from cryptography.hazmat.primitives._serialization import Encoding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.x509 import Certificate, ExtensionNotFound, load_der_x509_certificate, load_pem_x509_certificate
from cryptography.x509 import Certificate, ExtensionNotFound, Name, load_der_x509_certificate, load_pem_x509_certificate
from loguru import logger
from pki_tools import Certificate as PKIToolCertificate
from pki_tools import Chain
Expand All @@ -21,8 +21,9 @@

from auth_server.config import ConfigurationError, load_config

OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10")
OID_COMMON_NAME = ObjectIdentifier("2.5.4.3")
OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10")
OID_COUNTRY_CODE = ObjectIdentifier("2.5.4.6")
OID_SERIAL_NUMBER = ObjectIdentifier("2.5.4.5")
OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION = ObjectIdentifier("1.3.6.1.5.5.7.3.2")

Expand Down Expand Up @@ -97,29 +98,32 @@ def get_org_id_from_cert(cert: Certificate, ca_name: str) -> Optional[str]:
return None

if supported_ca is SupportedOrgIdCA.EXPITRUST:
return get_org_id_expitrust(cert=cert)
org_id = get_org_id_expitrust(cert=cert)
elif supported_ca is SupportedOrgIdCA.EFOS:
return get_org_id_efos(cert=cert)
org_id = get_org_id_efos(cert=cert)
elif supported_ca is SupportedOrgIdCA.SITHS:
return get_org_id_siths(cert=cert)
org_id = get_org_id_siths(cert=cert)
else:
logger.info(f"CA {ca_name} / {ca_org_name} is not implemented for org id extraction")
return None

if org_id is None:
return None

# Add country code as prefix to org id as TLSFED does
client_country_code = get_oid_for_name(x509_name=cert.subject, oid=OID_COUNTRY_CODE)
return f"{client_country_code}{org_id}"


def get_org_id_expitrust(cert: Certificate) -> Optional[str]:
"""
The org number is just the serial number of the certificate.
"""
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
serial_number = get_oid_for_name(x509_name=cert.subject, oid=OID_SERIAL_NUMBER)
if serial_number is None:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None
return serial_number


def get_org_id_siths(cert: Certificate) -> Optional[str]:
Expand All @@ -128,7 +132,6 @@ def get_org_id_siths(cert: Certificate) -> Optional[str]:
ex. SE5565594230-AAAA -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)

# Check that the certificate has enhancedKeyUsage clientAuthentication
try:
cert.extensions.get_extension_for_oid(OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION)
Expand All @@ -137,15 +140,12 @@ def get_org_id_siths(cert: Certificate) -> Optional[str]:
return None

# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("SE")
except IndexError:
serial_number = get_oid_for_name(x509_name=cert.subject, oid=OID_SERIAL_NUMBER)
if serial_number is None:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None
org_id, _ = serial_number.split("-")
return org_id.removeprefix("SE")


def get_org_id_efos(cert: Certificate):
Expand All @@ -155,41 +155,53 @@ def get_org_id_efos(cert: Certificate):
"""
cert_fingerprint = rfc8705_fingerprint(cert)
# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("EFOS16")
except IndexError:
serial_number = get_oid_for_name(x509_name=cert.subject, oid=OID_SERIAL_NUMBER)
if serial_number is None:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None
org_id, _ = serial_number.split("-")
return org_id.removeprefix("EFOS16")


def get_subject_cn(cert: Certificate) -> Optional[str]:
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject common name")
return None
common_name = get_oid_for_name(x509_name=cert.subject, oid=OID_COMMON_NAME)
if common_name is None:
logger.error(f"certificate {rfc8705_fingerprint(cert)} has no subject common name")
return common_name


def get_subject_c(cert: Certificate) -> Optional[str]:
country_code = get_oid_for_name(x509_name=cert.subject, oid=OID_COUNTRY_CODE)
if country_code is None:
logger.error(f"certificate {rfc8705_fingerprint(cert)} has no subject country code")
return country_code


def get_subject_o(cert: Certificate) -> Optional[str]:
org_name = get_oid_for_name(x509_name=cert.subject, oid=OID_ORGANIZATION_NAME)
if org_name is None:
logger.error(f"certificate {rfc8705_fingerprint(cert)} has no subject organization name")
return org_name


def get_issuer_cn(ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
logger.error(f"CA {ca_name} not found")
return None
issuer_common_name = get_oid_for_name(x509_name=ca_cert.subject, oid=OID_COMMON_NAME)
if issuer_common_name is None:
logger.error(f"CA {ca_name} has no subject common name")
return issuer_common_name


def get_oid_for_name(x509_name: Name, oid: ObjectIdentifier) -> Optional[str]:
try:
ret = ca_cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
ret = x509_name.get_attributes_for_oid(oid)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"CA {ca_name} has no subject common name")
return None


Expand Down
2 changes: 2 additions & 0 deletions src/auth_server/db/transaction_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class CAState(TransactionState):
auth_source: AuthSource = AuthSource.CA
issuer_common_name: Optional[str] = None
client_common_name: Optional[str] = None
client_organization_name: Optional[str] = None
client_country_code: Optional[str] = None
organization_id: Optional[str] = None


Expand Down
8 changes: 7 additions & 1 deletion src/auth_server/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
cert_within_validity_period,
get_issuer_cn,
get_org_id_from_cert,
get_subject_c,
get_subject_cn,
get_subject_o,
is_cert_revoked,
load_pem_from_str,
)
Expand Down Expand Up @@ -638,7 +640,9 @@ async def validate_proof(self) -> Optional[GrantResponse]:
self.state.issuer_common_name = get_issuer_cn(ca_name=ca_name)
# try to get an organization id from the client certificate
self.state.organization_id = get_org_id_from_cert(cert=client_cert, ca_name=ca_name)

# add extra claims from client certificate
self.state.client_organization_name = get_subject_o(cert=client_cert)
self.state.client_country_code = get_subject_c(cert=client_cert)
return None

async def create_claims(self) -> CAClaims:
Expand All @@ -650,5 +654,7 @@ async def create_claims(self) -> CAClaims:
**base_claims.model_dump(exclude_none=True),
organization_id=self.state.organization_id,
common_name=self.state.client_common_name,
organization_name=self.state.client_organization_name,
country_code=self.state.client_country_code,
source=self.state.issuer_common_name,
)
2 changes: 2 additions & 0 deletions src/auth_server/models/claims.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class ConfigClaims(Claims):

class CAClaims(Claims):
common_name: str
organization_name: Optional[str] = None
country_code: Optional[str] = None
organization_id: Optional[str] = None


Expand Down
6 changes: 5 additions & 1 deletion src/auth_server/tests/test_ca_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from auth_server.cert_utils import (
cert_signed_by_ca,
cert_within_validity_period,
get_subject_c,
get_subject_cn,
get_subject_o,
is_cert_revoked,
load_ca_certs,
rfc8705_fingerprint,
Expand Down Expand Up @@ -142,7 +144,7 @@ def _do_mtls_transaction(self, cert: Certificate) -> Response:

def test_mtls_transaction(self):
parameters = [
("bolag_a.crt", True, "165560000167"),
("bolag_a.crt", True, "SE165560000167"),
("bolag_b.crt", False, "client certificate revoked"),
("bolag_c.crt", False, "client certificate expired or not yet valid"),
("bolag_e.crt", False, "client certificate expired or not yet valid"),
Expand All @@ -163,6 +165,8 @@ def test_mtls_transaction(self):
assert claims is not None
assert claims["organization_id"] == expected_result, f"{cert_name} has wrong org id"
assert claims["common_name"] == get_subject_cn(cert=cert), f"{cert_name} has wrong common name"
assert claims["organization_name"] == get_subject_o(cert=cert), f"{cert_name} has wrong common name"
assert claims["country_code"] == get_subject_c(cert=cert), f"{cert_name} has wrong common name"
assert claims["source"] is not None
else:
assert response.status_code == 401
Expand Down

0 comments on commit a011e11

Please sign in to comment.