diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py index cbdd80d1..dcc6d94d 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -111,6 +111,7 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, + recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -276,13 +277,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy +import ipaddress import json import logging import uuid from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from ipaddress import IPv4Address from typing import List, Literal, Optional, Union from cryptography import x509 @@ -316,7 +317,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 10 +LIBPATCH = 17 PYDEPS = ["cryptography", "jsonschema"] @@ -453,11 +454,35 @@ class ProviderCertificate: ca: str chain: List[str] revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None def chain_as_pem(self) -> str: """Return full certificate chain as a PEM string.""" return "\n\n".join(reversed(self.chain)) + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -682,21 +707,49 @@ def _get_closest_future_time( ) -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. Args: - certificate (str): x509 certificate as a string + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. Returns: - Optional[datetime]: Expiry datetime or None + datetime: Time to notify the user about the certificate expiry. """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after_utc - except ValueError: - logger.warning("Could not load certificate.") - return None + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = expiry_time - timedelta( + hours=provider_recommended_notification_time + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = expiry_time - timedelta( + hours=requirer_recommended_notification_time + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) def generate_ca( @@ -965,6 +1018,8 @@ def generate_csr( # noqa: C901 organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -983,6 +1038,8 @@ def generate_csr( # noqa: C901 organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) @@ -1008,13 +1065,19 @@ def generate_csr( # noqa: C901 subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -1030,6 +1093,13 @@ def generate_csr( # noqa: C901 return signed_certificate.public_bytes(serialization.Encoding.PEM) +def get_sha256_hex(data: str) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data.encode()) + return digest.finalize().hex() + + def csr_matches_certificate(csr: str, cert: str) -> bool: """Check if a CSR matches a certificate. @@ -1039,25 +1109,16 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: Returns: bool: True/False depending on whether the CSR matches the certificate. """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): return False return True @@ -1135,6 +1196,7 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificate to relation data. @@ -1144,6 +1206,8 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. Returns: None @@ -1161,6 +1225,7 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, } provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) @@ -1227,6 +1292,7 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificates to relation data. @@ -1236,6 +1302,8 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. Returns: None @@ -1257,6 +1325,7 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: @@ -1310,6 +1379,13 @@ def get_provider_certificates( provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) for certificate in provider_certificates: + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() + ) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue provider_certificate = ProviderCertificate( relation_id=relation.id, application_name=relation.app.name, @@ -1318,6 +1394,10 @@ def get_provider_certificates( ca=certificate["ca"], chain=certificate["chain"], revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), ) certificates.append(provider_certificate) return certificates @@ -1475,15 +1555,17 @@ def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: int = 168, + expiry_notification_time: Optional[int] = None, ): """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Time difference between now and expiry (in hours). - Used to trigger the CertificateExpiring event. Default: 7 days. + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. """ super().__init__(charm, relationship_name) if not JujuVersion.from_environ().has_secrets: @@ -1544,9 +1626,25 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: if not certificate: logger.warning("No certificate found in relation data - Skipping") continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue ca = provider_certificate_dict.get("ca") chain = provider_certificate_dict.get("chain", []) csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) if not csr: logger.warning("No CSR found in relation data - Skipping") continue @@ -1559,6 +1657,8 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: ca=ca, chain=chain, revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, ) provider_certificates.append(provider_certificate) return provider_certificates @@ -1708,13 +1808,9 @@ def get_expiring_certificates(self) -> List[ProviderCertificate]: expiring_certificates: List[ProviderCertificate] = [] for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - expiry_time = _get_certificate_expiry_time(cert.certificate) - if not expiry_time: + if not cert.expiry_time or not cert.expiry_notification_time: continue - expiry_notification_time = expiry_time - timedelta( - hours=self.expiry_notification_time - ) - if datetime.now(timezone.utc) > expiry_notification_time: + if datetime.now(timezone.utc) > cert.expiry_notification_time: expiring_certificates.append(cert) return expiring_certificates @@ -1774,9 +1870,14 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ] for certificate in provider_certificates: if certificate.csr in requirer_csrs: + csr_in_sha256_hex = get_sha256_hex(certificate.csr) if certificate.revoked: with suppress(SecretNotFoundError): - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + logger.debug( + "Removing secret with label %s", + f"{LIBID}-{csr_in_sha256_hex}", + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", @@ -1787,16 +1888,24 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ) else: try: - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") - secret.set_content({"certificate": certificate.certificate}) + logger.debug( + "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + secret.set_content( + {"certificate": certificate.certificate, "csr": certificate.csr} + ) secret.set_info( - expire=self._get_next_secret_expiry_time(certificate.certificate), + expire=self._get_next_secret_expiry_time(certificate), ) except SecretNotFoundError: + logger.debug( + "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) secret = self.charm.unit.add_secret( - {"certificate": certificate.certificate}, - label=f"{LIBID}-{certificate.csr}", - expire=self._get_next_secret_expiry_time(certificate.certificate), + {"certificate": certificate.certificate, "csr": certificate.csr}, + label=f"{LIBID}-{csr_in_sha256_hex}", + expire=self._get_next_secret_expiry_time(certificate), ) self.on.certificate_available.emit( certificate_signing_request=certificate.csr, @@ -1805,7 +1914,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: chain=certificate.chain, ) - def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: """Return the expiry time or expiry notification time. Extracts the expiry time from the provided certificate, calculates the @@ -1813,17 +1922,18 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: the future. Args: - certificate: x509 certificate + certificate: ProviderCertificate object Returns: Optional[datetime]: None if the certificate expiry time cannot be read, next expiry time otherwise. """ - expiry_time = _get_certificate_expiry_time(certificate) - if not expiry_time: + if not certificate.expiry_time or not certificate.expiry_notification_time: return None - expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) - return _get_closest_future_time(expiry_notification_time, expiry_time) + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: """Handle Relation Broken Event. @@ -1857,27 +1967,26 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: """ if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return - csr = event.secret.label[len(f"{LIBID}-") :] + csr = event.secret.get_content()["csr"] provider_certificate = self._find_certificate_in_relation_data(csr) if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(provider_certificate.certificate) - if not expiry_time: + if not provider_certificate.expiry_time: # A secret expired but matching certificate is invalid. Cleaning up event.secret.remove_all_revisions() return - if datetime.now(timezone.utc) < expiry_time: + if datetime.now(timezone.utc) < provider_certificate.expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( certificate=provider_certificate.certificate, - expiry=expiry_time.isoformat(), + expiry=provider_certificate.expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(provider_certificate.certificate), + expire=provider_certificate.expiry_time, ) else: logger.warning("Certificate is expired")