Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ca proof flow #26

Merged
merged 8 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
506 changes: 253 additions & 253 deletions dev_requirements.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
plugins = pydantic.mypy
432 changes: 216 additions & 216 deletions requirements.txt

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/auth_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from auth_server.config import AuthServerConfig, ConfigurationError, FlowName, load_config
from auth_server.context import ContextRequestRoute
from auth_server.flows import BaseAuthFlow, ConfigFlow, InteractionFlow, MDQFlow, TestFlow, TLSFEDFlow
from auth_server.flows import BaseAuthFlow, CAFlow, ConfigFlow, InteractionFlow, MDQFlow, TestFlow, TLSFEDFlow
from auth_server.log import init_logging
from auth_server.middleware import JOSEMiddleware
from auth_server.routers.interaction import interaction_router
Expand All @@ -28,10 +28,11 @@ def __init__(self):

# Load flows
self.builtin_flow: Dict[FlowName, Type[BaseAuthFlow]] = {
FlowName.TESTFLOW: TestFlow,
FlowName.INTERACTIONFLOW: InteractionFlow,
FlowName.CAFLOW: CAFlow,
FlowName.CONFIGFLOW: ConfigFlow,
FlowName.INTERACTIONFLOW: InteractionFlow,
FlowName.MDQFLOW: MDQFlow,
FlowName.TESTFLOW: TestFlow,
FlowName.TLSFEDFLOW: TLSFEDFlow,
}
self.auth_flows = self.load_flows(config=config)
Expand Down
237 changes: 237 additions & 0 deletions src/auth_server/cert_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# -*- coding: utf-8 -*-
__author__ = "lundberg"

from base64 import b64encode
from datetime import datetime
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Optional

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.bindings._rust import ObjectIdentifier
from cryptography.hazmat.primitives._serialization import Encoding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.x509 import Certificate, ExtensionNotFound, load_der_x509_certificate, load_pem_x509_certificate
from loguru import logger
from pki_tools import Certificate as PKIToolCertificate
from pki_tools import Chain
from pki_tools import Error as PKIToolsError
from pki_tools import is_revoked

from auth_server.config import ConfigurationError, load_config

OID_ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10")
OID_COMMON_NAME = ObjectIdentifier("2.5.4.3")
OID_SERIAL_NUMBER = ObjectIdentifier("2.5.4.5")
OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION = ObjectIdentifier("1.3.6.1.5.5.7.3.2")


class SupportedOrgIdCA(str, Enum):
EFOS = "Swedish Social Insurance Agency"
EXPITRUST = "Expisoft AB"
SITHS = "Inera AB"


def cert_within_validity_period(cert: Certificate) -> bool:
"""
check if certificate is within the validity period
"""
cert_fingerprint = rfc8705_fingerprint(cert)
now = datetime.utcnow()
if now < cert.not_valid_before:
logger.error(f"Certificate {cert_fingerprint} not valid before {cert.not_valid_before}")
return False
if now > cert.not_valid_after:
logger.error(f"Certificate {cert_fingerprint} not valid after {cert.not_valid_after}")
return False
return True


def cert_signed_by_ca(cert: Certificate) -> Optional[str]:
"""
check if the cert is signed by any on our loaded CA certs
"""
cert_fingerprint = rfc8705_fingerprint(cert)
for ca_name, ca_cert in load_ca_certs().items():
try:
cert.verify_directly_issued_by(ca_cert)
logger.debug(f"Certificate {cert_fingerprint} signed by CA cert {ca_name}")
return ca_name
except (ValueError, TypeError, InvalidSignature):
continue

logger.error(f"Certificate {cert_fingerprint} did NOT match any loaded CA cert")
return None


async def is_cert_revoked(cert: Certificate, ca_name: str) -> bool:
"""
check if cert is revoked
"""
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
raise ConfigurationError(f"CA cert {ca_name} not found")
try:
return is_revoked(
cert=PKIToolCertificate.from_cryptography(cert=cert), chain=Chain.from_cryptography([ca_cert])
)
except PKIToolsError as e:
logger.error(f"Certificate {rfc8705_fingerprint(cert)} failed revoke check: {e}")
return True


