Skip to content

Commit

Permalink
Merge pull request #26 from SUNET/lundberg_ca_proof
Browse files Browse the repository at this point in the history
ca proof flow
  • Loading branch information
johanlundberg authored Nov 30, 2023
2 parents c21243f + 5875279 commit 0d70acd
Show file tree
Hide file tree
Showing 31 changed files with 1,413 additions and 770 deletions.
506 changes: 253 additions & 253 deletions dev_requirements.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
plugins = pydantic.mypy
432 changes: 216 additions & 216 deletions requirements.txt

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/auth_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from auth_server.config import AuthServerConfig, ConfigurationError, FlowName, load_config
from auth_server.context import ContextRequestRoute
from auth_server.flows import BaseAuthFlow, ConfigFlow, InteractionFlow, MDQFlow, TestFlow, TLSFEDFlow
from auth_server.flows import BaseAuthFlow, CAFlow, ConfigFlow, InteractionFlow, MDQFlow, TestFlow, TLSFEDFlow
from auth_server.log import init_logging
from auth_server.middleware import JOSEMiddleware
from auth_server.routers.interaction import interaction_router
Expand All @@ -28,10 +28,11 @@ def __init__(self):

# Load flows
self.builtin_flow: Dict[FlowName, Type[BaseAuthFlow]] = {
FlowName.TESTFLOW: TestFlow,
FlowName.INTERACTIONFLOW: InteractionFlow,
FlowName.CAFLOW: CAFlow,
FlowName.CONFIGFLOW: ConfigFlow,
FlowName.INTERACTIONFLOW: InteractionFlow,
FlowName.MDQFLOW: MDQFlow,
FlowName.TESTFLOW: TestFlow,
FlowName.TLSFEDFLOW: TLSFEDFlow,
}
self.auth_flows = self.load_flows(config=config)
Expand Down
237 changes: 237 additions & 0 deletions src/auth_server/cert_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# -*- coding: utf-8 -*-
__author__ = "lundberg"

from base64 import b64encode
from datetime import datetime
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Optional

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.bindings._rust import ObjectIdentifier
from cryptography.hazmat.primitives._serialization import Encoding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.x509 import Certificate, ExtensionNotFound, load_der_x509_certificate, load_pem_x509_certificate
from loguru import logger
from pki_tools import Certificate as PKIToolCertificate
from pki_tools import Chain
from pki_tools import Error as PKIToolsError
from pki_tools import is_revoked

from auth_server.config import ConfigurationError, load_config

OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10")
OID_COMMON_NAME = ObjectIdentifier("2.5.4.3")
OID_SERIAL_NUMBER = ObjectIdentifier("2.5.4.5")
OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION = ObjectIdentifier("1.3.6.1.5.5.7.3.2")


class SupportedOrgIdCA(str, Enum):
EFOS = "Swedish Social Insurance Agency"
EXPITRUST = "Expisoft AB"
SITHS = "Inera AB"


def cert_within_validity_period(cert: Certificate) -> bool:
"""
check if certificate is within the validity period
"""
cert_fingerprint = rfc8705_fingerprint(cert)
now = datetime.utcnow()
if now < cert.not_valid_before:
logger.error(f"Certificate {cert_fingerprint} not valid before {cert.not_valid_before}")
return False
if now > cert.not_valid_after:
logger.error(f"Certificate {cert_fingerprint} not valid after {cert.not_valid_after}")
return False
return True


def cert_signed_by_ca(cert: Certificate) -> Optional[str]:
"""
check if the cert is signed by any on our loaded CA certs
"""
cert_fingerprint = rfc8705_fingerprint(cert)
for ca_name, ca_cert in load_ca_certs().items():
try:
cert.verify_directly_issued_by(ca_cert)
logger.debug(f"Certificate {cert_fingerprint} signed by CA cert {ca_name}")
return ca_name
except (ValueError, TypeError, InvalidSignature):
continue

