Skip to content

Commit

Permalink
Merge pull request #75 from canonical/correct-upgrade-libpatch
Browse files Browse the repository at this point in the history
correct lib patch on upgrade lib
  • Loading branch information
MiaAltieri authored Nov 12, 2024
2 parents 5205fcc + a2327d7 commit d2fd216
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 20 deletions.
2 changes: 1 addition & 1 deletion lib/charms/mongos/v0/upgrade_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 6
LIBPATCH = 4

logger = logging.getLogger(__name__)

Expand Down
89 changes: 70 additions & 19 deletions lib/charms/tls_certificates_interface/v3/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven
ModelError,
Relation,
RelationDataContent,
Secret,
SecretNotFoundError,
Unit,
)
Expand All @@ -317,7 +318,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 17
LIBPATCH = 23

PYDEPS = ["cryptography", "jsonschema"]

Expand Down Expand Up @@ -735,16 +736,16 @@ def calculate_expiry_notification_time(
"""
if provider_recommended_notification_time is not None:
provider_recommended_notification_time = abs(provider_recommended_notification_time)
provider_recommendation_time_delta = (
expiry_time - timedelta(hours=provider_recommended_notification_time)
provider_recommendation_time_delta = expiry_time - timedelta(
hours=provider_recommended_notification_time
)
if validity_start_time < provider_recommendation_time_delta:
return provider_recommendation_time_delta

if requirer_recommended_notification_time is not None:
requirer_recommended_notification_time = abs(requirer_recommended_notification_time)
requirer_recommendation_time_delta = (
expiry_time - timedelta(hours=requirer_recommended_notification_time)
requirer_recommendation_time_delta = expiry_time - timedelta(
hours=requirer_recommended_notification_time
)
if validity_start_time < requirer_recommendation_time_delta:
return requirer_recommendation_time_delta
Expand Down Expand Up @@ -1448,18 +1449,31 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None
Returns:
None
"""
provider_certificates = self.get_provider_certificates(relation_id)
requirer_csrs = self.get_requirer_csrs(relation_id)
provider_certificates = self.get_unsolicited_certificates(relation_id=relation_id)
for provider_certificate in provider_certificates:
self.on.certificate_revocation_request.emit(
certificate=provider_certificate.certificate,
certificate_signing_request=provider_certificate.csr,
ca=provider_certificate.ca,
chain=provider_certificate.chain,
)
self.remove_certificate(certificate=provider_certificate.certificate)

def get_unsolicited_certificates(
self, relation_id: Optional[int] = None
) -> List[ProviderCertificate]:
"""Return provider certificates for which no certificate requests exists.
Those certificates should be revoked.
"""
unsolicited_certificates: List[ProviderCertificate] = []
provider_certificates = self.get_provider_certificates(relation_id=relation_id)
requirer_csrs = self.get_requirer_csrs(relation_id=relation_id)
list_of_csrs = [csr.csr for csr in requirer_csrs]
for certificate in provider_certificates:
if certificate.csr not in list_of_csrs:
self.on.certificate_revocation_request.emit(
certificate=certificate.certificate,
certificate_signing_request=certificate.csr,
ca=certificate.ca,
chain=certificate.chain,
)
self.remove_certificate(certificate=certificate.certificate)
unsolicited_certificates.append(certificate)
return unsolicited_certificates

def get_outstanding_certificate_requests(
self, relation_id: Optional[int] = None
Expand Down Expand Up @@ -1877,8 +1891,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
"Removing secret with label %s",
f"{LIBID}-{csr_in_sha256_hex}",
)
secret = self.model.get_secret(
label=f"{LIBID}-{csr_in_sha256_hex}")
secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}")
secret.remove_all_revisions()
self.on.certificate_invalidated.emit(
reason="revoked",
Expand All @@ -1889,10 +1902,20 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
)
else:
try:
secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}")
logger.debug(
"Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}"
)
secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}")
# Juju < 3.6 will create a new revision even if the content is the same
if (
secret.get_content(refresh=True).get("certificate", "")
== certificate.certificate
):
logger.debug(
"Secret %s with correct certificate already exists",
f"{LIBID}-{csr_in_sha256_hex}",
)
continue
secret.set_content(
{"certificate": certificate.certificate, "csr": certificate.csr}
)
Expand Down Expand Up @@ -1966,17 +1989,26 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
Args:
event (SecretExpiredEvent): Juju event
"""
if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"):
csr = self._get_csr_from_secret(event.secret)
if not csr:
logger.error("Failed to get CSR from secret %s", event.secret.label)
return
csr = event.secret.get_content()["csr"]
provider_certificate = self._find_certificate_in_relation_data(csr)
if not provider_certificate:
# A secret expired but we did not find matching certificate. Cleaning up
logger.warning(
"Failed to find matching certificate for csr, cleaning up secret %s",
event.secret.label,
)
event.secret.remove_all_revisions()
return

if not provider_certificate.expiry_time:
# A secret expired but matching certificate is invalid. Cleaning up
logger.warning(
"Certificate matching csr is invalid, cleaning up secret %s",
event.secret.label,
)
event.secret.remove_all_revisions()
return

Expand Down Expand Up @@ -2008,3 +2040,22 @@ def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCerti
continue
return provider_certificate
return None

def _get_csr_from_secret(self, secret: Secret) -> Union[str, None]:
"""Extract the CSR from the secret label or content.
This function is a workaround to maintain backwards compatibility
and fix the issue reported in
https://github.com/canonical/tls-certificates-interface/issues/228
"""
try:
content = secret.get_content(refresh=True)
except SecretNotFoundError:
return None
if not (csr := content.get("csr", None)):
# In versions <14 of the Lib we were storing the CSR in the label of the secret
# The CSR now is stored int the content of the secret, which was a breaking change
# Here we get the CSR if the secret was created by an app using libpatch 14 or lower
if secret.label and secret.label.startswith(f"{LIBID}-"):
csr = secret.label[len(f"{LIBID}-") :]
return csr

0 comments on commit d2fd216

Please sign in to comment.