def get_org_id_from_cert(cert: Certificate, ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if not ca_cert:
raise ConfigurationError(f"CA cert {ca_name} not found")
try:
ca_org_name = ca_cert.issuer.get_attributes_for_oid(OID_ORGANIZATION_NAME)[0].value
except IndexError:
logger.error(f"CA certificate {ca_name} has no org name")
return None
try:
supported_ca = SupportedOrgIdCA(ca_org_name)
except ValueError:
logger.info(f"CA {ca_name} is not supported for org id extraction")
return None

if supported_ca is SupportedOrgIdCA.EXPITRUST:
return get_org_id_expitrust(cert=cert)
elif supported_ca is SupportedOrgIdCA.EFOS:
return get_org_id_efos(cert=cert)
elif supported_ca is SupportedOrgIdCA.SITHS:
return get_org_id_siths(cert=cert)
else:
logger.info(f"CA {ca_name} / {ca_org_name} is not implemented for org id extraction")
return None


def get_org_id_expitrust(cert: Certificate) -> Optional[str]:
"""
The org number is just the serial number of the certificate.
"""
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_org_id_siths(cert: Certificate) -> Optional[str]:
"""
The org number is the first part of the serial number of the certificate with a prefix of SE.
ex. SE5565594230-AAAA -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)

# Check that the certificate has enhancedKeyUsage clientAuthentication
try:
cert.extensions.get_extension_for_oid(OID_ENHANCED_KEY_USAGE_CLIENT_AUTHENTICATION)
except ExtensionNotFound:
logger.error(f"certificate {cert_fingerprint} has no enhancedKeyUsage clientAuthentication")
return None

# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("SE")
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_org_id_efos(cert: Certificate):
"""
The org number is the first part of the serial number of the certificate with a prefix of EFOS16.
ex. EFOS165565594230-012345 -> 5565594230
"""
cert_fingerprint = rfc8705_fingerprint(cert)
# Check that the certificate has a subject serial number and parse the org id
try:
serial_number = cert.subject.get_attributes_for_oid(OID_SERIAL_NUMBER)[0].value
if isinstance(serial_number, bytes):
serial_number = serial_number.decode("utf-8")
org_id, _ = serial_number.split("-")
return org_id.removeprefix("EFOS16")
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject serial number")
return None


def get_subject_cn(cert: Certificate) -> Optional[str]:
cert_fingerprint = rfc8705_fingerprint(cert)
try:
ret = cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"certificate {cert_fingerprint} has no subject common name")
return None


def get_issuer_cn(ca_name: str) -> Optional[str]:
ca_cert = load_ca_certs().get(ca_name)
if ca_cert is None:
logger.error(f"CA {ca_name} not found")
return None
try:
ret = ca_cert.subject.get_attributes_for_oid(OID_COMMON_NAME)[0].value
if isinstance(ret, bytes):
ret = ret.decode("utf-8")
return ret
except IndexError:
logger.error(f"CA {ca_name} has no subject common name")
return None


@lru_cache()
def load_ca_certs() -> dict[str, Certificate]:
config = load_config()
if config.ca_certs_path is None:
raise ConfigurationError("no CA certs path specified in config")
certs = {}
path = Path(config.ca_certs_path)
for crt in path.glob("**/*.c*"): # match .crt and .cer files
if crt.is_dir():
continue
try:
with open(crt, "rb") as f:
content = f.read()
try:
cert = load_pem_x509_certificate(content)
except ValueError:
cert = load_der_x509_certificate(content)
if cert_within_validity_period(cert):
certs[cert.subject.rfc4514_string()] = cert
except (IOError, ValueError) as e:
logger.error(f"Failed to load CA cert {crt}: {e}")
logger.info(f"Loaded {len(certs)} CA certs")
logger.debug(f"Certs loaded: {certs.keys()}")
return certs


def load_pem_from_str(cert: str) -> Certificate:
if not cert.startswith("-----BEGIN CERTIFICATE-----"):
cert = f"-----BEGIN CERTIFICATE-----\n{cert}\n-----END CERTIFICATE-----"
return load_pem_x509_certificate(cert.encode())


def serialize_certificate(cert: Certificate, encoding: Encoding = Encoding.PEM) -> str:
public_bytes = cert.public_bytes(encoding=encoding)
if encoding == Encoding.DER:
return b64encode(public_bytes).decode("ascii")
else:
return public_bytes.decode("ascii")


def rfc8705_fingerprint(cert: Certificate):
return b64encode(cert.fingerprint(algorithm=SHA256())).decode("ascii")
18 changes: 10 additions & 8 deletions src/auth_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class Environment(str, Enum):


class FlowName(str, Enum):
CAFLOW = "CAFlow"
CONFIGFLOW = "ConfigFlow"
MDQFLOW = "MDQFlow"
INTERACTIONFLOW = "InteractionFlow"
MDQFLOW = "MDQFlow"
TESTFLOW = "TestFlow"
TLSFEDFLOW = "TLSFEDFlow"

Expand Down Expand Up @@ -65,21 +66,22 @@ class AuthServerConfig(BaseSettings):
auth_flows: List[str] = Field(default_factory=list)
mdq_server: Optional[str] = Field(default=None)
tls_fed_metadata: List[TLSFEDMetadata] = Field(default_factory=list)
tls_fed_metadata_max_age: timedelta = Field(default="PT1H")
keystore_path: Path = Field(default="keystore.jwks")
tls_fed_metadata_max_age: timedelta = Field(default=timedelta(hours=1))
keystore_path: Path = Field(default=Path("keystore.jwks"))
signing_key_id: str = Field(default="default")
auth_token_issuer: str
auth_token_audience: Optional[str] = Field(default=None)
auth_token_expires_in: timedelta = Field(default="PT10H")
proof_jws_max_age: timedelta = Field(default="PT5M")
auth_token_expires_in: timedelta = Field(default=timedelta(hours=10))
proof_jws_max_age: timedelta = Field(default=timedelta(minutes=5))
client_keys: Dict[str, ClientKey] = Field(default_factory=dict)
mongo_uri: Optional[str] = None
transaction_state_expires_in: timedelta = Field(default="PT10M")
transaction_state_expires_in: timedelta = Field(default=timedelta(minutes=10))
pysaml2_config_path: Optional[Path] = Field(default=None)
pysaml2_config_name: str = "SAML_CONFIG"
saml2_discovery_service_url: Optional[AnyUrl] = None
saml2_single_idp: Optional[str] = None
ca_certs_path: Optional[Path] = None # all files ending with .crt will be loaded recursively. PEM and DER supported
ca_certs_path: Optional[Path] = None # all files ending with .c* will be loaded recursively. PEM and DER supported
ca_certs_mandatory_org_id: bool = False # fail grant requests where no org id is found in the certificate

@field_validator("application_root")
@classmethod
Expand Down Expand Up @@ -112,7 +114,7 @@ def load_config() -> AuthServerConfig:
config = AuthServerConfig.parse_obj(data)
else:
# config will be instantiated with env vars if there is no config file
config = AuthServerConfig() # type: ignore[call-arg]
config = AuthServerConfig()
# Save config to a file in /dev/shm for introspection
fd_int = os.open(f"/dev/shm/{config.app_name}_config.yaml", os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with open(fd_int, "w") as fd:
Expand Down
8 changes: 8 additions & 0 deletions src/auth_server/db/transaction_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class AuthSource(str, Enum):
CONFIG = "config"
MDQ = "mdq"
TLSFED = "tlsfed"
CA = "ca"
TEST = "test"


Expand Down Expand Up @@ -96,6 +97,13 @@ class TLSFEDState(TransactionState):
entity: Optional[MetadataEntity] = None


class CAState(TransactionState):
auth_source: AuthSource = AuthSource.CA
issuer_common_name: Optional[str] = None
client_common_name: Optional[str] = None
organization_id: Optional[str] = None


class TransactionStateDB(BaseDB):
def __init__(self, db_client: AsyncIOMotorClient):
super().__init__(db_client=db_client, db_name="auth_server", collection="transaction_states")
Expand Down
Loading
Loading