logger.error(f"Certificate {cert_fingerprint} did NOT match any loaded CA cert")
return None


async def is_cert_revoked(cert: Certificate, ca_name: str) -> bool:
"""
check if cert is revoked
"""
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
raise ConfigurationError(f"CA cert {ca_name} not found")
try:
return is_revoked(
cert=PKIToolCertificate.from_cryptography(cert=cert), chain=Chain.from_cryptography([ca_cert])
)
except PKIToolsError as e:
logger.error(f"Certificate {rfc8705_fingerprint(cert)} failed revoke check: {e}")
return True


def get_org_id_from_cert(cert: Certificate, ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if not ca_cert:
raise ConfigurationError(f"CA cert {ca_name} not found")
try:
ca_org_name = ca_cert.issuer.get_attributes_for_oid(OID_ORGANIZATION_NAME)[0].value
except IndexError:
logger.error(f"CA certificate {ca_name} has no org name")
return None
try:
supported_ca = SupportedOrgIdCA(ca_org_name)
except ValueError:
logger.info(f"CA {ca_name} is not supported for org id extraction")
return None

if supported_ca is SupportedOrgIdCA.EXPITRUST:
return get_org_id_expitrust(cert=cert)
elif supported_ca is SupportedOrgIdCA.EFOS:
return get_org_id_efos(cert=cert)
elif supported_ca is SupportedOrgIdCA.SITHS:
return get_org_id_siths(cert=cert)
else:
logger.info(f"CA {ca_name} / {ca_org_name} is not implemented for org id extraction")
return None


def get_org_id_expitrust(cert: Certificate) -> Optional[str]:
"""
The org number is just the serial number of the certificate.
"""
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_org_id_siths(cert: Certificate) -> Optional[str]:
"""
The org number is the first part of the serial number of the certificate with a prefix of SE.
ex. SE5565594230-AAAA -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)

# Check that the certificate has enhancedKeyUsage clientAuthentication
try:
cert.extensions.get_extension_for_oid(OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION)
except ExtensionNotFound:
logger.error(f"certificate {cert_fingerprint} has no enhancedKeyUsage clientAuthentication")
return None

# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("SE")
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_org_id_efos(cert: Certificate):
"""
The org number is the first part of the serial number of the certificate with a prefix of EFOS16.
ex. EFOS165565594230-012345 -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)
# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("EFOS16")
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_subject_cn(cert: Certificate) -> Optional[str]:
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject common name")
return None


def get_issuer_cn(ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
logger.error(f"CA {ca_name} not found")
return None
try:
ret = ca_cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"CA {ca_name} has no subject common name")
return None


@lru_cache()
def load_ca_certs() -> dict[str, Certificate]:
config = load_config()
if config.ca_certs_path is None:
raise ConfigurationError("no CA certs path specified in config")
certs = {}
path = Path(config.ca_certs_path)
for crt in path.glob("**/*.c*"): # match .crt and .cer files
if crt.is_dir():
continue
try:
with open(crt, "rb") as f:
content = f.read()
try:
cert = load_pem_x509_certificate(content)
except ValueError:
cert = load_der_x509_certificate(content)
if cert_within_validity_period(cert):
certs[cert.subject.rfc4514_string()] = cert
except (IOError, ValueError) as e:
logger.error(f"Failed to load CA cert {crt}: {e}")
logger.info(f"Loaded {len(certs)} CA certs")
logger.debug(f"Certs loaded: {certs.keys()}")
return certs


def load_pem_from_str(cert: str) -> Certificate:
if not cert.startswith("-----BEGIN CERTIFICATE-----"):
cert = f"-----BEGIN CERTIFICATE-----\n{cert}\n-----END CERTIFICATE-----"
return load_pem_x509_certificate(cert.encode())


def serialize_certificate(cert: Certificate, encoding: Encoding = Encoding.PEM) -> str:
public_bytes = cert.public_bytes(encoding=encoding)
if encoding == Encoding.DER:
return b64encode(public_bytes).decode("ascii")
else:
return public_bytes.decode("ascii")


