Skip to content

Commit

Permalink
chore: Update charm libraries (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
telcobot authored Jan 23, 2024
1 parent d1daf71 commit ccd3a5a
Showing 1 changed file with 156 additions and 72 deletions.
228 changes: 156 additions & 72 deletions lib/charms/tls_certificates_interface/v2/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 21
LIBPATCH = 22

PYDEPS = ["cryptography", "jsonschema"]

REQUIRER_JSON_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/requirer.json", # noqa: E501
"$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json",
"type": "object",
"title": "`tls_certificates` requirer root schema",
"description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501
Expand Down Expand Up @@ -349,7 +349,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

PROVIDER_JSON_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/provider.json", # noqa: E501
"$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json",
"type": "object",
"title": "`tls_certificates` provider root schema",
"description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501
Expand Down Expand Up @@ -623,6 +623,40 @@ def _load_relation_data(relation_data_content: RelationDataContent) -> dict:
return certificate_data


def _get_closest_future_time(
expiry_notification_time: datetime, expiry_time: datetime
) -> datetime:
"""Return expiry_notification_time if not in the past, otherwise return expiry_time.
Args:
expiry_notification_time (datetime): Notification time of impending expiration
expiry_time (datetime): Expiration time
Returns:
datetime: expiry_notification_time if not in the past, expiry_time otherwise
"""
return (
expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time
)


def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]:
"""Extract expiry time from a certificate string.
Args:
certificate (str): x509 certificate as a string
Returns:
Optional[datetime]: Expiry datetime or None
"""
try:
certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
return certificate_object.not_valid_after
except ValueError:
logger.warning("Could not load certificate.")
return None


def generate_ca(
private_key: bytes,
subject: str,
Expand Down Expand Up @@ -984,6 +1018,38 @@ def generate_csr(
return signed_certificate.public_bytes(serialization.Encoding.PEM)


def csr_matches_certificate(csr: str, cert: str) -> bool:
"""Check if a CSR matches a certificate.
Args:
csr (str): Certificate Signing Request as a string
cert (str): Certificate as a string
Returns:
bool: True/False depending on whether the CSR matches the certificate.
"""
try:
csr_object = x509.load_pem_x509_csr(csr.encode("utf-8"))
cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8"))

if csr_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
) != cert_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
):
return False
if (
csr_object.public_key().public_numbers().n # type: ignore[union-attr]
!= cert_object.public_key().public_numbers().n # type: ignore[union-attr]
):
return False
except ValueError:
logger.warning("Could not load certificate or CSR.")
return False
return True


class CertificatesProviderCharmEvents(CharmEvents):
"""List of events that the TLS Certificates provider charm can leverage."""

Expand Down Expand Up @@ -1447,7 +1513,7 @@ def __init__(

@property
def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]:
"""Returns list of requirer's CSRs from relation data.
"""Returns list of requirer's CSRs from relation unit data.
Example:
[
Expand Down Expand Up @@ -1592,6 +1658,92 @@ def request_certificate_renewal(
)
logger.info("Certificate renewal request completed.")

def get_assigned_certificates(self) -> List[Dict[str, str]]:
"""Get a list of certificates that were assigned to this unit.
Returns:
List of certificates. For example:
[
{
"ca": "-----BEGIN CERTIFICATE-----...",
"chain": [
"-----BEGIN CERTIFICATE-----..."
],
"certificate": "-----BEGIN CERTIFICATE-----...",
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
}
]
"""
final_list = []
for csr in self.get_certificate_signing_requests(fulfilled_only=True):
assert type(csr["certificate_signing_request"]) == str
if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]):
final_list.append(cert)
return final_list

def get_expiring_certificates(self) -> List[Dict[str, str]]:
"""Get a list of certificates that were assigned to this unit that are expiring or expired.
Returns:
List of certificates. For example:
[
{
"ca": "-----BEGIN CERTIFICATE-----...",
"chain": [
"-----BEGIN CERTIFICATE-----..."
],
"certificate": "-----BEGIN CERTIFICATE-----...",
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
}
]
"""
final_list = []
for csr in self.get_certificate_signing_requests(fulfilled_only=True):
assert type(csr["certificate_signing_request"]) == str
if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]):
expiry_time = _get_certificate_expiry_time(cert["certificate"])
if not expiry_time:
continue
expiry_notification_time = expiry_time - timedelta(
hours=self.expiry_notification_time
)
if datetime.utcnow() > expiry_notification_time:
final_list.append(cert)
return final_list

def get_certificate_signing_requests(
self,
fulfilled_only: bool = False,
unfulfilled_only: bool = False,
) -> List[Dict[str, Union[bool, str]]]:
"""Gets the list of CSR's that were sent to the provider.
You can choose to get only the CSR's that have a certificate assigned or only the CSR's
that don't.
Args:
fulfilled_only (bool): This option will discard CSRs that don't have certificates yet.
unfulfilled_only (bool): This option will discard CSRs that have certificates signed.
Returns:
List of CSR dictionaries. For example:
[
{
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
"ca": false
}
]
"""

final_list = []
for csr in self._requirer_csrs:
assert type(csr["certificate_signing_request"]) == str
cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"])
if (unfulfilled_only and cert) or (fulfilled_only and not cert):
continue
final_list.append(csr)

return final_list

@staticmethod
def _relation_data_is_valid(certificates_data: dict) -> bool:
"""Checks whether relation data is valid based on json schema.
Expand Down Expand Up @@ -1802,71 +1954,3 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None:
certificate=certificate_dict["certificate"],
expiry=expiry_time.isoformat(),
)


def csr_matches_certificate(csr: str, cert: str) -> bool:
"""Check if a CSR matches a certificate.
expects to get the original string representations.
Args:
csr (str): Certificate Signing Request
cert (str): Certificate
Returns:
bool: True/False depending on whether the CSR matches the certificate.
"""
try:
csr_object = x509.load_pem_x509_csr(csr.encode("utf-8"))
cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8"))

if csr_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
) != cert_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
):
return False
if (
csr_object.public_key().public_numbers().n # type: ignore[union-attr]
!= cert_object.public_key().public_numbers().n # type: ignore[union-attr]
):
return False
except ValueError:
logger.warning("Could not load certificate or CSR.")
return False
return True


def _get_closest_future_time(
expiry_notification_time: datetime, expiry_time: datetime
) -> datetime:
"""Return expiry_notification_time if not in the past, otherwise return expiry_time.
Args:
expiry_notification_time (datetime): Notification time of impending expiration
expiry_time (datetime): Expiration time
Returns:
datetime: expiry_notification_time if not in the past, expiry_time otherwise
"""
return (
expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time
)


def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]:
"""Extract expiry time from a certificate string.
Args:
certificate (str): x509 certificate as a string
Returns:
Optional[datetime]: Expiry datetime or None
"""
try:
certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
return certificate_object.not_valid_after
except ValueError:
logger.warning("Could not load certificate.")
return None

0 comments on commit ccd3a5a

Please sign in to comment.