diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py index 10ca873..0f131de 100644 --- a/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -32,7 +32,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent +from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent, SecretRemoveEvent from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion from ops.model import ( @@ -52,7 +52,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 3 PYDEPS = ["cryptography", "pydantic"] @@ -305,6 +305,37 @@ def from_string(cls, certificate: str) -> "Certificate": validity_start_time=validity_start_time, ) + def matches_private_key(self, private_key: PrivateKey) -> bool: + """Check if this certificate matches a given private key. + + Args: + private_key (PrivateKey): The private key to validate against. + + Returns: + bool: True if the certificate matches the private key, False otherwise. + """ + try: + cert_object = x509.load_pem_x509_certificate(self.raw.encode()) + key_object = serialization.load_pem_private_key( + private_key.raw.encode(), password=None + ) + + cert_public_key = cert_object.public_key() + key_public_key = key_object.public_key() + + if not isinstance(cert_public_key, rsa.RSAPublicKey): + logger.warning("Certificate does not use RSA public key") + return False + + if not isinstance(key_public_key, rsa.RSAPublicKey): + logger.warning("Private key is not an RSA key") + return False + + return cert_public_key.public_numbers() == key_public_key.public_numbers() + except Exception as e: + logger.warning("Failed to validate certificate and private key match: %s", e) + return False + @dataclass(frozen=True) class CertificateSigningRequest: @@ -974,6 +1005,7 @@ def __init__( self.framework.observe(charm.on[relationship_name].relation_created, self._configure) self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + self.framework.observe(charm.on.secret_remove, self._on_secret_remove) for event in refresh_events: self.framework.observe(event, self._configure) @@ -996,6 +1028,10 @@ def _configure(self, _: EventBase): def _mode_is_valid(self, mode) -> bool: return mode in [Mode.UNIT, Mode.APP] + def _on_secret_remove(self, event: SecretRemoveEvent) -> None: + """Handle Secret Removed Event.""" + event.secret.remove_revision(event.revision) + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: """Handle Secret Expired Event. @@ -1069,7 +1105,7 @@ def _get_app_or_unit(self) -> Union[Application, Unit]: raise TLSCertificatesError("Invalid mode") @property - def private_key(self) -> PrivateKey | None: + def private_key(self) -> Optional[PrivateKey]: """Return the private key.""" if not self._private_key_generated(): return None @@ -1238,7 +1274,7 @@ def _send_certificate_requests(self): def get_assigned_certificate( self, certificate_request: CertificateRequestAttributes - ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: + ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]: """Get the certificate that was assigned to the given certificate request.""" for requirer_csr in self.get_csrs_from_requirer_relation_data(): if certificate_request == CertificateRequestAttributes.from_csr( @@ -1248,7 +1284,9 @@ def get_assigned_certificate( return self._find_certificate_in_relation_data(requirer_csr), self.private_key return None, None - def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateKey | None]: + def get_assigned_certificates( + self, + ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]: """Get a list of certificates that were assigned to this or app.""" assigned_certificates = [] for requirer_csr in self.get_csrs_from_requirer_relation_data(): @@ -1259,12 +1297,19 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK def _find_certificate_in_relation_data( self, csr: RequirerCertificateRequest ) -> Optional[ProviderCertificate]: - """Return the certificate that match the given CSR.""" + """Return the certificate that matches the given CSR, validated against the private key.""" + if not self.private_key: + return None for provider_certificate in self.get_provider_certificates(): if ( provider_certificate.certificate_signing_request == csr.certificate_signing_request and provider_certificate.certificate.is_ca == csr.is_ca ): + if not provider_certificate.certificate.matches_private_key(self.private_key): + logger.warning( + "Certificate does not match the private key. Ignoring invalid certificate." + ) + continue return provider_certificate return None