From 8dbe5785e397c9a75a7b7aafa6c84ba393b5f812 Mon Sep 17 00:00:00 2001 From: Noctua Date: Tue, 19 Mar 2024 21:14:48 +0100 Subject: [PATCH] chore: update charm libraries (#290) Co-authored-by: Github Actions --- .../v0/certificate_transfer.py | 18 +- lib/charms/hydra/v0/oauth.py | 5 +- .../v0/kubernetes_compute_resources_patch.py | 8 +- lib/charms/tempo_k8s/v1/charm_tracing.py | 16 +- lib/charms/tempo_k8s/v1/tracing.py | 16 +- .../v2/tls_certificates.py | 363 +++++++++++------- 6 files changed, 256 insertions(+), 170 deletions(-) diff --git a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py index 44ddfdae..b07b8355 100644 --- a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py +++ b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py @@ -21,7 +21,9 @@ from ops.charm import CharmBase, RelationJoinedEvent from ops.main import main -from lib.charms.certificate_transfer_interface.v0.certificate_transfer import CertificateTransferProvides # noqa: E501 W505 +from lib.charms.certificate_transfer_interface.v0.certificate_transfer import( + CertificateTransferProvides, +) class DummyCertificateTransferProviderCharm(CharmBase): @@ -36,7 +38,9 @@ def _on_certificates_relation_joined(self, event: RelationJoinedEvent): certificate = "my certificate" ca = "my CA certificate" chain = ["certificate 1", "certificate 2"] - self.certificate_transfer.set_certificate(certificate=certificate, ca=ca, chain=chain, relation_id=event.relation.id) + self.certificate_transfer.set_certificate( + certificate=certificate, ca=ca, chain=chain, relation_id=event.relation.id + ) if __name__ == "__main__": @@ -95,7 +99,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): import json import logging -from typing import List +from typing import List, Mapping from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent @@ -109,7 +113,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 7 PYDEPS = ["jsonschema"] @@ -210,7 +214,7 @@ def restore(self, snapshot: dict): self.relation_id = snapshot["relation_id"] -def _load_relation_data(raw_relation_data: dict) -> dict: +def _load_relation_data(raw_relation_data: Mapping[str, str]) -> dict: """Load relation data from the relation data bag. Args: @@ -313,7 +317,7 @@ def remove_certificate(self, relation_id: int) -> None: class CertificateTransferRequires(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - on = CertificateTransferRequirerCharmEvents() + on = CertificateTransferRequirerCharmEvents() # type: ignore def __init__( self, @@ -379,7 +383,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handler triggered on relation broken event. + """Handle relation broken event. Args: event: Juju event diff --git a/lib/charms/hydra/v0/oauth.py b/lib/charms/hydra/v0/oauth.py index 6d8ed1ef..a12137c7 100644 --- a/lib/charms/hydra/v0/oauth.py +++ b/lib/charms/hydra/v0/oauth.py @@ -74,7 +74,7 @@ def _set_client_config(self): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 6 logger = logging.getLogger(__name__) @@ -395,9 +395,6 @@ def _on_relation_broken_event(self, event: RelationBrokenEvent) -> None: self.on.oauth_info_removed.emit() def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - if not self.model.unit.is_leader(): - return - data = event.relation.data[event.app] if not data: logger.info("No relation data available.") diff --git a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py index a6ad4dfb..2ab8a22c 100644 --- a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py +++ b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py @@ -133,7 +133,7 @@ def setUp(self, *unused): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 7 _Decimal = Union[Decimal, float, str, int] # types that are potentially convertible to Decimal @@ -364,7 +364,7 @@ def is_patched(self, resource_reqs: ResourceRequirements) -> bool: Returns: bool: A boolean indicating if the service patch has been applied. """ - return equals_canonically(self.get_templated(), resource_reqs) + return equals_canonically(self.get_templated(), resource_reqs) # pyright: ignore def get_templated(self) -> Optional[ResourceRequirements]: """Returns the resource limits specified in the StatefulSet template.""" @@ -397,8 +397,8 @@ def is_ready(self, pod_name, resource_reqs: ResourceRequirements): self.get_templated(), self.get_actual(pod_name), ) - return self.is_patched(resource_reqs) and equals_canonically( - resource_reqs, self.get_actual(pod_name) + return self.is_patched(resource_reqs) and equals_canonically( # pyright: ignore + resource_reqs, self.get_actual(pod_name) # pyright: ignore ) def apply(self, resource_reqs: ResourceRequirements) -> None: diff --git a/lib/charms/tempo_k8s/v1/charm_tracing.py b/lib/charms/tempo_k8s/v1/charm_tracing.py index 64ac0bd8..c146e6d3 100644 --- a/lib/charms/tempo_k8s/v1/charm_tracing.py +++ b/lib/charms/tempo_k8s/v1/charm_tracing.py @@ -146,7 +146,7 @@ def my_tracing_endpoint(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 2 PYDEPS = ["opentelemetry-exporter-otlp-proto-http>=1.21.0"] @@ -200,15 +200,12 @@ def _get_tracer() -> Optional[Tracer]: return tracer.get() except LookupError: try: - logger.debug("tracer was not found in context variable, looking up in default context") ctx: Context = copy_context() if context_tracer := _get_tracer_from_context(ctx): return context_tracer.get() else: - logger.debug("Couldn't find context var for tracer: span will be skipped") return None except LookupError as err: - logger.debug(f"Couldn't find tracer: span will be skipped, err: {err}") return None @@ -219,7 +216,6 @@ def _span(name: str) -> Generator[Optional[Span], Any, Any]: with tracer.start_as_current_span(name) as span: yield cast(Span, span) else: - logger.debug("tracer not found") yield None @@ -243,9 +239,9 @@ def _get_tracing_endpoint(tracing_endpoint_getter, self, charm): tracing_endpoint = tracing_endpoint_getter(self) if tracing_endpoint is None: - logger.warning( - f"{charm}.{getattr(tracing_endpoint_getter, '__qualname__', str(tracing_endpoint_getter))} " - f"returned None; continuing with tracing DISABLED." + logger.debug( + "Charm tracing is disabled. Tracing endpoint is not defined - " + "tracing is not available or relation is not set." ) return elif not isinstance(tracing_endpoint, str): @@ -266,7 +262,7 @@ def _get_server_cert(server_cert_getter, self, charm): if server_cert is None: logger.warning( - f"{charm}.{server_cert_getter} returned None; continuing with INSECURE connection." + f"{charm}.{server_cert_getter} returned None; sending traces over INSECURE connection." ) return elif not Path(server_cert).is_absolute(): @@ -274,7 +270,6 @@ def _get_server_cert(server_cert_getter, self, charm): f"{charm}.{server_cert_getter} should return a valid tls cert absolute path (string | Path)); " f"got {server_cert} instead." ) - logger.debug("Certificate successfully retrieved.") # todo: some more validation? return server_cert @@ -300,7 +295,6 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): original_event_context = framework._event_context - logging.debug("Initializing opentelemetry tracer...") _service_name = service_name or self.app.name resource = Resource.create( diff --git a/lib/charms/tempo_k8s/v1/tracing.py b/lib/charms/tempo_k8s/v1/tracing.py index 79ddebf7..3ffcc044 100644 --- a/lib/charms/tempo_k8s/v1/tracing.py +++ b/lib/charms/tempo_k8s/v1/tracing.py @@ -1,4 +1,4 @@ -# Copyright 2022 Pietro Pasotti +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. """## Overview. @@ -93,7 +93,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 PYDEPS = ["pydantic>=2"] @@ -151,8 +151,12 @@ def load(cls, databag: MutableMapping): try: return cls.parse_raw(json.dumps(data)) # type: ignore except pydantic.ValidationError as e: - msg = f"failed to validate databag: {databag}" - logger.error(msg, exc_info=True) + if not data: + # databag is empty; this is usually expected + raise DataValidationError("empty databag") + + msg = f"failed to validate databag contents: {data!r} as {cls}" + logger.debug(msg, exc_info=True) raise DataValidationError(msg) from e def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): @@ -194,8 +198,8 @@ class TracingProviderAppData(DatabagModel): # noqa: D101 class _AutoSnapshotEvent(RelationEvent): - __args__ = () # type: Tuple[str, ...] - __optional_kwargs__ = {} # type: Dict[str, Any] + __args__: Tuple[str, ...] = () + __optional_kwargs__: Dict[str, Any] = {} @classmethod def __attrs__(cls): diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index b8855bea..9f67833b 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -277,7 +277,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven import logging import uuid from contextlib import suppress -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union @@ -286,8 +286,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 -from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import-untyped] +from jsonschema import exceptions, validate from ops.charm import ( CharmBase, CharmEvents, @@ -308,13 +307,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 21 +LIBPATCH = 28 PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/requirer.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -349,7 +348,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/provider.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -441,7 +440,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -450,7 +449,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] @@ -474,11 +473,11 @@ def __init__(self, handle, certificate: str, expiry: str): self.expiry = expiry def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {"certificate": self.certificate, "expiry": self.expiry} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.expiry = snapshot["expiry"] @@ -503,7 +502,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "reason": self.reason, "certificate_signing_request": self.certificate_signing_request, @@ -513,7 +512,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.reason = snapshot["reason"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.certificate = snapshot["certificate"] @@ -528,11 +527,11 @@ def __init__(self, handle: Handle): super().__init__(handle) def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" pass @@ -552,7 +551,7 @@ def __init__( self.is_ca = is_ca def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, @@ -560,7 +559,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] self.is_ca = snapshot["is_ca"] @@ -584,7 +583,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -593,7 +592,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] @@ -601,7 +600,7 @@ def restore(self, snapshot: dict): def _load_relation_data(relation_data_content: RelationDataContent) -> dict: - """Loads relation data from the relation data bag. + """Load relation data from the relation data bag. Json loads all data. @@ -611,7 +610,7 @@ def _load_relation_data(relation_data_content: RelationDataContent) -> dict: Returns: dict: Relation data in dict format. """ - certificate_data = dict() + certificate_data = {} try: for key in relation_data_content: try: @@ -623,6 +622,42 @@ def _load_relation_data(relation_data_content: RelationDataContent) -> dict: return certificate_data +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time + if datetime.now(timezone.utc) < expiry_notification_time + else expiry_time + ) + + +def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: + """Extract expiry time from a certificate string. + + Args: + certificate (str): x509 certificate as a string + + Returns: + Optional[datetime]: Expiry datetime or None + """ + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + return certificate_object.not_valid_after_utc + except ValueError: + logger.warning("Could not load certificate.") + return None + + def generate_ca( private_key: bytes, subject: str, @@ -630,11 +665,11 @@ def generate_ca( validity: int = 365, country: str = "US", ) -> bytes: - """Generates a CA Certificate. + """Generate a CA Certificate. Args: private_key (bytes): Private key - subject (str): Certificate subject + subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). private_key_password (bytes): Private key password validity (int): Certificate validity time (in days) country (str): Certificate Issuing country @@ -645,7 +680,7 @@ def generate_ca( private_key_object = serialization.load_pem_private_key( private_key, password=private_key_password ) - subject = issuer = x509.Name( + subject_name = x509.Name( [ x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), @@ -668,12 +703,12 @@ def generate_ca( ) cert = ( x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) + .subject_name(subject_name) + .issuer_name(subject_name) .public_key(private_key_object.public_key()) # type: ignore[arg-type] .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) .add_extension( x509.AuthorityKeyIdentifier( @@ -699,7 +734,7 @@ def get_certificate_extensions( alt_names: Optional[List[str]], is_ca: bool, ) -> List[x509.Extension]: - """Generates a list of certificate extensions from a CSR and other known information. + """Generate a list of certificate extensions from a CSR and other known information. Args: authority_key_identifier (bytes): Authority key identifier @@ -801,7 +836,7 @@ def generate_certificate( alt_names: Optional[List[str]] = None, is_ca: bool = False, ) -> bytes: - """Generates a TLS certificate based on a CSR. + """Generate a TLS certificate based on a CSR. Args: csr (bytes): CSR @@ -827,8 +862,8 @@ def generate_certificate( .issuer_name(issuer) .public_key(csr_object.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) ) extensions = get_certificate_extensions( authority_key_identifier=ca_pem.extensions.get_extension_for_class( @@ -857,7 +892,7 @@ def generate_pfx_package( package_password: str, private_key_password: Optional[bytes] = None, ) -> bytes: - """Generates a PFX package to contain the TLS certificate and private key. + """Generate a PFX package to contain the TLS certificate and private key. Args: certificate (bytes): TLS certificate @@ -888,7 +923,7 @@ def generate_private_key( key_size: int = 2048, public_exponent: int = 65537, ) -> bytes: - """Generates a private key. + """Generate a private key. Args: password (bytes): Password for decrypting the private key @@ -905,14 +940,16 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), ) return key_bytes -def generate_csr( +def generate_csr( # noqa: C901 private_key: bytes, subject: str, add_unique_id_to_subject_name: bool = True, @@ -926,11 +963,11 @@ def generate_csr( sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, ) -> bytes: - """Generates a CSR using private key and subject. + """Generate a CSR using private key and subject. Args: private_key (bytes): Private key - subject (str): CSR Subject. + subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's subject name. Always leave to "True" when the CSR is used to request certificates using the tls-certificates relation. @@ -984,6 +1021,38 @@ def generate_csr( return signed_certificate.public_bytes(serialization.Encoding.PEM) +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1003,7 +1072,7 @@ class CertificatesRequirerCharmEvents(CharmEvents): class TLSCertificatesProvidesV2(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" - on = CertificatesProviderCharmEvents() + on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] def __init__(self, charm: CharmBase, relationship_name: str): super().__init__(charm, relationship_name) @@ -1014,12 +1083,12 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.relationship_name = relationship_name def _load_app_relation_data(self, relation: Relation) -> dict: - """Loads relation data from the application relation data bag. + """Load relation data from the application relation data bag. Json loads all data. Args: - relation_object: Relation data from the application databag + relation: Relation data from the application databag Returns: dict: Relation data in dict format. @@ -1037,7 +1106,7 @@ def _add_certificate( ca: str, chain: List[str], ) -> None: - """Adds certificate to relation data. + """Add certificate to relation data. Args: relation_id (int): Relation id @@ -1078,7 +1147,7 @@ def _remove_certificate( certificate: Optional[str] = None, certificate_signing_request: Optional[str] = None, ) -> None: - """Removes certificate from a given relation based on user provided certificate or csr. + """Remove certificate from a given relation based on user provided certificate or csr. Args: relation_id (int): Relation id @@ -1111,7 +1180,7 @@ def _remove_certificate( @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: - """Uses JSON schema validator to validate relation data content. + """Use JSON schema validator to validate relation data content. Args: certificates_data (dict): Certificate data dictionary as retrieved from relation data. @@ -1126,7 +1195,7 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: return False def revoke_all_certificates(self) -> None: - """Revokes all certificates of this provider. + """Revoke all certificates of this provider. This method is meant to be used when the Root CA has changed. """ @@ -1145,7 +1214,7 @@ def set_relation_certificate( chain: List[str], relation_id: int, ) -> None: - """Adds certificates to relation data. + """Add certificates to relation data. Args: certificate (str): Certificate @@ -1177,7 +1246,7 @@ def set_relation_certificate( ) def remove_certificate(self, certificate: str) -> None: - """Removes a given certificate from relation data. + """Remove a given certificate from relation data. Args: certificate (str): TLS Certificate @@ -1194,7 +1263,7 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None ) -> Dict[str, List[Dict[str, str]]]: - """Returns a dictionary of issued certificates. + """Return a dictionary of issued certificates. It returns certificates from all relations if relation_id is not specified. Certificates are returned per application name and CSR. @@ -1229,7 +1298,7 @@ def get_issued_certificates( return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed event. + """Handle relation changed event. Looks at the relation data and either emits: - certificate request event: If the unit relation data contains a CSR for which @@ -1276,7 +1345,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revokes certificates for which no unit has a CSR. + """Revoke certificates for which no unit has a CSR. Goes through all generated certificates and compare against the list of CSRs for all units of a given relationship. @@ -1312,7 +1381,7 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Returns CSR's for which no certificate has been issued. + """Return CSR's for which no certificate has been issued. Example return: [ { @@ -1354,7 +1423,7 @@ def get_outstanding_certificate_requests( def get_requirer_csrs( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Returns a list of requirers' CSRs grouped by unit. + """Return a list of requirers' CSRs grouped by unit. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. @@ -1393,7 +1462,7 @@ def get_requirer_csrs( def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] ) -> bool: - """Checks whether a certificate has been issued for a given CSR. + """Check whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. @@ -1414,7 +1483,7 @@ def certificate_issued_for_csr( class TLSCertificatesRequiresV2(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - on = CertificatesRequirerCharmEvents() + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] def __init__( self, @@ -1422,7 +1491,7 @@ def __init__( relationship_name: str, expiry_notification_time: int = 168, ): - """Generates/use private key and observes relation changed event. + """Generate/use private key and observes relation changed event. Args: charm: Charm object @@ -1447,7 +1516,7 @@ def __init__( @property def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: - """Returns list of requirer's CSRs from relation data. + """Return list of requirer's CSRs from relation unit data. Example: [ @@ -1465,7 +1534,7 @@ def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: @property def _provider_certificates(self) -> List[Dict[str, str]]: - """Returns list of certificates from the provider's relation data.""" + """Return list of certificates from the provider's relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1480,7 +1549,7 @@ def _provider_certificates(self) -> List[Dict[str, str]]: return provider_relation_data.get("certificates", []) def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: - """Adds CSR to relation data. + """Add CSR to relation data. Args: csr (str): Certificate Signing Request @@ -1507,7 +1576,7 @@ def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) def _remove_requirer_csr(self, csr: str) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. Args: csr (str): Certificate signing request @@ -1552,7 +1621,7 @@ def request_certificate_creation( logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. The provider of this relation is then expected to remove certificates associated to this CSR from the relation data as well and emit a request_certificate_revocation event for the @@ -1570,7 +1639,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> def request_certificate_renewal( self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes ) -> None: - """Renews certificate. + """Renew certificate. Removes old CSR from relation data and adds new one. @@ -1592,9 +1661,95 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") + def get_assigned_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + final_list.append(cert) + return final_list + + def get_expiring_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + expiry_time = _get_certificate_expiry_time(cert["certificate"]) + if not expiry_time: + continue + expiry_notification_time = expiry_time - timedelta( + hours=self.expiry_notification_time + ) + if datetime.now(timezone.utc) > expiry_notification_time: + final_list.append(cert) + return final_list + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[Dict[str, Union[bool, str]]]: + """Get the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. + + Args: + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + + Returns: + List of CSR dictionaries. For example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ + final_list = [] + for csr in self._requirer_csrs: + assert isinstance(csr["certificate_signing_request"], str) + cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + final_list.append(csr) + + return final_list + @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: - """Checks whether relation data is valid based on json schema. + """Check whether relation data is valid based on json schema. Args: certificates_data: Certificate data in dict format. @@ -1609,7 +1764,7 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: return False def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed events. + """Handle relation changed event. Goes through all providers certificates that match a requested CSR. @@ -1694,7 +1849,7 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: return _get_closest_future_time(expiry_notification_time, expiry_time) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handler triggered on relation broken event. + """Handle relation broken event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1708,7 +1863,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Triggered when a certificate is set to expire. + """Handle secret expired event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1738,7 +1893,7 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: event.secret.remove_all_revisions() return - if datetime.utcnow() < expiry_time: + if datetime.now(timezone.utc) < expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( certificate=certificate_dict["certificate"], @@ -1760,7 +1915,7 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: event.secret.remove_all_revisions() def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: - """Returns the certificate that match the given CSR.""" + """Return the certificate that match the given CSR.""" for certificate_dict in self._provider_certificates: if certificate_dict["certificate_signing_request"] != csr: continue @@ -1768,7 +1923,7 @@ def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any return None def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Triggered on update status event. + """Handle update status event. Goes through each certificate in the "certificates" relation and checks their expiry date. If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if @@ -1784,7 +1939,7 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) if not expiry_time: continue - time_difference = expiry_time - datetime.utcnow() + time_difference = expiry_time - datetime.now(timezone.utc) if time_difference.total_seconds() < 0: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( @@ -1802,71 +1957,3 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: certificate=certificate_dict["certificate"], expiry=expiry_time.isoformat(), ) - - -def csr_matches_certificate(csr: str, cert: str) -> bool: - """Check if a CSR matches a certificate. - - expects to get the original string representations. - - Args: - csr (str): Certificate Signing Request - cert (str): Certificate - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time - ) - - -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. - - Args: - certificate (str): x509 certificate as a string - - Returns: - Optional[datetime]: Expiry datetime or None - """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after - except ValueError: - logger.warning("Could not load certificate.") - return None