Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Update charm libraries #444

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lib/charms/sdcore_nms_k8s/v0/sdcore_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def _on_sdcore_config_relation_joined(self, event: RelationJoinedEvent):
import logging
from typing import Optional

from interface_tester.schema_base import DataBagSchema # type: ignore[import]
from interface_tester.schema_base import DataBagSchema
from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent
from ops.framework import EventBase, EventSource, Handle, Object
from ops.model import Relation
from ops.model import ModelError, Relation
from pydantic import BaseModel, Field, ValidationError

# The unique Charmhub library identifier, never change it
Expand All @@ -120,7 +120,7 @@ def _on_sdcore_config_relation_joined(self, event: RelationJoinedEvent):

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 1
LIBPATCH = 3

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -150,7 +150,7 @@ class SdcoreConfigProviderAppData(BaseModel):

class ProviderSchema(DataBagSchema):
"""The schema for the provider side of the sdcore-config interface."""
app: SdcoreConfigProviderAppData
app_data: SdcoreConfigProviderAppData


def data_is_valid(data: dict) -> bool:
Expand All @@ -163,7 +163,7 @@ def data_is_valid(data: dict) -> bool:
bool: True if data is valid, False otherwise.
"""
try:
ProviderSchema(app=data)
ProviderSchema(app_data=SdcoreConfigProviderAppData(**data))
return True
except ValidationError as e:
logger.error("Invalid data: %s", e)
Expand Down Expand Up @@ -207,7 +207,7 @@ class SdcoreConfigRequirerCharmEvents(CharmEvents):
class SdcoreConfigRequires(Object):
"""Class to be instantiated by the SD-Core config requirer charm."""

on = SdcoreConfigRequirerCharmEvents()
on = SdcoreConfigRequirerCharmEvents() # type: ignore

def __init__(self, charm: CharmBase, relation_name: str):
"""Init."""
Expand Down Expand Up @@ -336,4 +336,7 @@ def set_webui_url_in_all_relations(self, webui_url: str) -> None:
raise RuntimeError(f"Relation {self.relation_name} not created yet.")

for relation in relations:
relation.data[self.charm.app].update({"webui_url": webui_url})
try:
relation.data[self.charm.app].update({"webui_url": webui_url})
except ModelError as exc:
logger.error("Error updating the relation data: %s", str(exc))
66 changes: 59 additions & 7 deletions lib/charms/tls_certificates_interface/v4/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent
from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent, SecretRemoveEvent
from ops.framework import EventBase, EventSource, Handle, Object
from ops.jujuversion import JujuVersion
from ops.model import (
Expand All @@ -52,7 +52,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 1
LIBPATCH = 5

PYDEPS = ["cryptography", "pydantic"]

Expand Down Expand Up @@ -305,6 +305,37 @@ def from_string(cls, certificate: str) -> "Certificate":
validity_start_time=validity_start_time,
)

def matches_private_key(self, private_key: PrivateKey) -> bool:
"""Check if this certificate matches a given private key.

Args:
private_key (PrivateKey): The private key to validate against.

Returns:
bool: True if the certificate matches the private key, False otherwise.
"""
try:
cert_object = x509.load_pem_x509_certificate(self.raw.encode())
key_object = serialization.load_pem_private_key(
private_key.raw.encode(), password=None
)

cert_public_key = cert_object.public_key()
key_public_key = key_object.public_key()

if not isinstance(cert_public_key, rsa.RSAPublicKey):
logger.warning("Certificate does not use RSA public key")
return False

if not isinstance(key_public_key, rsa.RSAPublicKey):
logger.warning("Private key is not an RSA key")
return False

return cert_public_key.public_numbers() == key_public_key.public_numbers()
except Exception as e:
logger.warning("Failed to validate certificate and private key match: %s", e)
return False


@dataclass(frozen=True)
class CertificateSigningRequest:
Expand Down Expand Up @@ -974,6 +1005,7 @@ def __init__(
self.framework.observe(charm.on[relationship_name].relation_created, self._configure)
self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
self.framework.observe(charm.on.secret_expired, self._on_secret_expired)
self.framework.observe(charm.on.secret_remove, self._on_secret_remove)
for event in refresh_events:
self.framework.observe(event, self._configure)

Expand All @@ -993,9 +1025,20 @@ def _configure(self, _: EventBase):
self._find_available_certificates()
self._cleanup_certificate_requests()

def _mode_is_valid(self, mode) -> bool:
def _mode_is_valid(self, mode: Mode) -> bool:
return mode in [Mode.UNIT, Mode.APP]

def _on_secret_remove(self, event: SecretRemoveEvent) -> None:
"""Handle Secret Removed Event."""
try:
event.secret.remove_revision(event.revision)
except SecretNotFoundError:
logger.warning(
"No such secret %s, nothing to remove",
event.secret.label or event.secret.id,
)
return

def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
"""Handle Secret Expired Event.

Expand Down Expand Up @@ -1069,7 +1112,7 @@ def _get_app_or_unit(self) -> Union[Application, Unit]:
raise TLSCertificatesError("Invalid mode")

@property
def private_key(self) -> PrivateKey | None:
def private_key(self) -> Optional[PrivateKey]:
"""Return the private key."""
if not self._private_key_generated():
return None
Expand Down Expand Up @@ -1238,7 +1281,7 @@ def _send_certificate_requests(self):

def get_assigned_certificate(
self, certificate_request: CertificateRequestAttributes
) -> Tuple[ProviderCertificate | None, PrivateKey | None]:
) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]:
"""Get the certificate that was assigned to the given certificate request."""
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if certificate_request == CertificateRequestAttributes.from_csr(
Expand All @@ -1248,7 +1291,9 @@ def get_assigned_certificate(
return self._find_certificate_in_relation_data(requirer_csr), self.private_key
return None, None

def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateKey | None]:
def get_assigned_certificates(
self,
) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]:
"""Get a list of certificates that were assigned to this or app."""
assigned_certificates = []
for requirer_csr in self.get_csrs_from_requirer_relation_data():
Expand All @@ -1259,12 +1304,19 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK
def _find_certificate_in_relation_data(
self, csr: RequirerCertificateRequest
) -> Optional[ProviderCertificate]:
"""Return the certificate that match the given CSR."""
"""Return the certificate that matches the given CSR, validated against the private key."""
if not self.private_key:
return None
for provider_certificate in self.get_provider_certificates():
if (
provider_certificate.certificate_signing_request == csr.certificate_signing_request
and provider_certificate.certificate.is_ca == csr.is_ca
):
if not provider_certificate.certificate.matches_private_key(self.private_key):
logger.warning(
"Certificate does not match the private key. Ignoring invalid certificate."
)
continue
return provider_certificate
return None

Expand Down
Loading