def rfc8705_fingerprint(cert: Certificate):
return b64encode(cert.fingerprint(algorithm=SHA256())).decode("ascii")
18 changes: 10 additions & 8 deletions src/auth_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class Environment(str, Enum):


class FlowName(str, Enum):
CAFLOW = "CAFlow"
CONFIGFLOW = "ConfigFlow"
MDQFLOW = "MDQFlow"
INTERACTIONFLOW = "InteractionFlow"
MDQFLOW = "MDQFlow"
TESTFLOW = "TestFlow"
TLSFEDFLOW = "TLSFEDFlow"

Expand Down Expand Up @@ -65,21 +66,22 @@ class AuthServerConfig(BaseSettings):
auth_flows: List[str] = Field(default_factory=list)
mdq_server: Optional[str] = Field(default=None)
tls_fed_metadata: List[TLSFEDMetadata] = Field(default_factory=list)
tls_fed_metadata_max_age: timedelta = Field(default="PT1H")
keystore_path: Path = Field(default="keystore.jwks")
tls_fed_metadata_max_age: timedelta = Field(default=timedelta(hours=1))
keystore_path: Path = Field(default=Path("keystore.jwks"))
signing_key_id: str = Field(default="default")
auth_token_issuer: str
auth_token_audience: Optional[str] = Field(default=None)
auth_token_expires_in: timedelta = Field(default="PT10H")
proof_jws_max_age: timedelta = Field(default="PT5M")
auth_token_expires_in: timedelta = Field(default=timedelta(hours=10))
proof_jws_max_age: timedelta = Field(default=timedelta(minutes=5))
client_keys: Dict[str, ClientKey] = Field(default_factory=dict)
mongo_uri: Optional[str] = None
transaction_state_expires_in: timedelta = Field(default="PT10M")
transaction_state_expires_in: timedelta = Field(default=timedelta(minutes=10))
pysaml2_config_path: Optional[Path] = Field(default=None)
pysaml2_config_name: str = "SAML_CONFIG"
saml2_discovery_service_url: Optional[AnyUrl] = None
saml2_single_idp: Optional[str] = None
ca_certs_path: Optional[Path] = None # all files ending with .crt will be loaded recursively. PEM and DER supported
ca_certs_path: Optional[Path] = None # all files ending with .c* will be loaded recursively. PEM and DER supported
ca_certs_mandatory_org_id: bool = False # fail grant requests where no org id is found in the certificate

@field_validator("application_root")
@classmethod
Expand Down Expand Up @@ -112,7 +114,7 @@ def load_config() -> AuthServerConfig:
config = AuthServerConfig.parse_obj(data)
else:
# config will be instantiated with env vars if there is no config file
config = AuthServerConfig() # type: ignore[call-arg]
config = AuthServerConfig()
# Save config to a file in /dev/shm for introspection
fd_int = os.open(f"/dev/shm/{config.app_name}_config.yaml", os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with open(fd_int, "w") as fd:
Expand Down
8 changes: 8 additions & 0 deletions src/auth_server/db/transaction_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class AuthSource(str, Enum):
CONFIG = "config"
MDQ = "mdq"
TLSFED = "tlsfed"
CA = "ca"
TEST = "test"


Expand Down Expand Up @@ -96,6 +97,13 @@ class TLSFEDState(TransactionState):
entity: Optional[MetadataEntity] = None


class CAState(TransactionState):
auth_source: AuthSource = AuthSource.CA
issuer_common_name: Optional[str] = None
client_common_name: Optional[str] = None
organization_id: Optional[str] = None


class TransactionStateDB(BaseDB):
def __init__(self, db_client: AsyncIOMotorClient):
super().__init__(db_client=db_client, db_name="auth_server", collection="transaction_states")
Expand Down
Loading

0 comments on commit 0d70acd

Please sign in to comment.