diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 18c5fd3..fd144cb 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run tests uses: get-woke/woke-action@v0 with: @@ -34,7 +34,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install dependencies run: python3 -m pip install tox - name: Run linters @@ -45,12 +45,23 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install dependencies run: python3 -m pip install tox - name: Run tests run: tox -e unit + type-check: + name: Type check + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install dependencies + run: python3 -m pip install tox + - name: Run tests + run: tox -e type + integration-test: strategy: fail-fast: true @@ -62,10 +73,11 @@ jobs: needs: - inclusive-naming-check - lint + - type-check - unit-test steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup operator environment uses: charmed-kubernetes/actions-operator@main with: diff --git a/.gitignore b/.gitignore index 9ceb2e5..182ce6e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ __pycache__/ .idea .vscode/ version +.ruff_cache/ diff --git a/README.md b/README.md index 2b7b968..5edc69e 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,10 @@ $ juju deploy slurmd --channel edge $ juju deploy slurmdbd --channel edge $ juju deploy mysql --channel 8.0/edge $ juju deploy mysql-router slurmdbd-mysql-router --channel dpe/edge -$ juju integrate slurmctld:slurmd slurmd:slurmd +$ juju integrate slurmctld:slurmd slurmd:slurmctld +$ juju integrate slurmctld:slurmdbd slurmdbd:slurmctld $ juju integrate slurmdbd-mysql-router:backend-database mysql:database $ juju integrate slurmdbd:database slurmdbd-mysql-router:database -$ juju integrate slurmctld:slurmdbd slurmdbd:slurmdbd ``` ## Project & Community diff --git a/charmcraft.yaml b/charmcraft.yaml index 6a6f7d1..47e7e9f 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -1,7 +1,40 @@ -# Copyright 2020 Omnivector Solutions, LLC. +# Copyright 2020-2024 Omnivector, LLC. # See LICENSE file for licensing details. - +name: slurmdbd type: charm + +assumes: + - juju + +summary: | + Slurm DBD accounting daemon. + +description: | + This charm provides slurmdbd, munged, and the bindings to other utilities + that make lifecycle operations a breeze. + + slurmdbd provides a secure enterprise-wide interface to a database for + SLURM. This is particularly useful for archiving accounting records. + +links: + contact: https://matrix.to/#/#hpc:ubuntu.com + + source: + - https://github.com/omnivector-solutions/slurmdbd-operator + + issues: + - https://github.com/omnivector-solutions/slurmdbd-operator/issues + +requires: + database: + interface: mysql_client + fluentbit: + interface: fluentbit + +provides: + slurmctld: + interface: slurmdbd + bases: - build-on: - name: ubuntu @@ -10,22 +43,3 @@ bases: - name: ubuntu channel: "22.04" architectures: [amd64] - -parts: - charm: - build-packages: [git] - charm-python-packages: [setuptools] - - # Create a version file and pack it into the charm. This is dynamically generated - # as part of the build process for a charm to ensure that the git revision of the - # charm is always recorded in this version file. - version-file: - plugin: nil - build-packages: - - git - override-build: | - VERSION=$(git -C $CRAFT_PART_SRC/../../charm/src describe --dirty --always) - echo "Setting version to $VERSION" - echo $VERSION > $CRAFT_PART_INSTALL/version - stage: - - version diff --git a/config.yaml b/config.yaml deleted file mode 100644 index e7fc520..0000000 --- a/config.yaml +++ /dev/null @@ -1,23 +0,0 @@ -options: - custom-slurm-repo: - type: string - default: "" - description: > - Use a custom repository for Slurm installation. - - This can be set to the Organization's local mirror/cache of packages and - supersedes the Omnivector repositories. Alternatively, it can be used to - track a `testing` Slurm version, e.g. by setting to - `ppa:omnivector/osd-testing`. - - Note: The configuration `custom-slurm-repo` must be set *before* - deploying the units. Changing this value after deploying the units will - not reinstall Slurm. - slurmdbd-debug: - type: string - default: info - description: > - The level of detail to provide slurmdbd daemon's logs. The default value - is `info`. If the slurmdbd daemon is initiated with `-v` or `--verbose` - options, that debug level will be preserve or restored upon - reconfiguration. diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index 8d860a6..3ce69e1 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Library to manage the relation for the data-platform products. +r"""Library to manage the relation for the data-platform products. This library contains the Requires and Provides classes for handling the relation between an application and multiple managed application supported by the data-team: -MySQL, Postgresql, MongoDB, Redis, and Kakfa. +MySQL, Postgresql, MongoDB, Redis, and Kafka. ### Database (MySQL, Postgresql, MongoDB, and Redis) @@ -144,6 +144,19 @@ def _on_cluster2_database_created(self, event: DatabaseCreatedEvent) -> None: ``` +When it's needed to check whether a plugin (extension) is enabled on the PostgreSQL +charm, you can use the is_postgresql_plugin_enabled method. To use that, you need to +add the following dependency to your charmcraft.yaml file: + +```yaml + +parts: + charm: + charm-binary-python-packages: + - psycopg[binary] + +``` + ### Provider Charm Following an example of using the DatabaseRequestedEvent, in the context of the @@ -277,23 +290,38 @@ def _on_topic_requested(self, event: TopicRequestedEvent): creating a new topic when other information other than a topic name is exchanged in the relation databag. """ -import abc + +import copy import json import logging from abc import ABC, abstractmethod -from collections import namedtuple +from collections import UserDict, namedtuple from datetime import datetime -from typing import List, Optional +from enum import Enum +from typing import ( + Callable, + Dict, + ItemsView, + KeysView, + List, + Optional, + Set, + Tuple, + Union, + ValuesView, +) +from ops import JujuVersion, Model, Secret, SecretInfo, SecretNotFoundError from ops.charm import ( CharmBase, CharmEvents, RelationChangedEvent, + RelationCreatedEvent, RelationEvent, - RelationJoinedEvent, + SecretChangedEvent, ) from ops.framework import EventSource, Object -from ops.model import Relation +from ops.model import Application, ModelError, Relation, Unit # The unique Charmhub library identifier, never change it LIBID = "6c3e6b6680d64e9c89e611d1a15f65be" @@ -303,7 +331,9 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 34 + +PYDEPS = ["ops>=2.0.0"] logger = logging.getLogger(__name__) @@ -316,7 +346,98 @@ def _on_topic_requested(self, event: TopicRequestedEvent): deleted - key that were deleted""" -def diff(event: RelationChangedEvent, bucket: str) -> Diff: +PROV_SECRET_PREFIX = "secret-" +REQ_SECRET_FIELDS = "requested-secrets" +GROUP_MAPPING_FIELD = "secret_group_mapping" +GROUP_SEPARATOR = "@" + + +class SecretGroup(str): + """Secret groups specific type.""" + + +class SecretGroupsAggregate(str): + """Secret groups with option to extend with additional constants.""" + + def __init__(self): + self.USER = SecretGroup("user") + self.TLS = SecretGroup("tls") + self.EXTRA = SecretGroup("extra") + + def __setattr__(self, name, value): + """Setting internal constants.""" + if name in self.__dict__: + raise RuntimeError("Can't set constant!") + else: + super().__setattr__(name, SecretGroup(value)) + + def groups(self) -> list: + """Return the list of stored SecretGroups.""" + return list(self.__dict__.values()) + + def get_group(self, group: str) -> Optional[SecretGroup]: + """If the input str translates to a group name, return that.""" + return SecretGroup(group) if group in self.groups() else None + + +SECRET_GROUPS = SecretGroupsAggregate() + + +class DataInterfacesError(Exception): + """Common ancestor for DataInterfaces related exceptions.""" + + +class SecretError(DataInterfacesError): + """Common ancestor for Secrets related exceptions.""" + + +class SecretAlreadyExistsError(SecretError): + """A secret that was to be added already exists.""" + + +class SecretsUnavailableError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +class SecretsIllegalUpdateError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +class IllegalOperationError(DataInterfacesError): + """To be used when an operation is not allowed to be performed.""" + + +def get_encoded_dict( + relation: Relation, member: Union[Unit, Application], field: str +) -> Optional[Dict[str, str]]: + """Retrieve and decode an encoded field from relation data.""" + data = json.loads(relation.data[member].get(field, "{}")) + if isinstance(data, dict): + return data + logger.error("Unexpected datatype for %s instead of dict.", str(data)) + + +def get_encoded_list( + relation: Relation, member: Union[Unit, Application], field: str +) -> Optional[List[str]]: + """Retrieve and decode an encoded field from relation data.""" + data = json.loads(relation.data[member].get(field, "[]")) + if isinstance(data, list): + return data + logger.error("Unexpected datatype for %s instead of list.", str(data)) + + +def set_encoded_field( + relation: Relation, + member: Union[Unit, Application], + field: str, + value: Union[str, list, Dict[str, str]], +) -> None: + """Set an encoded field from relation data.""" + relation.data[member].update({field: json.dumps(value)}) + + +def diff(event: RelationChangedEvent, bucket: Optional[Union[Unit, Application]]) -> Diff: """Retrieves the diff of the data in the relation changed databag. Args: @@ -328,305 +449,1940 @@ def diff(event: RelationChangedEvent, bucket: str) -> Diff: keys from the event relation databag. """ # Retrieve the old data from the data key in the application relation databag. - old_data = json.loads(event.relation.data[bucket].get("data", "{}")) + if not bucket: + return Diff([], [], []) + + old_data = get_encoded_dict(event.relation, bucket, "data") + + if not old_data: + old_data = {} + # Retrieve the new data from the event relation databag. - new_data = { - key: value for key, value in event.relation.data[event.app].items() if key != "data" - } + new_data = ( + {key: value for key, value in event.relation.data[event.app].items() if key != "data"} + if event.app + else {} + ) # These are the keys that were added to the databag and triggered this event. - added = new_data.keys() - old_data.keys() + added = new_data.keys() - old_data.keys() # pyright: ignore [reportAssignmentType] # These are the keys that were removed from the databag and triggered this event. - deleted = old_data.keys() - new_data.keys() + deleted = old_data.keys() - new_data.keys() # pyright: ignore [reportAssignmentType] # These are the keys that already existed in the databag, # but had their values changed. - changed = {key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key]} + changed = { + key + for key in old_data.keys() & new_data.keys() # pyright: ignore [reportAssignmentType] + if old_data[key] != new_data[key] # pyright: ignore [reportAssignmentType] + } # Convert the new_data to a serializable format and save it for a next diff check. - event.relation.data[bucket].update({"data": json.dumps(new_data)}) + set_encoded_field(event.relation, bucket, "data", new_data) # Return the diff with all possible changes. return Diff(added, changed, deleted) -# Base DataProvides and DataRequires +def leader_only(f): + """Decorator to ensure that only leader can perform given operation.""" + def wrapper(self, *args, **kwargs): + if self.component == self.local_app and not self.local_unit.is_leader(): + logger.error( + "This operation (%s()) can only be performed by the leader unit", f.__name__ + ) + return + return f(self, *args, **kwargs) -class DataProvides(Object, ABC): - """Base provides-side of the data products relation.""" - - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) - self.charm = charm - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name - self.framework.observe( - charm.on[relation_name].relation_changed, - self._on_relation_changed, - ) + wrapper.leader_only = True + return wrapper - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. - Args: - event: relation changed event. +def juju_secrets_only(f): + """Decorator to ensure that certain operations would be only executed on Juju3.""" - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - return diff(event, self.local_app) + def wrapper(self, *args, **kwargs): + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") + return f(self, *args, **kwargs) - @abstractmethod - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation data has changed.""" - raise NotImplementedError + return wrapper - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. +def dynamic_secrets_only(f): + """Decorator to ensure that certain operations would be only executed when NO static secrets are defined.""" - Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation id). - """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data + def wrapper(self, *args, **kwargs): + if self.static_secret_fields: + raise IllegalOperationError( + "Unsafe usage of statically and dynamically defined secrets, aborting." + ) + return f(self, *args, **kwargs) - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. + return wrapper - This function writes in the application data bag, therefore, - only the leader unit can call it. - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) +def either_static_or_dynamic_secrets(f): + """Decorator to ensure that static and dynamic secrets won't be used in parallel.""" - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return list(self.charm.model.relations[self.relation_name]) + def wrapper(self, *args, **kwargs): + if self.static_secret_fields and set(self.current_secret_fields) - set( + self.static_secret_fields + ): + raise IllegalOperationError( + "Unsafe usage of statically and dynamically defined secrets, aborting." + ) + return f(self, *args, **kwargs) - def set_credentials(self, relation_id: int, username: str, password: str) -> None: - """Set credentials. + return wrapper - This function writes in the application data bag, therefore, - only the leader unit can call it. - Args: - relation_id: the identifier for a particular relation. - username: user that was created. - password: password of the created user. - """ - self._update_relation_data( - relation_id, - { - "username": username, - "password": password, - }, - ) +class Scope(Enum): + """Peer relations scope.""" - def set_tls(self, relation_id: int, tls: str) -> None: - """Set whether TLS is enabled. + APP = "app" + UNIT = "unit" - Args: - relation_id: the identifier for a particular relation. - tls: whether tls is enabled (True or False). - """ - self._update_relation_data(relation_id, {"tls": tls}) - def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: - """Set the TLS CA in the application relation databag. +################################################################################ +# Secrets internal caching +################################################################################ - Args: - relation_id: the identifier for a particular relation. - tls_ca: TLS certification authority. - """ - self._update_relation_data(relation_id, {"tls_ca": tls_ca}) +class CachedSecret: + """Locally cache a secret. -class DataRequires(Object, ABC): - """Requires-side of the relation.""" + The data structure is precisely re-using/simulating as in the actual Secret Storage + """ def __init__( self, - charm, - relation_name: str, - extra_user_roles: str = None, + model: Model, + component: Union[Application, Unit], + label: str, + secret_uri: Optional[str] = None, + legacy_labels: List[str] = [], ): - """Manager of base client relations.""" - super().__init__(charm, relation_name) - self.charm = charm - self.extra_user_roles = extra_user_roles - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name - self.framework.observe( - self.charm.on[relation_name].relation_joined, self._on_relation_joined_event - ) - self.framework.observe( - self.charm.on[relation_name].relation_changed, self._on_relation_changed_event - ) + self._secret_meta = None + self._secret_content = {} + self._secret_uri = secret_uri + self.label = label + self._model = model + self.component = component + self.legacy_labels = legacy_labels + self.current_label = None + + def add_secret( + self, + content: Dict[str, str], + relation: Optional[Relation] = None, + label: Optional[str] = None, + ) -> Secret: + """Create a new secret.""" + if self._secret_uri: + raise SecretAlreadyExistsError( + "Secret is already defined with uri %s", self._secret_uri + ) - @abstractmethod - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the relation.""" - raise NotImplementedError + label = self.label if not label else label - @abstractmethod - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - raise NotImplementedError + secret = self.component.add_secret(content, label=label) + if relation and relation.app != self._model.app: + # If it's not a peer relation, grant is to be applied + secret.grant(relation) + self._secret_uri = secret.id + self._secret_meta = secret + return self._secret_meta - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + + for label in [self.label] + self.legacy_labels: + try: + self._secret_meta = self._model.get_secret(label=label) + except SecretNotFoundError: + pass + else: + if label != self.label: + self.current_label = label + break + + # If still not found, to be checked by URI, to be labelled with the proposed label + if not self._secret_meta and self._secret_uri: + self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) + return self._secret_meta + + def get_content(self) -> Dict[str, str]: + """Getting cached secret content.""" + if not self._secret_content: + if self.meta: + try: + self._secret_content = self.meta.get_content(refresh=True) + except (ValueError, ModelError) as err: + # https://bugs.launchpad.net/juju/+bug/2042596 + # Only triggered when 'refresh' is set + known_model_errors = [ + "ERROR either URI or label should be used for getting an owned secret but not both", + "ERROR secret owner cannot use --refresh", + ] + if isinstance(err, ModelError) and not any( + msg in str(err) for msg in known_model_errors + ): + raise + # Due to: ValueError: Secret owner cannot use refresh=True + self._secret_content = self.meta.get_content() + return self._secret_content + + def _move_to_new_label_if_needed(self): + """Helper function to re-create the secret with a different label.""" + if not self.current_label or not (self.meta and self._secret_meta): + return - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. - Function cannot be used in `*-relation-broken` events and will raise an exception. + # Create a new secret with the new label + old_meta = self._secret_meta + content = self._secret_meta.get_content() - Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation ID). - """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data + # I wish we could just check if we are the owners of the secret... + try: + self._secret_meta = self.add_secret(content, label=self.label) + except ModelError as err: + if "this unit is not the leader" not in str(err): + raise + old_meta.remove_all_revisions() + + def set_content(self, content: Dict[str, str]) -> None: + """Setting cached secret content.""" + if not self.meta: + return - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. + if content: + self._move_to_new_label_if_needed() + self.meta.set_content(content) + self._secret_content = content + else: + self.meta.remove_all_revisions() - This function writes in the application data bag, therefore, - only the leader unit can call it. + def get_info(self) -> Optional[SecretInfo]: + """Wrapper function to apply the corresponding call on the Secret object within CachedSecret if any.""" + if self.meta: + return self.meta.get_info() - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) + def remove(self) -> None: + """Remove secret.""" + if not self.meta: + raise SecretsUnavailableError("Non-existent secret was attempted to be removed.") + try: + self.meta.remove_all_revisions() + except SecretNotFoundError: + pass + self._secret_content = {} + self._secret_meta = None + self._secret_uri = None + + +class SecretCache: + """A data structure storing CachedSecret objects.""" + + def __init__(self, model: Model, component: Union[Application, Unit]): + self._model = model + self.component = component + self._secrets: Dict[str, CachedSecret] = {} + + def get( + self, label: str, uri: Optional[str] = None, legacy_labels: List[str] = [] + ) -> Optional[CachedSecret]: + """Getting a secret from Juju Secret store or cache.""" + if not self._secrets.get(label): + secret = CachedSecret( + self._model, self.component, label, uri, legacy_labels=legacy_labels + ) + if secret.meta: + self._secrets[label] = secret + return self._secrets.get(label) + + def add(self, label: str, content: Dict[str, str], relation: Relation) -> CachedSecret: + """Adding a secret to Juju Secret.""" + if self._secrets.get(label): + raise SecretAlreadyExistsError(f"Secret {label} already exists") + + secret = CachedSecret(self._model, self.component, label) + secret.add_secret(content, relation) + self._secrets[label] = secret + return self._secrets[label] + + def remove(self, label: str) -> None: + """Remove a secret from the cache.""" + if secret := self.get(label): + try: + secret.remove() + self._secrets.pop(label) + except (SecretsUnavailableError, KeyError): + pass + else: + return + logging.debug("Non-existing Juju Secret was attempted to be removed %s", label) - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. - Args: - event: relation changed event. +################################################################################ +# Relation Data base/abstract ancestors (i.e. parent classes) +################################################################################ - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - return diff(event, self.local_unit) + +# Base Data + + +class DataDict(UserDict): + """Python Standard Library 'dict' - like representation of Relation Data.""" + + def __init__(self, relation_data: "Data", relation_id: int): + self.relation_data = relation_data + self.relation_id = relation_id + + @property + def data(self) -> Dict[str, str]: + """Return the full content of the Abstract Relation Data dictionary.""" + result = self.relation_data.fetch_my_relation_data([self.relation_id]) + try: + result_remote = self.relation_data.fetch_relation_data([self.relation_id]) + except NotImplementedError: + result_remote = {self.relation_id: {}} + if result: + result_remote[self.relation_id].update(result[self.relation_id]) + return result_remote.get(self.relation_id, {}) + + def __setitem__(self, key: str, item: str) -> None: + """Set an item of the Abstract Relation Data dictionary.""" + self.relation_data.update_relation_data(self.relation_id, {key: item}) + + def __getitem__(self, key: str) -> str: + """Get an item of the Abstract Relation Data dictionary.""" + result = None + + # Avoiding "leader_only" error when cross-charm non-leader unit, not to report useless error + if ( + not hasattr(self.relation_data.fetch_my_relation_field, "leader_only") + or self.relation_data.component != self.relation_data.local_app + or self.relation_data.local_unit.is_leader() + ): + result = self.relation_data.fetch_my_relation_field(self.relation_id, key) + + if not result: + try: + result = self.relation_data.fetch_relation_field(self.relation_id, key) + except NotImplementedError: + pass + + if not result: + raise KeyError + return result + + def __eq__(self, d: dict) -> bool: + """Equality.""" + return self.data == d + + def __repr__(self) -> str: + """String representation Abstract Relation Data dictionary.""" + return repr(self.data) + + def __len__(self) -> int: + """Length of the Abstract Relation Data dictionary.""" + return len(self.data) + + def __delitem__(self, key: str) -> None: + """Delete an item of the Abstract Relation Data dictionary.""" + self.relation_data.delete_relation_data(self.relation_id, [key]) + + def has_key(self, key: str) -> bool: + """Does the key exist in the Abstract Relation Data dictionary?""" + return key in self.data + + def update(self, items: Dict[str, str]): + """Update the Abstract Relation Data dictionary.""" + self.relation_data.update_relation_data(self.relation_id, items) + + def keys(self) -> KeysView[str]: + """Keys of the Abstract Relation Data dictionary.""" + return self.data.keys() + + def values(self) -> ValuesView[str]: + """Values of the Abstract Relation Data dictionary.""" + return self.data.values() + + def items(self) -> ItemsView[str, str]: + """Items of the Abstract Relation Data dictionary.""" + return self.data.items() + + def pop(self, item: str) -> str: + """Pop an item of the Abstract Relation Data dictionary.""" + result = self.relation_data.fetch_my_relation_field(self.relation_id, item) + if not result: + raise KeyError(f"Item {item} doesn't exist.") + self.relation_data.delete_relation_data(self.relation_id, [item]) + return result + + def __contains__(self, item: str) -> bool: + """Does the Abstract Relation Data dictionary contain item?""" + return item in self.data.values() + + def __iter__(self): + """Iterate through the Abstract Relation Data dictionary.""" + return iter(self.data) + + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + """Safely get an item of the Abstract Relation Data dictionary.""" + try: + if result := self[key]: + return result + except KeyError: + return default + + +class Data(ABC): + """Base relation data mainpulation (abstract) class.""" + + SCOPE = Scope.APP + + # Local map to associate mappings with secrets potentially as a group + SECRET_LABEL_MAP = { + "username": SECRET_GROUPS.USER, + "password": SECRET_GROUPS.USER, + "uris": SECRET_GROUPS.USER, + "tls": SECRET_GROUPS.TLS, + "tls-ca": SECRET_GROUPS.TLS, + } + + def __init__( + self, + model: Model, + relation_name: str, + ) -> None: + self._model = model + self.local_app = self._model.app + self.local_unit = self._model.unit + self.relation_name = relation_name + self._jujuversion = None + self.component = self.local_app if self.SCOPE == Scope.APP else self.local_unit + self.secrets = SecretCache(self._model, self.component) + self.data_component = None @property def relations(self) -> List[Relation]: """The list of Relation instances associated with this relation_name.""" return [ relation - for relation in self.charm.model.relations[self.relation_name] + for relation in self._model.relations[self.relation_name] if self._is_relation_active(relation) ] + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + if not self._jujuversion: + self._jujuversion = JujuVersion.from_environ() + return self._jujuversion.has_secrets + + @property + def secret_label_map(self): + """Exposing secret-label map via a property -- could be overridden in descendants!""" + return self.SECRET_LABEL_MAP + + # Mandatory overrides for internal/helper methods + + @abstractmethod + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + raise NotImplementedError + + @abstractmethod + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation.""" + raise NotImplementedError + + @abstractmethod + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + @abstractmethod + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + @abstractmethod + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + # Internal helper methods + @staticmethod def _is_relation_active(relation: Relation): + """Whether the relation is active based on contained data.""" try: _ = repr(relation.data) return True - except RuntimeError: + except (RuntimeError, ModelError): return False @staticmethod - def _is_resource_created_for_relation(relation: Relation): - return ( - "username" in relation.data[relation.app] and "password" in relation.data[relation.app] - ) + def _is_secret_field(field: str) -> bool: + """Is the field in question a secret reference (URI) field or not?""" + return field.startswith(PROV_SECRET_PREFIX) - def is_resource_created(self, relation_id: Optional[int] = None) -> bool: - """Check if the resource has been created. + @staticmethod + def _generate_secret_label( + relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{relation_name}.{relation_id}.{group_mapping}.secret" - This function can be used to check if the Provider answered with data in the charm code - when outside an event callback. + def _generate_secret_field_name(self, group_mapping: SecretGroup) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{PROV_SECRET_PREFIX}{group_mapping}" - Args: - relation_id (int, optional): When provided the check is done only for the relation id - provided, otherwise the check is done for all relations + def _relation_from_secret_label(self, secret_label: str) -> Optional[Relation]: + """Retrieve the relation that belongs to a secret label.""" + contents = secret_label.split(".") - Returns: - True or False + if not (contents and len(contents) >= 3): + return - Raises: - IndexError: If relation_id is provided but that relation does not exist - """ - if relation_id is not None: - try: - relation = [relation for relation in self.relations if relation.id == relation_id][ - 0 - ] - return self._is_resource_created_for_relation(relation) - except IndexError: - raise IndexError(f"relation id {relation_id} cannot be accessed") - else: - return ( - all( - [ - self._is_resource_created_for_relation(relation) - for relation in self.relations - ] - ) - if self.relations - else False - ) + contents.pop() # ".secret" at the end + contents.pop() # Group mapping + relation_id = contents.pop() + try: + relation_id = int(relation_id) + except ValueError: + return + # In case '.' character appeared in relation name + relation_name = ".".join(contents) -# General events + try: + return self.get_relation(relation_name, relation_id) + except ModelError: + return + def _group_secret_fields(self, secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + """Helper function to arrange secret mappings under their group. -class ExtraRoleEvent(RelationEvent): - """Base class for data events.""" + NOTE: All unrecognized items end up in the 'extra' secret bucket. + Make sure only secret fields are passed! + """ + secret_fieldnames_grouped = {} + for key in secret_fields: + if group := self.secret_label_map.get(key): + secret_fieldnames_grouped.setdefault(group, []).append(key) + else: + secret_fieldnames_grouped.setdefault(SECRET_GROUPS.EXTRA, []).append(key) + return secret_fieldnames_grouped + + def _get_group_secret_contents( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Union[Set[str], List[str]] = [], + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + if (secret := self._get_relation_secret(relation.id, group)) and ( + secret_data := secret.get_content() + ): + return { + k: v for k, v in secret_data.items() if not secret_fields or k in secret_fields + } + return {} + + def _content_for_secret_group( + self, content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + ) -> Dict[str, str]: + """Select : pairs from input, that belong to this particular Secret group.""" + if group_mapping == SECRET_GROUPS.EXTRA: + return { + k: v + for k, v in content.items() + if k in secret_fields and k not in self.secret_label_map.keys() + } - @property - def extra_user_roles(self) -> Optional[str]: - """Returns the extra user roles that were requested.""" - return self.relation.data[self.relation.app].get("extra-user-roles") + return { + k: v + for k, v in content.items() + if k in secret_fields and self.secret_label_map.get(k) == group_mapping + } + + @juju_secrets_only + def _get_relation_secret_data( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[Dict[str, str]]: + """Retrieve contents of a Juju Secret that's been stored in the relation databag.""" + secret = self._get_relation_secret(relation_id, group_mapping, relation_name) + if secret: + return secret.get_content() + + # Core operations on Relation Fields manipulations (regardless whether the field is in the databag or in a secret) + # Internal functions to be called directly from transparent public interface functions (+closely related helpers) + + def _process_secret_fields( + self, + relation: Relation, + req_secret_fields: Optional[List[str]], + impacted_rel_fields: List[str], + operation: Callable, + *args, + **kwargs, + ) -> Tuple[Dict[str, str], Set[str]]: + """Isolate target secret fields of manipulation, and execute requested operation by Secret Group.""" + result = {} + + # If the relation started on a databag, we just stay on the databag + # (Rolling upgrades may result in a relation starting on databag, getting secrets enabled on-the-fly) + # self.local_app is sufficient to check (ignored if Requires, never has secrets -- works if Provider) + fallback_to_databag = ( + req_secret_fields + and (self.local_unit == self._model.unit and self.local_unit.is_leader()) + and set(req_secret_fields) & set(relation.data[self.component]) + ) + + normal_fields = set(impacted_rel_fields) + if req_secret_fields and self.secrets_enabled and not fallback_to_databag: + normal_fields = normal_fields - set(req_secret_fields) + secret_fields = set(impacted_rel_fields) - set(normal_fields) + + secret_fieldnames_grouped = self._group_secret_fields(list(secret_fields)) + + for group in secret_fieldnames_grouped: + # operation() should return nothing when all goes well + if group_result := operation(relation, group, secret_fields, *args, **kwargs): + # If "meaningful" data was returned, we take it. (Some 'operation'-s only return success/failure.) + if isinstance(group_result, dict): + result.update(group_result) + else: + # If it wasn't found as a secret, let's give it a 2nd chance as "normal" field + # Needed when Juju3 Requires meets Juju2 Provider + normal_fields |= set(secret_fieldnames_grouped[group]) + return (result, normal_fields) + + def _fetch_relation_data_without_secrets( + self, component: Union[Application, Unit], relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetching databag contents when no secrets are involved. + + Since the Provider's databag is the only one holding secrest, we can apply + a simplified workflow to read the Require's side's databag. + This is used typically when the Provider side wants to read the Requires side's data, + or when the Requires side may want to read its own data. + """ + if component not in relation.data or not relation.data[component]: + return {} + + if fields: + return { + k: relation.data[component][k] for k in fields if k in relation.data[component] + } + else: + return dict(relation.data[component]) + + def _fetch_relation_data_with_secrets( + self, + component: Union[Application, Unit], + req_secret_fields: Optional[List[str]], + relation: Relation, + fields: Optional[List[str]] = None, + ) -> Dict[str, str]: + """Fetching databag contents when secrets may be involved. + + This function has internal logic to resolve if a requested field may be "hidden" + within a Relation Secret, or directly available as a databag field. Typically + used to read the Provider side's databag (eigher by the Requires side, or by + Provider side itself). + """ + result = {} + normal_fields = [] + + if not fields: + if component not in relation.data: + return {} + + all_fields = list(relation.data[component].keys()) + normal_fields = [field for field in all_fields if not self._is_secret_field(field)] + fields = normal_fields + req_secret_fields if req_secret_fields else normal_fields + + if fields: + result, normal_fields = self._process_secret_fields( + relation, req_secret_fields, fields, self._get_group_secret_contents + ) + + # Processing "normal" fields. May include leftover from what we couldn't retrieve as a secret. + # (Typically when Juju3 Requires meets Juju2 Provider) + if normal_fields: + result.update( + self._fetch_relation_data_without_secrets(component, relation, list(normal_fields)) + ) + return result + + def _update_relation_data_without_secrets( + self, component: Union[Application, Unit], relation: Relation, data: Dict[str, str] + ) -> None: + """Updating databag contents when no secrets are involved.""" + if component not in relation.data or relation.data[component] is None: + return + + if relation: + relation.data[component].update(data) + + def _delete_relation_data_without_secrets( + self, component: Union[Application, Unit], relation: Relation, fields: List[str] + ) -> None: + """Remove databag fields 'fields' from Relation.""" + if component not in relation.data or relation.data[component] is None: + return + + for field in fields: + try: + relation.data[component].pop(field) + except KeyError: + logger.debug( + "Non-existing field '%s' was attempted to be removed from the databag (relation ID: %s)", + str(field), + str(relation.id), + ) + pass + + # Public interface methods + # Handling Relation Fields seamlessly, regardless if in databag or a Juju Secret + + def as_dict(self, relation_id: int) -> UserDict: + """Dict behavior representation of the Abstract Data.""" + return DataDict(self, relation_id) + + def get_relation(self, relation_name, relation_id) -> Relation: + """Safe way of retrieving a relation.""" + relation = self._model.get_relation(relation_name, relation_id) + + if not relation: + raise DataInterfacesError( + "Relation %s %s couldn't be retrieved", relation_name, relation_id + ) + + return relation + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: + """Retrieves data from relation. + + This function can be used to retrieve data from a relation + in the charm code when outside an event callback. + Function cannot be used in `*-relation-broken` events and will raise an exception. + + Returns: + a dict of the values stored in the relation data bag + for all relation instances (indexed by the relation ID). + """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + + data = {} + for relation in relations: + if not relation_ids or (relation_ids and relation.id in relation_ids): + data[relation.id] = self._fetch_specific_relation_data(relation, fields) + return data + + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data.""" + return ( + self.fetch_relation_data([relation_id], [field], relation_name) + .get(relation_id, {}) + .get(field) + ) + + def fetch_my_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Optional[Dict[int, Dict[str, str]]]: + """Fetch data of the 'owner' (or 'this app') side of the relation. + + NOTE: Since only the leader can read the relation's 'this_app'-side + Application databag, the functionality is limited to leaders + """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + + data = {} + for relation in relations: + if not relation_ids or relation.id in relation_ids: + data[relation.id] = self._fetch_my_specific_relation_data(relation, fields) + return data + + def fetch_my_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data -- owner side. + + NOTE: Since only the leader can read the relation's 'this_app'-side + Application databag, the functionality is limited to leaders + """ + if relation_data := self.fetch_my_relation_data([relation_id], [field], relation_name): + return relation_data.get(relation_id, {}).get(field) + + @leader_only + def update_relation_data(self, relation_id: int, data: dict) -> None: + """Update the data within the relation.""" + relation_name = self.relation_name + relation = self.get_relation(relation_name, relation_id) + return self._update_relation_data(relation, data) + + @leader_only + def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: + """Remove field from the relation.""" + relation_name = self.relation_name + relation = self.get_relation(relation_name, relation_id) + return self._delete_relation_data(relation, fields) + + +class EventHandlers(Object): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: Data, unique_key: str = ""): + """Manager of base client relations.""" + if not unique_key: + unique_key = relation_data.relation_name + super().__init__(charm, unique_key) + + self.charm = charm + self.relation_data = relation_data + + self.framework.observe( + charm.on[self.relation_data.relation_name].relation_changed, + self._on_relation_changed_event, + ) + + def _diff(self, event: RelationChangedEvent) -> Diff: + """Retrieves the diff of the data in the relation changed databag. + + Args: + event: relation changed event. + + Returns: + a Diff instance containing the added, deleted and changed + keys from the event relation databag. + """ + return diff(event, self.relation_data.data_component) + + @abstractmethod + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + +# Base ProviderData and RequiresData + + +class ProviderData(Data): + """Base provides-side of the data products relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + ) -> None: + super().__init__(model, relation_name) + self.data_component = self.local_app + + # Private methods handling secrets + + @juju_secrets_only + def _add_relation_secret( + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, + ) -> bool: + """Add a new Juju Secret that will be registered in the relation databag.""" + secret_field = self._generate_secret_field_name(group_mapping) + if uri_to_databag and relation.data[self.component].get(secret_field): + logging.error("Secret for relation %s already exists, not adding again", relation.id) + return False + + content = self._content_for_secret_group(data, secret_fields, group_mapping) + + label = self._generate_secret_label(self.relation_name, relation.id, group_mapping) + secret = self.secrets.add(label, content, relation) + + # According to lint we may not have a Secret ID + if uri_to_databag and secret.meta and secret.meta.id: + relation.data[self.component][secret_field] = secret.meta.id + + # Return the content that was added + return True + + @juju_secrets_only + def _update_relation_secret( + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + ) -> bool: + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group_mapping) + + if not secret: + logging.error("Can't update secret for relation %s", relation.id) + return False + + content = self._content_for_secret_group(data, secret_fields, group_mapping) + + old_content = secret.get_content() + full_content = copy.deepcopy(old_content) + full_content.update(content) + secret.set_content(full_content) + + # Return True on success + return True + + def _add_or_update_relation_secrets( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, + ) -> bool: + """Update contents for Secret group. If the Secret doesn't exist, create it.""" + if self._get_relation_secret(relation.id, group): + return self._update_relation_secret(relation, group, secret_fields, data) + else: + return self._add_relation_secret(relation, group, secret_fields, data, uri_to_databag) + + @juju_secrets_only + def _delete_relation_secret( + self, relation: Relation, group: SecretGroup, secret_fields: List[str], fields: List[str] + ) -> bool: + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group) + + if not secret: + logging.error("Can't delete secret for relation %s", str(relation.id)) + return False + + old_content = secret.get_content() + new_content = copy.deepcopy(old_content) + for field in fields: + try: + new_content.pop(field) + except KeyError: + logging.debug( + "Non-existing secret was attempted to be removed %s, %s", + str(relation.id), + str(field), + ) + return False + + # Remove secret from the relation if it's fully gone + if not new_content: + field = self._generate_secret_field_name(group) + try: + relation.data[self.component].pop(field) + except KeyError: + pass + label = self._generate_secret_label(self.relation_name, relation.id, group) + self.secrets.remove(label) + else: + secret.set_content(new_content) + + # Return the content that was removed + return True + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + if secret := self.secrets.get(label): + return secret + + relation = self._model.get_relation(relation_name, relation_id) + if not relation: + return + + secret_field = self._generate_secret_field_name(group_mapping) + if secret_uri := relation.data[self.local_app].get(secret_field): + return self.secrets.get(label, secret_uri) + + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetching relation data for Provider. + + NOTE: Since all secret fields are in the Provider side of the databag, we don't need to worry about that + """ + if not relation.app: + return {} + + return self._fetch_relation_data_without_secrets(relation.app, relation, fields) + + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> dict: + """Fetching our own relation data.""" + secret_fields = None + if relation.app: + secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + return self._fetch_relation_data_with_secrets( + self.local_app, + secret_fields, + relation, + fields, + ) + + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Set values for fields not caring whether it's a secret or not.""" + req_secret_fields = [] + if relation.app: + req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + _, normal_fields = self._process_secret_fields( + relation, + req_secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + ) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.local_app, relation, normal_content) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete fields from the Relation not caring whether it's a secret or not.""" + req_secret_fields = [] + if relation.app: + req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + _, normal_fields = self._process_secret_fields( + relation, req_secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.local_app, relation, list(normal_fields)) + + # Public methods - "native" + + def set_credentials(self, relation_id: int, username: str, password: str) -> None: + """Set credentials. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + username: user that was created. + password: password of the created user. + """ + self.update_relation_data(relation_id, {"username": username, "password": password}) + + def set_tls(self, relation_id: int, tls: str) -> None: + """Set whether TLS is enabled. + + Args: + relation_id: the identifier for a particular relation. + tls: whether tls is enabled (True or False). + """ + self.update_relation_data(relation_id, {"tls": tls}) + + def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: + """Set the TLS CA in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + tls_ca: TLS certification authority. + """ + self.update_relation_data(relation_id, {"tls-ca": tls_ca}) + + # Public functions -- inherited + + fetch_my_relation_data = leader_only(Data.fetch_my_relation_data) + fetch_my_relation_field = leader_only(Data.fetch_my_relation_field) + + +class RequirerData(Data): + """Requirer-side of the relation.""" + + SECRET_FIELDS = ["username", "password", "tls", "tls-ca", "uris"] + + def __init__( + self, + model, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): + """Manager of base client relations.""" + super().__init__(model, relation_name) + self.extra_user_roles = extra_user_roles + self._secret_fields = list(self.SECRET_FIELDS) + if additional_secret_fields: + self._secret_fields += additional_secret_fields + self.data_component = self.local_unit + + @property + def secret_fields(self) -> Optional[List[str]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return self._secret_fields + + # Internal helper functions + + def _register_secret_to_relation( + self, relation_name: str, relation_id: int, secret_id: str, group: SecretGroup + ): + """Fetch secrets and apply local label on them. + + [MAGIC HERE] + If we fetch a secret using get_secret(id=, label=), + then will be "stuck" on the Secret object, whenever it may + appear (i.e. as an event attribute, or fetched manually) on future occasions. + + This will allow us to uniquely identify the secret on Provider side (typically on + 'secret-changed' events), and map it to the corresponding relation. + """ + label = self._generate_secret_label(relation_name, relation_id, group) + + # Fetchin the Secret's meta information ensuring that it's locally getting registered with + CachedSecret(self._model, self.component, label, secret_id).meta + + def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): + """Make sure that secrets of the provided list are locally 'registered' from the databag. + + More on 'locally registered' magic is described in _register_secret_to_relation() method + """ + if not relation.app: + return + + for group in SECRET_GROUPS.groups(): + secret_field = self._generate_secret_field_name(group) + if secret_field in params_name_list: + if secret_uri := relation.data[relation.app].get(secret_field): + self._register_secret_to_relation( + relation.name, relation.id, secret_uri, group + ) + + def _is_resource_created_for_relation(self, relation: Relation) -> bool: + if not relation.app: + return False + + data = self.fetch_relation_data([relation.id], ["username", "password"]).get( + relation.id, {} + ) + return bool(data.get("username")) and bool(data.get("password")) + + def is_resource_created(self, relation_id: Optional[int] = None) -> bool: + """Check if the resource has been created. + + This function can be used to check if the Provider answered with data in the charm code + when outside an event callback. + + Args: + relation_id (int, optional): When provided the check is done only for the relation id + provided, otherwise the check is done for all relations + + Returns: + True or False + + Raises: + IndexError: If relation_id is provided but that relation does not exist + """ + if relation_id is not None: + try: + relation = [relation for relation in self.relations if relation.id == relation_id][ + 0 + ] + return self._is_resource_created_for_relation(relation) + except IndexError: + raise IndexError(f"relation id {relation_id} cannot be accessed") + else: + return ( + all( + self._is_resource_created_for_relation(relation) for relation in self.relations + ) + if self.relations + else False + ) + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group) + return self.secrets.get(label) + + def _fetch_specific_relation_data( + self, relation, fields: Optional[List[str]] = None + ) -> Dict[str, str]: + """Fetching Requirer data -- that may include secrets.""" + if not relation.app: + return {} + return self._fetch_relation_data_with_secrets( + relation.app, self.secret_fields, relation, fields + ) + + def _fetch_my_specific_relation_data(self, relation, fields: Optional[List[str]]) -> dict: + """Fetching our own relation data.""" + return self._fetch_relation_data_without_secrets(self.local_app, relation, fields) + + def _update_relation_data(self, relation: Relation, data: dict) -> None: + """Updates a set of key-value pairs in the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation: the particular relation. + data: dict containing the key-value pairs + that should be updated in the relation. + """ + return self._update_relation_data_without_secrets(self.local_app, relation, data) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Deletes a set of fields from the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation: the particular relation. + fields: list containing the field names that should be removed from the relation. + """ + return self._delete_relation_data_without_secrets(self.local_app, relation, fields) + + # Public functions -- inherited + + fetch_my_relation_data = leader_only(Data.fetch_my_relation_data) + fetch_my_relation_field = leader_only(Data.fetch_my_relation_field) + + +class RequirerEventHandlers(EventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: RequirerData, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + self.framework.observe( + self.charm.on[relation_data.relation_name].relation_created, + self._on_relation_created_event, + ) + self.framework.observe( + charm.on.secret_changed, + self._on_secret_changed_event, + ) + + # Event handlers + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the relation is created.""" + if not self.relation_data.local_unit.is_leader(): + return + + if self.relation_data.secret_fields: # pyright: ignore [reportAttributeAccessIssue] + set_encoded_field( + event.relation, + self.relation_data.component, + REQ_SECRET_FIELDS, + self.relation_data.secret_fields, # pyright: ignore [reportAttributeAccessIssue] + ) + + @abstractmethod + def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + +################################################################################ +# Peer Relation Data +################################################################################ + + +class DataPeerData(RequirerData, ProviderData): + """Represents peer relations data.""" + + SECRET_FIELDS = [] + SECRET_FIELD_NAME = "internal_secret" + SECRET_LABEL_MAP = {} + + def __init__( + self, + model, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + ): + """Manager of base client relations.""" + RequirerData.__init__( + self, + model, + relation_name, + extra_user_roles, + additional_secret_fields, + ) + self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME + self.deleted_label = deleted_label + self._secret_label_map = {} + # Secrets that are being dynamically added within the scope of this event handler run + self._new_secrets = [] + self._additional_secret_group_mapping = additional_secret_group_mapping + + for group, fields in additional_secret_group_mapping.items(): + if group not in SECRET_GROUPS.groups(): + setattr(SECRET_GROUPS, group, group) + for field in fields: + secret_group = SECRET_GROUPS.get_group(group) + internal_field = self._field_to_internal_name(field, secret_group) + self._secret_label_map.setdefault(group, []).append(internal_field) + self._secret_fields.append(internal_field) + + @property + def scope(self) -> Optional[Scope]: + """Turn component information into Scope.""" + if isinstance(self.component, Application): + return Scope.APP + if isinstance(self.component, Unit): + return Scope.UNIT + + @property + def secret_label_map(self) -> Dict[str, str]: + """Property storing secret mappings.""" + return self._secret_label_map + + @property + def static_secret_fields(self) -> List[str]: + """Re-definition of the property in a way that dynamically extended list is retrieved.""" + return self._secret_fields + + @property + def secret_fields(self) -> List[str]: + """Re-definition of the property in a way that dynamically extended list is retrieved.""" + return ( + self.static_secret_fields if self.static_secret_fields else self.current_secret_fields + ) + + @property + def current_secret_fields(self) -> List[str]: + """Helper method to get all currently existing secret fields (added statically or dynamically).""" + if not self.secrets_enabled: + return [] + + if len(self._model.relations[self.relation_name]) > 1: + raise ValueError(f"More than one peer relation on {self.relation_name}") + + relation = self._model.relations[self.relation_name][0] + fields = [] + + ignores = [SECRET_GROUPS.get_group("user"), SECRET_GROUPS.get_group("tls")] + for group in SECRET_GROUPS.groups(): + if group in ignores: + continue + if content := self._get_group_secret_contents(relation, group): + fields += list(content.keys()) + return list(set(fields) | set(self._new_secrets)) + + @dynamic_secrets_only + def set_secret( + self, + relation_id: int, + field: str, + value: str, + group_mapping: Optional[SecretGroup] = None, + ) -> None: + """Public interface method to add a Relation Data field specifically as a Juju Secret. + + Args: + relation_id: ID of the relation + field: The secret field that is to be added + value: The string value of the secret + group_mapping: The name of the "secret group", in case the field is to be added to an existing secret + """ + full_field = self._field_to_internal_name(field, group_mapping) + if self.secrets_enabled and full_field not in self.current_secret_fields: + self._new_secrets.append(full_field) + if self._no_group_with_databag(field, full_field): + self.update_relation_data(relation_id, {full_field: value}) + + # Unlike for set_secret(), there's no harm using this operation with static secrets + # The restricion is only added to keep the concept clear + @dynamic_secrets_only + def get_secret( + self, + relation_id: int, + field: str, + group_mapping: Optional[SecretGroup] = None, + ) -> Optional[str]: + """Public interface method to fetch secrets only.""" + full_field = self._field_to_internal_name(field, group_mapping) + if ( + self.secrets_enabled + and full_field not in self.current_secret_fields + and field not in self.current_secret_fields + ): + return + if self._no_group_with_databag(field, full_field): + return self.fetch_my_relation_field(relation_id, full_field) + + @dynamic_secrets_only + def delete_secret( + self, + relation_id: int, + field: str, + group_mapping: Optional[SecretGroup] = None, + ) -> Optional[str]: + """Public interface method to delete secrets only.""" + full_field = self._field_to_internal_name(field, group_mapping) + if self.secrets_enabled and full_field not in self.current_secret_fields: + logger.warning(f"Secret {field} from group {group_mapping} was not found") + return + if self._no_group_with_databag(field, full_field): + self.delete_relation_data(relation_id, [full_field]) + + # Helpers + + @staticmethod + def _field_to_internal_name(field: str, group: Optional[SecretGroup]) -> str: + if not group or group == SECRET_GROUPS.EXTRA: + return field + return f"{field}{GROUP_SEPARATOR}{group}" + + @staticmethod + def _internal_name_to_field(name: str) -> Tuple[str, SecretGroup]: + parts = name.split(GROUP_SEPARATOR) + if not len(parts) > 1: + return (parts[0], SECRET_GROUPS.EXTRA) + secret_group = SECRET_GROUPS.get_group(parts[1]) + if not secret_group: + raise ValueError(f"Invalid secret field {name}") + return (parts[0], secret_group) + + def _group_secret_fields(self, secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + """Helper function to arrange secret mappings under their group. + + NOTE: All unrecognized items end up in the 'extra' secret bucket. + Make sure only secret fields are passed! + """ + secret_fieldnames_grouped = {} + for key in secret_fields: + field, group = self._internal_name_to_field(key) + secret_fieldnames_grouped.setdefault(group, []).append(field) + return secret_fieldnames_grouped + + def _content_for_secret_group( + self, content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + ) -> Dict[str, str]: + """Select : pairs from input, that belong to this particular Secret group.""" + if group_mapping == SECRET_GROUPS.EXTRA: + return {k: v for k, v in content.items() if k in self.secret_fields} + return { + self._internal_name_to_field(k)[0]: v + for k, v in content.items() + if k in self.secret_fields + } + + # Backwards compatibility + + def _check_deleted_label(self, relation, fields) -> None: + """Helper function for legacy behavior.""" + current_data = self.fetch_my_relation_data([relation.id], fields) + if current_data is not None: + # Check if the secret we wanna delete actually exists + # Given the "deleted label", here we can't rely on the default mechanism (i.e. 'key not found') + if non_existent := (set(fields) & set(self.secret_fields)) - set( + current_data.get(relation.id, []) + ): + logger.debug( + "Non-existing secret %s was attempted to be removed.", + ", ".join(non_existent), + ) + + def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: + """For Rolling Upgrades -- when moving from databag to secrets usage. + + Practically what happens here is to remove stuff from the databag that is + to be stored in secrets. + """ + if not self.secret_fields: + return + + secret_fields_passed = set(self.secret_fields) & set(fields) + for field in secret_fields_passed: + if self._fetch_relation_data_without_secrets(self.component, relation, [field]): + self._delete_relation_data_without_secrets(self.component, relation, [field]) + + def _remove_secret_field_name_from_databag(self, relation) -> None: + """Making sure that the old databag URI is gone. + + This action should not be executed more than once. + """ + # Nothing to do if 'internal-secret' is not in the databag + if not (relation.data[self.component].get(self._generate_secret_field_name())): + return + + # Making sure that the secret receives its label + # (This should have happened by the time we get here, rather an extra security measure.) + secret = self._get_relation_secret(relation.id) + + # Either app scope secret with leader executing, or unit scope secret + leader_or_unit_scope = self.component != self.local_app or self.local_unit.is_leader() + if secret and leader_or_unit_scope: + # Databag reference to the secret URI can be removed, now that it's labelled + relation.data[self.component].pop(self._generate_secret_field_name(), None) + + def _previous_labels(self) -> List[str]: + """Generator for legacy secret label names, for backwards compatibility.""" + result = [] + members = [self._model.app.name] + if self.scope: + members.append(self.scope.value) + result.append(f"{'.'.join(members)}") + return result + + def _no_group_with_databag(self, field: str, full_field: str) -> bool: + """Check that no secret group is attempted to be used together with databag.""" + if not self.secrets_enabled and full_field != field: + logger.error( + f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." + ) + return False + return True + + # Event handlers + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + pass + + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + + # Overrides of Relation Data handling functions + + def _generate_secret_label( + self, relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + members = [relation_name, self._model.app.name] + if self.scope: + members.append(self.scope.value) + if group_mapping != SECRET_GROUPS.EXTRA: + members.append(group_mapping) + return f"{'.'.join(members)}" + + def _generate_secret_field_name(self, group_mapping: SecretGroup = SECRET_GROUPS.EXTRA) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{self.secret_field_name}" + + @juju_secrets_only + def _get_relation_secret( + self, + relation_id: int, + group_mapping: SecretGroup = SECRET_GROUPS.EXTRA, + relation_name: Optional[str] = None, + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret specifically for peer relations. + + In case this code may be executed within a rolling upgrade, and we may need to + migrate secrets from the databag to labels, we make sure to stick the correct + label on the secret, and clean up the local databag. + """ + if not relation_name: + relation_name = self.relation_name + + relation = self._model.get_relation(relation_name, relation_id) + if not relation: + return + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + secret_uri = relation.data[self.component].get(self._generate_secret_field_name(), None) + + # URI or legacy label is only to applied when moving single legacy secret to a (new) label + if group_mapping == SECRET_GROUPS.EXTRA: + # Fetching the secret with fallback to URI (in case label is not yet known) + # Label would we "stuck" on the secret in case it is found + return self.secrets.get(label, secret_uri, legacy_labels=self._previous_labels()) + return self.secrets.get(label) + + def _get_group_secret_contents( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Union[Set[str], List[str]] = [], + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + secret_fields = [self._internal_name_to_field(k)[0] for k in secret_fields] + result = super()._get_group_secret_contents(relation, group, secret_fields) + if self.deleted_label: + result = {key: result[key] for key in result if result[key] != self.deleted_label} + if self._additional_secret_group_mapping: + return {self._field_to_internal_name(key, group): result[key] for key in result} + return result + + @either_static_or_dynamic_secrets + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + return self._fetch_relation_data_with_secrets( + self.component, self.secret_fields, relation, fields + ) + + @either_static_or_dynamic_secrets + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + self._remove_secret_from_databag(relation, list(data.keys())) + _, normal_fields = self._process_secret_fields( + relation, + self.secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + uri_to_databag=False, + ) + self._remove_secret_field_name_from_databag(relation) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.component, relation, normal_content) + + @either_static_or_dynamic_secrets + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + if self.secret_fields and self.deleted_label: + # Legacy, backwards compatibility + self._check_deleted_label(relation, fields) + + _, normal_fields = self._process_secret_fields( + relation, + self.secret_fields, + fields, + self._update_relation_secret, + data={field: self.deleted_label for field in fields}, + ) + else: + _, normal_fields = self._process_secret_fields( + relation, self.secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.component, relation, list(normal_fields)) + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: + """This method makes no sense for a Peer Relation.""" + raise NotImplementedError( + "Peer Relation only supports 'self-side' fetch methods: " + "fetch_my_relation_data() and fetch_my_relation_field()" + ) + + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """This method makes no sense for a Peer Relation.""" + raise NotImplementedError( + "Peer Relation only supports 'self-side' fetch methods: " + "fetch_my_relation_data() and fetch_my_relation_field()" + ) + + # Public functions -- inherited + + fetch_my_relation_data = Data.fetch_my_relation_data + fetch_my_relation_field = Data.fetch_my_relation_field + + +class DataPeerEventHandlers(RequirerEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: RequirerData, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + pass + + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + + +class DataPeer(DataPeerData, DataPeerEventHandlers): + """Represents peer relations.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm.model, + relation_name, + extra_user_roles, + additional_secret_fields, + additional_secret_group_mapping, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + +class DataPeerUnitData(DataPeerData): + """Unit data abstraction representation.""" + + SCOPE = Scope.UNIT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DataPeerUnit(DataPeerUnitData, DataPeerEventHandlers): + """Unit databag representation.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm.model, + relation_name, + extra_user_roles, + additional_secret_fields, + additional_secret_group_mapping, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + +class DataPeerOtherUnitData(DataPeerUnitData): + """Unit data abstraction representation.""" + + def __init__(self, unit: Unit, *args, **kwargs): + super().__init__(*args, **kwargs) + self.local_unit = unit + self.component = unit + def update_relation_data(self, relation_id: int, data: dict) -> None: + """This method makes no sense for a Other Peer Relation.""" + raise NotImplementedError("It's not possible to update data of another unit.") -class AuthenticationEvent(RelationEvent): - """Base class for authentication fields for events.""" + def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: + """This method makes no sense for a Other Peer Relation.""" + raise NotImplementedError("It's not possible to delete data of another unit.") + + +class DataPeerOtherUnitEventHandlers(DataPeerEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: DataPeerUnitData): + """Manager of base client relations.""" + unique_key = f"{relation_data.relation_name}-{relation_data.local_unit.name}" + super().__init__(charm, relation_data, unique_key=unique_key) + + +class DataPeerOtherUnit(DataPeerOtherUnitData, DataPeerOtherUnitEventHandlers): + """Unit databag representation for another unit than the executor.""" + + def __init__( + self, + unit: Unit, + charm: CharmBase, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + ): + DataPeerOtherUnitData.__init__( + self, + unit, + charm.model, + relation_name, + extra_user_roles, + additional_secret_fields, + additional_secret_group_mapping, + secret_field_name, + deleted_label, + ) + DataPeerOtherUnitEventHandlers.__init__(self, charm, self) + + +################################################################################ +# Cross-charm Relatoins Data Handling and Evenets +################################################################################ + +# Generic events + + +class ExtraRoleEvent(RelationEvent): + """Base class for data events.""" + + @property + def extra_user_roles(self) -> Optional[str]: + """Returns the extra user roles that were requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("extra-user-roles") + + +class RelationEventWithSecret(RelationEvent): + """Base class for Relation Events that need to handle secrets.""" + + @property + def _secrets(self) -> dict: + """Caching secrets to avoid fetching them each time a field is referrd. + + DON'T USE the encapsulated helper variable outside of this function + """ + if not hasattr(self, "_cached_secrets"): + self._cached_secrets = {} + return self._cached_secrets + + def _get_secret(self, group) -> Optional[Dict[str, str]]: + """Retrieveing secrets.""" + if not self.app: + return + if not self._secrets.get(group): + self._secrets[group] = None + secret_field = f"{PROV_SECRET_PREFIX}{group}" + if secret_uri := self.relation.data[self.app].get(secret_field): + secret = self.framework.model.get_secret(id=secret_uri) + self._secrets[group] = secret.get_content() + return self._secrets[group] + + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + return JujuVersion.from_environ().has_secrets + + +class AuthenticationEvent(RelationEventWithSecret): + """Base class for authentication fields for events. + + The amount of logic added here is not ideal -- but this was the only way to preserve + the interface when moving to Juju Secrets + """ @property def username(self) -> Optional[str]: """Returns the created username.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("username") + return self.relation.data[self.relation.app].get("username") @property def password(self) -> Optional[str]: """Returns the password for the created user.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("password") + return self.relation.data[self.relation.app].get("password") @property def tls(self) -> Optional[str]: """Returns whether TLS is configured.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls") + return self.relation.data[self.relation.app].get("tls") @property def tls_ca(self) -> Optional[str]: """Returns TLS CA.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls-ca") + return self.relation.data[self.relation.app].get("tls-ca") @@ -639,12 +2395,26 @@ class DatabaseProvidesEvent(RelationEvent): @property def database(self) -> Optional[str]: """Returns the database that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("database") class DatabaseRequestedEvent(DatabaseProvidesEvent, ExtraRoleEvent): """Event emitted when a new database is requested for use on this relation.""" + @property + def external_node_connectivity(self) -> bool: + """Returns the requested external_node_connectivity field.""" + if not self.relation.app: + return False + + return ( + self.relation.data[self.relation.app].get("external-node-connectivity", "false") + == "true" + ) + class DatabaseProvidesEvents(CharmEvents): """Database events. @@ -655,17 +2425,39 @@ class DatabaseProvidesEvents(CharmEvents): database_requested = EventSource(DatabaseRequestedEvent) -class DatabaseRequiresEvent(RelationEvent): +class DatabaseRequiresEvent(RelationEventWithSecret): """Base class for database events.""" + @property + def database(self) -> Optional[str]: + """Returns the database name.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("database") + @property def endpoints(self) -> Optional[str]: - """Returns a comma separated list of read/write endpoints.""" + """Returns a comma separated list of read/write endpoints. + + In VM charms, this is the primary's address. + In kubernetes charms, this is the service to the primary pod. + """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoints") @property def read_only_endpoints(self) -> Optional[str]: - """Returns a comma separated list of read only endpoints.""" + """Returns a comma separated list of read only endpoints. + + In VM charms, this is the address of all the secondary instances. + In kubernetes charms, this is the service to all replica pod instances. + """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("read-only-endpoints") @property @@ -674,6 +2466,9 @@ def replset(self) -> Optional[str]: MongoDB only. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("replset") @property @@ -682,6 +2477,14 @@ def uris(self) -> Optional[str]: MongoDB, Redis, OpenSearch. """ + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("uris") + return self.relation.data[self.relation.app].get("uris") @property @@ -690,6 +2493,9 @@ def version(self) -> Optional[str]: Version as informed by the database daemon. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("version") @@ -719,27 +2525,23 @@ class DatabaseRequiresEvents(CharmEvents): # Database Provider and Requires -class DatabaseProvides(DataProvides): - """Provider-side of the database relations.""" - - on = DatabaseProvidesEvents() +class DatabaseProviderData(ProviderData): + """Provider-side data of the database relations.""" - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Only the leader should handle this event. - if not self.local_unit.is_leader(): - return + def set_database(self, relation_id: int, database_name: str) -> None: + """Set database name. - # Check which data has changed to emit customs events. - diff = self._diff(event) + This function writes in the application data bag, therefore, + only the leader unit can call it. - # Emit a database requested event if the setup key (database name and optional - # extra user roles) was added to the relation databag by the application. - if "database" in diff.added: - self.on.database_requested.emit(event.relation, app=event.app, unit=event.unit) + Args: + relation_id: the identifier for a particular relation. + database_name: database name. + """ + self.update_relation_data(relation_id, {"database": database_name}) def set_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database primary connections. @@ -747,11 +2549,15 @@ def set_endpoints(self, relation_id: int, connection_strings: str) -> None: This function writes in the application data bag, therefore, only the leader unit can call it. + In VM charms, only the primary's address should be passed as an endpoint. + In kubernetes charms, the service endpoint to the primary pod should be + passed as an endpoint. + Args: relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"endpoints": connection_strings}) + self.update_relation_data(relation_id, {"endpoints": connection_strings}) def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database replicas connection strings. @@ -763,7 +2569,7 @@ def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) + self.update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) def set_replset(self, relation_id: int, replset: str) -> None: """Set replica set name in the application relation databag. @@ -774,7 +2580,7 @@ def set_replset(self, relation_id: int, replset: str) -> None: relation_id: the identifier for a particular relation. replset: replica set name. """ - self._update_relation_data(relation_id, {"replset": replset}) + self.update_relation_data(relation_id, {"replset": replset}) def set_uris(self, relation_id: int, uris: str) -> None: """Set the database connection URIs in the application relation databag. @@ -785,48 +2591,152 @@ def set_uris(self, relation_id: int, uris: str) -> None: relation_id: the identifier for a particular relation. uris: connection URIs. """ - self._update_relation_data(relation_id, {"uris": uris}) + self.update_relation_data(relation_id, {"uris": uris}) def set_version(self, relation_id: int, version: str) -> None: """Set the database version in the application relation databag. Args: - relation_id: the identifier for a particular relation. - version: database version. + relation_id: the identifier for a particular relation. + version: database version. + """ + self.update_relation_data(relation_id, {"version": version}) + + +class DatabaseProviderEventHandlers(EventHandlers): + """Provider-side of the database relation handlers.""" + + on = DatabaseProvidesEvents() # pyright: ignore [reportAssignmentType] + + def __init__( + self, charm: CharmBase, relation_data: DatabaseProviderData, unique_key: str = "" + ): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + # Just to calm down pyright, it can't parse that the same type is being used in the super() call above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit a database requested event if the setup key (database name and optional + # extra user roles) was added to the relation databag by the application. + if "database" in diff.added: + getattr(self.on, "database_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class DatabaseProvides(DatabaseProviderData, DatabaseProviderEventHandlers): + """Provider-side of the database relations.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + DatabaseProviderData.__init__(self, charm.model, relation_name) + DatabaseProviderEventHandlers.__init__(self, charm, self) + + +class DatabaseRequirerData(RequirerData): + """Requirer-side of the database relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + database_name: str, + extra_user_roles: Optional[str] = None, + relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], + external_node_connectivity: bool = False, + ): + """Manager of database client relations.""" + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) + self.database = database_name + self.relations_aliases = relations_aliases + self.external_node_connectivity = external_node_connectivity + + def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> bool: + """Returns whether a plugin is enabled in the database. + + Args: + plugin: name of the plugin to check. + relation_index: optional relation index to check the database + (default: 0 - first relation). + + PostgreSQL only. """ - self._update_relation_data(relation_id, {"version": version}) + # Psycopg 3 is imported locally to avoid the need of its package installation + # when relating to a database charm other than PostgreSQL. + import psycopg + + # Return False if no relation is established. + if len(self.relations) == 0: + return False + + relation_id = self.relations[relation_index].id + host = self.fetch_relation_field(relation_id, "endpoints") + + # Return False if there is no endpoint available. + if host is None: + return False + + host = host.split(":")[0] + + content = self.fetch_relation_data([relation_id], ["username", "password"]).get( + relation_id, {} + ) + user = content.get("username") + password = content.get("password") + + connection_string = ( + f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" + ) + try: + with psycopg.connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute( + "SELECT TRUE FROM pg_extension WHERE extname=%s::text;", (plugin,) + ) + return cursor.fetchone() is not None + except psycopg.Error as e: + logger.exception( + f"failed to check whether {plugin} plugin is enabled in the database: %s", str(e) + ) + return False -class DatabaseRequires(DataRequires): - """Requires-side of the database relation.""" +class DatabaseRequirerEventHandlers(RequirerEventHandlers): + """Requires-side of the relation.""" - on = DatabaseRequiresEvents() + on = DatabaseRequiresEvents() # pyright: ignore [reportAssignmentType] def __init__( - self, - charm, - relation_name: str, - database_name: str, - extra_user_roles: str = None, - relations_aliases: List[str] = None, + self, charm: CharmBase, relation_data: DatabaseRequirerData, unique_key: str = "" ): - """Manager of database client relations.""" - super().__init__(charm, relation_name, extra_user_roles) - self.database = database_name - self.relations_aliases = relations_aliases + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data # Define custom event names for each alias. - if relations_aliases: + if self.relation_data.relations_aliases: # Ensure the number of aliases does not exceed the maximum # of connections allowed in the specific relation. - relation_connection_limit = self.charm.meta.requires[relation_name].limit - if len(relations_aliases) != relation_connection_limit: + relation_connection_limit = self.charm.meta.requires[ + self.relation_data.relation_name + ].limit + if len(self.relation_data.relations_aliases) != relation_connection_limit: raise ValueError( f"The number of aliases must match the maximum number of connections allowed in the relation. " - f"Expected {relation_connection_limit}, got {len(relations_aliases)}" + f"Expected {relation_connection_limit}, got {len(self.relation_data.relations_aliases)}" ) - for relation_alias in relations_aliases: + if self.relation_data.relations_aliases: + for relation_alias in self.relation_data.relations_aliases: self.on.define_event(f"{relation_alias}_database_created", DatabaseCreatedEvent) self.on.define_event( f"{relation_alias}_endpoints_changed", DatabaseEndpointsChangedEvent @@ -836,6 +2746,10 @@ def __init__( DatabaseReadOnlyEndpointsChangedEvent, ) + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass + def _assign_relation_alias(self, relation_id: int) -> None: """Assigns an alias to a relation. @@ -845,29 +2759,32 @@ def _assign_relation_alias(self, relation_id: int) -> None: relation_id: the identifier for a particular relation. """ # If no aliases were provided, return immediately. - if not self.relations_aliases: + if not self.relation_data.relations_aliases: return # Return if an alias was already assigned to this relation # (like when there are more than one unit joining the relation). - if ( - self.charm.model.get_relation(self.relation_name, relation_id) - .data[self.local_unit] - .get("alias") - ): + relation = self.charm.model.get_relation(self.relation_data.relation_name, relation_id) + if relation and relation.data[self.relation_data.local_unit].get("alias"): return # Retrieve the available aliases (the ones that weren't assigned to any relation). - available_aliases = self.relations_aliases[:] - for relation in self.charm.model.relations[self.relation_name]: - alias = relation.data[self.local_unit].get("alias") + available_aliases = self.relation_data.relations_aliases[:] + for relation in self.charm.model.relations[self.relation_data.relation_name]: + alias = relation.data[self.relation_data.local_unit].get("alias") if alias: logger.debug("Alias %s was already assigned to relation %d", alias, relation.id) available_aliases.remove(alias) # Set the alias in the unit relation databag of the specific relation. - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_unit].update({"alias": available_aliases[0]}) + relation = self.charm.model.get_relation(self.relation_data.relation_name, relation_id) + if relation: + relation.data[self.relation_data.local_unit].update({"alias": available_aliases[0]}) + + # We need to set relation alias also on the application level so, + # it will be accessible in show-unit juju command, executed for a consumer application unit + if self.relation_data.local_unit.is_leader(): + self.relation_data.update_relation_data(relation_id, {"alias": available_aliases[0]}) def _emit_aliased_event(self, event: RelationChangedEvent, event_name: str) -> None: """Emit an aliased event to a particular relation if it has an alias. @@ -891,40 +2808,54 @@ def _get_relation_alias(self, relation_id: int) -> Optional[str]: Returns: the relation alias or None if the relation was not found. """ - for relation in self.charm.model.relations[self.relation_name]: + for relation in self.charm.model.relations[self.relation_data.relation_name]: if relation.id == relation_id: - return relation.data[self.local_unit].get("alias") + return relation.data[self.relation_data.local_unit].get("alias") return None - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the database relation.""" + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the database relation is created.""" + super()._on_relation_created_event(event) + # If relations aliases were provided, assign one to the relation. self._assign_relation_alias(event.relation.id) # Sets both database and extra user roles in the relation # if the roles are provided. Otherwise, sets only the database. - if self.extra_user_roles: - self._update_relation_data( - event.relation.id, - { - "database": self.database, - "extra-user-roles": self.extra_user_roles, - }, - ) - else: - self._update_relation_data(event.relation.id, {"database": self.database}) + if not self.relation_data.local_unit.is_leader(): + return + + event_data = {"database": self.relation_data.database} + + if self.relation_data.extra_user_roles: + event_data["extra-user-roles"] = self.relation_data.extra_user_roles + + # set external-node-connectivity field + if self.relation_data.external_node_connectivity: + event_data["external-node-connectivity"] = "true" + + self.relation_data.update_relation_data(event.relation.id, event_data) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the database relation has changed.""" # Check which data has changed to emit customs events. diff = self._diff(event) + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + # Check if the database is created # (the database charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("database created at %s", datetime.now()) - self.on.database_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "database_created").emit( + event.relation, app=event.app, unit=event.unit + ) # Emit the aliased event (if any). self._emit_aliased_event(event, "database_created") @@ -938,7 +2869,9 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.endpoints_changed.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "endpoints_changed").emit( + event.relation, app=event.app, unit=event.unit + ) # Emit the aliased event (if any). self._emit_aliased_event(event, "endpoints_changed") @@ -952,7 +2885,7 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("read-only-endpoints changed on %s", datetime.now()) - self.on.read_only_endpoints_changed.emit( + getattr(self.on, "read_only_endpoints_changed").emit( event.relation, app=event.app, unit=event.unit ) @@ -960,7 +2893,37 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: self._emit_aliased_event(event, "read_only_endpoints_changed") -# Kafka related events +class DatabaseRequires(DatabaseRequirerData, DatabaseRequirerEventHandlers): + """Provider-side of the database relations.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + database_name: str, + extra_user_roles: Optional[str] = None, + relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], + external_node_connectivity: bool = False, + ): + DatabaseRequirerData.__init__( + self, + charm.model, + relation_name, + database_name, + extra_user_roles, + relations_aliases, + additional_secret_fields, + external_node_connectivity, + ) + DatabaseRequirerEventHandlers.__init__(self, charm, self) + + +################################################################################ +# Charm-specific Relations Data and Events +################################################################################ + +# Kafka Events class KafkaProvidesEvent(RelationEvent): @@ -969,8 +2932,19 @@ class KafkaProvidesEvent(RelationEvent): @property def topic(self) -> Optional[str]: """Returns the topic that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("topic") + @property + def consumer_group_prefix(self) -> Optional[str]: + """Returns the consumer-group-prefix that was requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("consumer-group-prefix") + class TopicRequestedEvent(KafkaProvidesEvent, ExtraRoleEvent): """Event emitted when a new topic is requested for use on this relation.""" @@ -988,19 +2962,36 @@ class KafkaProvidesEvents(CharmEvents): class KafkaRequiresEvent(RelationEvent): """Base class for Kafka events.""" + @property + def topic(self) -> Optional[str]: + """Returns the topic.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("topic") + @property def bootstrap_server(self) -> Optional[str]: - """Returns a a comma-seperated list of broker uris.""" + """Returns a comma-separated list of broker uris.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoints") @property def consumer_group_prefix(self) -> Optional[str]: """Returns the consumer-group-prefix.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("consumer-group-prefix") @property def zookeeper_uris(self) -> Optional[str]: """Returns a comma separated list of Zookeeper uris.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("zookeeper-uris") @@ -1025,27 +3016,20 @@ class KafkaRequiresEvents(CharmEvents): # Kafka Provides and Requires -class KafkaProvides(DataProvides): +class KafkaProvidesData(ProviderData): """Provider-side of the Kafka relation.""" - on = KafkaProvidesEvents() - - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) - - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Only the leader should handle this event. - if not self.local_unit.is_leader(): - return + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) - # Check which data has changed to emit customs events. - diff = self._diff(event) + def set_topic(self, relation_id: int, topic: str) -> None: + """Set topic name in the application relation databag. - # Emit a topic requested event if the setup key (topic name and optional - # extra user roles) was added to the relation databag by the application. - if "topic" in diff.added: - self.on.topic_requested.emit(event.relation, app=event.app, unit=event.unit) + Args: + relation_id: the identifier for a particular relation. + topic: the topic name. + """ + self.update_relation_data(relation_id, {"topic": topic}) def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: """Set the bootstrap server in the application relation databag. @@ -1054,7 +3038,7 @@ def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: relation_id: the identifier for a particular relation. bootstrap_server: the bootstrap server address. """ - self._update_relation_data(relation_id, {"endpoints": bootstrap_server}) + self.update_relation_data(relation_id, {"endpoints": bootstrap_server}) def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str) -> None: """Set the consumer group prefix in the application relation databag. @@ -1063,43 +3047,111 @@ def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str relation_id: the identifier for a particular relation. consumer_group_prefix: the consumer group prefix string. """ - self._update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) + self.update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: """Set the zookeeper uris in the application relation databag. Args: relation_id: the identifier for a particular relation. - zookeeper_uris: comma-seperated list of ZooKeeper server uris. + zookeeper_uris: comma-separated list of ZooKeeper server uris. """ - self._update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) + self.update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) -class KafkaRequires(DataRequires): - """Requires-side of the Kafka relation.""" +class KafkaProvidesEventHandlers(EventHandlers): + """Provider-side of the Kafka relation.""" + + on = KafkaProvidesEvents() # pyright: ignore [reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: KafkaProvidesData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit a topic requested event if the setup key (topic name and optional + # extra user roles) was added to the relation databag by the application. + if "topic" in diff.added: + getattr(self.on, "topic_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class KafkaProvides(KafkaProvidesData, KafkaProvidesEventHandlers): + """Provider-side of the Kafka relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + KafkaProvidesData.__init__(self, charm.model, relation_name) + KafkaProvidesEventHandlers.__init__(self, charm, self) - on = KafkaRequiresEvents() - def __init__(self, charm, relation_name: str, topic: str, extra_user_roles: str = None): +class KafkaRequiresData(RequirerData): + """Requirer-side of the Kafka relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + topic: str, + extra_user_roles: Optional[str] = None, + consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): """Manager of Kafka client relations.""" - # super().__init__(charm, relation_name) - super().__init__(charm, relation_name, extra_user_roles) - self.charm = charm + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) self.topic = topic + self.consumer_group_prefix = consumer_group_prefix or "" - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the Kafka relation.""" - # Sets both topic and extra user roles in the relation - # if the roles are provided. Otherwise, sets only the topic. - self._update_relation_data( - event.relation.id, - { - "topic": self.topic, - "extra-user-roles": self.extra_user_roles, - } - if self.extra_user_roles is not None - else {"topic": self.topic}, - ) + @property + def topic(self): + """Topic to use in Kafka.""" + return self._topic + + @topic.setter + def topic(self, value): + # Avoid wildcards + if value == "*": + raise ValueError(f"Error on topic '{value}', cannot be a wildcard.") + self._topic = value + + +class KafkaRequiresEventHandlers(RequirerEventHandlers): + """Requires-side of the Kafka relation.""" + + on = KafkaRequiresEvents() # pyright: ignore [reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: KafkaRequiresData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the Kafka relation is created.""" + super()._on_relation_created_event(event) + + if not self.relation_data.local_unit.is_leader(): + return + + # Sets topic, extra user roles, and "consumer-group-prefix" in the relation + relation_data = { + f: getattr(self, f.replace("-", "_"), "") + for f in ["consumer-group-prefix", "extra-user-roles", "topic"] + } + + self.relation_data.update_relation_data(event.relation.id, relation_data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the Kafka relation has changed.""" @@ -1108,21 +3160,306 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check if the topic is created # (the Kafka charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("topic created at %s", datetime.now()) - self.on.topic_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "topic_created").emit(event.relation, app=event.app, unit=event.unit) # To avoid unnecessary application restarts do not trigger # “endpoints_changed“ event if “topic_created“ is triggered. return - # Emit an endpoints (bootstap-server) changed event if the Kakfa endpoints + # Emit an endpoints (bootstrap-server) changed event if the Kafka endpoints # added or changed this info in the relation databag. if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.bootstrap_server_changed.emit( + getattr(self.on, "bootstrap_server_changed").emit( event.relation, app=event.app, unit=event.unit ) # here check if this is the right design return + + +class KafkaRequires(KafkaRequiresData, KafkaRequiresEventHandlers): + """Provider-side of the Kafka relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + topic: str, + extra_user_roles: Optional[str] = None, + consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + KafkaRequiresData.__init__( + self, + charm.model, + relation_name, + topic, + extra_user_roles, + consumer_group_prefix, + additional_secret_fields, + ) + KafkaRequiresEventHandlers.__init__(self, charm, self) + + +# Opensearch related events + + +class OpenSearchProvidesEvent(RelationEvent): + """Base class for OpenSearch events.""" + + @property + def index(self) -> Optional[str]: + """Returns the index that was requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("index") + + +class IndexRequestedEvent(OpenSearchProvidesEvent, ExtraRoleEvent): + """Event emitted when a new index is requested for use on this relation.""" + + +class OpenSearchProvidesEvents(CharmEvents): + """OpenSearch events. + + This class defines the events that OpenSearch can emit. + """ + + index_requested = EventSource(IndexRequestedEvent) + + +class OpenSearchRequiresEvent(DatabaseRequiresEvent): + """Base class for OpenSearch requirer events.""" + + +class IndexCreatedEvent(AuthenticationEvent, OpenSearchRequiresEvent): + """Event emitted when a new index is created for use on this relation.""" + + +class OpenSearchRequiresEvents(CharmEvents): + """OpenSearch events. + + This class defines the events that the opensearch requirer can emit. + """ + + index_created = EventSource(IndexCreatedEvent) + endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) + authentication_updated = EventSource(AuthenticationEvent) + + +# OpenSearch Provides and Requires Objects + + +class OpenSearchProvidesData(ProviderData): + """Provider-side of the OpenSearch relation.""" + + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) + + def set_index(self, relation_id: int, index: str) -> None: + """Set the index in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + index: the index as it is _created_ on the provider charm. This needn't match the + requested index, and can be used to present a different index name if, for example, + the requested index is invalid. + """ + self.update_relation_data(relation_id, {"index": index}) + + def set_endpoints(self, relation_id: int, endpoints: str) -> None: + """Set the endpoints in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + endpoints: the endpoint addresses for opensearch nodes. + """ + self.update_relation_data(relation_id, {"endpoints": endpoints}) + + def set_version(self, relation_id: int, version: str) -> None: + """Set the opensearch version in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + version: database version. + """ + self.update_relation_data(relation_id, {"version": version}) + + +class OpenSearchProvidesEventHandlers(EventHandlers): + """Provider-side of the OpenSearch relation.""" + + on = OpenSearchProvidesEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: OpenSearchProvidesData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit an index requested event if the setup key (index name and optional extra user roles) + # have been added to the relation databag by the application. + if "index" in diff.added: + getattr(self.on, "index_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class OpenSearchProvides(OpenSearchProvidesData, OpenSearchProvidesEventHandlers): + """Provider-side of the OpenSearch relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + OpenSearchProvidesData.__init__(self, charm.model, relation_name) + OpenSearchProvidesEventHandlers.__init__(self, charm, self) + + +class OpenSearchRequiresData(RequirerData): + """Requires data side of the OpenSearch relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): + """Manager of OpenSearch client relations.""" + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) + self.index = index + + +class OpenSearchRequiresEventHandlers(RequirerEventHandlers): + """Requires events side of the OpenSearch relation.""" + + on = OpenSearchRequiresEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: OpenSearchRequiresData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the OpenSearch relation is created.""" + super()._on_relation_created_event(event) + + if not self.relation_data.local_unit.is_leader(): + return + + # Sets both index and extra user roles in the relation if the roles are provided. + # Otherwise, sets only the index. + data = {"index": self.relation_data.index} + if self.relation_data.extra_user_roles: + data["extra-user-roles"] = self.relation_data.extra_user_roles + + self.relation_data.update_relation_data(event.relation.id, data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + if not event.secret.label: + return + + relation = self.relation_data._relation_from_secret_label(event.secret.label) + if not relation: + logging.info( + f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" + ) + return + + if relation.app == self.charm.app: + logging.info("Secret changed event ignored for Secret Owner") + + remote_unit = None + for unit in relation.units: + if unit.app != self.charm.app: + remote_unit = unit + + logger.info("authentication updated") + getattr(self.on, "authentication_updated").emit( + relation, app=relation.app, unit=remote_unit + ) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the OpenSearch relation has changed. + + This event triggers individual custom events depending on the changing relation. + """ + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + secret_field_tls = self.relation_data._generate_secret_field_name(SECRET_GROUPS.TLS) + updates = {"username", "password", "tls", "tls-ca", secret_field_user, secret_field_tls} + if len(set(diff._asdict().keys()) - updates) < len(diff): + logger.info("authentication updated at: %s", datetime.now()) + getattr(self.on, "authentication_updated").emit( + event.relation, app=event.app, unit=event.unit + ) + + # Check if the index is created + # (the OpenSearch charm shares the credentials). + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: + # Emit the default event (the one without an alias). + logger.info("index created at: %s", datetime.now()) + getattr(self.on, "index_created").emit(event.relation, app=event.app, unit=event.unit) + + # To avoid unnecessary application restarts do not trigger + # “endpoints_changed“ event if “index_created“ is triggered. + return + + # Emit a endpoints changed event if the OpenSearch application added or changed this info + # in the relation databag. + if "endpoints" in diff.added or "endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("endpoints changed on %s", datetime.now()) + getattr(self.on, "endpoints_changed").emit( + event.relation, app=event.app, unit=event.unit + ) # here check if this is the right design + return + + +class OpenSearchRequires(OpenSearchRequiresData, OpenSearchRequiresEventHandlers): + """Requires-side of the OpenSearch relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + OpenSearchRequiresData.__init__( + self, + charm.model, + relation_name, + index, + extra_user_roles, + additional_secret_fields, + ) + OpenSearchRequiresEventHandlers.__init__(self, charm, self) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py new file mode 100644 index 0000000..1400df7 --- /dev/null +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -0,0 +1,1361 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstractions for the system's Debian/Ubuntu package information and repositories. + +This module contains abstractions and wrappers around Debian/Ubuntu-style repositories and +packages, in order to easily provide an idiomatic and Pythonic mechanism for adding packages and/or +repositories to systems for use in machine charms. + +A sane default configuration is attainable through nothing more than instantiation of the +appropriate classes. `DebianPackage` objects provide information about the architecture, version, +name, and status of a package. + +`DebianPackage` will try to look up a package either from `dpkg -L` or from `apt-cache` when +provided with a string indicating the package name. If it cannot be located, `PackageNotFoundError` +will be returned, as `apt` and `dpkg` otherwise return `100` for all errors, and a meaningful error +message if the package is not known is desirable. + +To install packages with convenience methods: + +```python +try: + # Run `apt-get update` + apt.update() + apt.add_package("zsh") + apt.add_package(["vim", "htop", "wget"]) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +```` + +To find details of a specific package: + +```python +try: + vim = apt.DebianPackage.from_system("vim") + + # To find from the apt cache only + # apt.DebianPackage.from_apt_cache("vim") + + # To find from installed packages only + # apt.DebianPackage.from_installed_package("vim") + + vim.ensure(PackageState.Latest) + logger.info("updated vim to version: %s", vim.fullversion) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +``` + + +`RepositoryMapping` will return a dict-like object containing enabled system repositories +and their properties (available groups, baseuri. gpg key). This class can add, disable, or +manipulate repositories. Items can be retrieved as `DebianRepository` objects. + +In order add a new repository with explicit details for fields, a new `DebianRepository` can +be added to `RepositoryMapping` + +`RepositoryMapping` provides an abstraction around the existing repositories on the system, +and can be accessed and iterated over like any `Mapping` object, to retrieve values by key, +iterate, or perform other operations. + +Keys are constructed as `{repo_type}-{}-{release}` in order to uniquely identify a repository. + +Repositories can be added with explicit values through a Python constructor. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-example.com-focal" not in repositories: + repositories.add(DebianRepository(enabled=True, repotype="deb", + uri="https://example.com", release="focal", groups=["universe"])) +``` + +Alternatively, any valid `sources.list` line may be used to construct a new +`DebianRepository`. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-us.archive.ubuntu.com-xenial" not in repositories: + line = "deb http://us.archive.ubuntu.com/ubuntu xenial main restricted" + repo = DebianRepository.from_repo_line(line) + repositories.add(repo) +``` +""" + +import fileinput +import glob +import logging +import os +import re +import subprocess +from collections.abc import Mapping +from enum import Enum +from subprocess import PIPE, CalledProcessError, check_output +from typing import Iterable, List, Optional, Tuple, Union +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "7c3dbc9c2ad44a47bd6fcb25caa270e5" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 13 + + +VALID_SOURCE_TYPES = ("deb", "deb-src") +OPTIONS_MATCHER = re.compile(r"\[.*?\]") + + +class Error(Exception): + """Base class of most errors raised by this library.""" + + def __repr__(self): + """Represent the Error.""" + return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) + + @property + def name(self): + """Return a string representation of the model plus class.""" + return "<{}.{}>".format(type(self).__module__, type(self).__name__) + + @property + def message(self): + """Return the message passed as an argument.""" + return self.args[0] + + +class PackageError(Error): + """Raised when there's an error installing or removing a package.""" + + +class PackageNotFoundError(Error): + """Raised when a requested package is not known to the system.""" + + +class PackageState(Enum): + """A class to represent possible package states.""" + + Present = "present" + Absent = "absent" + Latest = "latest" + Available = "available" + + +class DebianPackage: + """Represents a traditional Debian package and its utility functions. + + `DebianPackage` wraps information and functionality around a known package, whether installed + or available. The version, epoch, name, and architecture can be easily queried and compared + against other `DebianPackage` objects to determine the latest version or to install a specific + version. + + The representation of this object as a string mimics the output from `dpkg` for familiarity. + + Installation and removal of packages is handled through the `state` property or `ensure` + method, with the following options: + + apt.PackageState.Absent + apt.PackageState.Available + apt.PackageState.Present + apt.PackageState.Latest + + When `DebianPackage` is initialized, the state of a given `DebianPackage` object will be set to + `Available`, `Present`, or `Latest`, with `Absent` implemented as a convenience for removal + (though it operates essentially the same as `Available`). + """ + + def __init__( + self, name: str, version: str, epoch: str, arch: str, state: PackageState + ) -> None: + self._name = name + self._arch = arch + self._state = state + self._version = Version(version, epoch) + + def __eq__(self, other) -> bool: + """Equality for comparison. + + Args: + other: a `DebianPackage` object for comparison + + Returns: + A boolean reflecting equality + """ + return isinstance(other, self.__class__) and ( + self._name, + self._version.number, + ) == (other._name, other._version.number) + + def __hash__(self): + """Return a hash of this package.""" + return hash((self._name, self._version.number)) + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return a human-readable representation of the package.""" + return "<{}: {}-{}.{} -- {}>".format( + self.__class__.__name__, + self._name, + self._version, + self._arch, + str(self._state), + ) + + @staticmethod + def _apt( + command: str, + package_names: Union[str, List], + optargs: Optional[List[str]] = None, + ) -> None: + """Wrap package management commands for Debian/Ubuntu systems. + + Args: + command: the command given to `apt-get` + package_names: a package name or list of package names to operate on + optargs: an (Optional) list of additioanl arguments + + Raises: + PackageError if an error is encountered + """ + optargs = optargs if optargs is not None else [] + if isinstance(package_names, str): + package_names = [package_names] + _cmd = ["apt-get", "-y", *optargs, command, *package_names] + try: + env = os.environ.copy() + env["DEBIAN_FRONTEND"] = "noninteractive" + subprocess.run(_cmd, capture_output=True, check=True, text=True, env=env) + except CalledProcessError as e: + raise PackageError( + "Could not {} package(s) [{}]: {}".format(command, [*package_names], e.stderr) + ) from None + + def _add(self) -> None: + """Add a package to the system.""" + self._apt( + "install", + "{}={}".format(self.name, self.version), + optargs=["--option=Dpkg::Options::=--force-confold"], + ) + + def _remove(self) -> None: + """Remove a package from the system. Implementation-specific.""" + return self._apt("remove", "{}={}".format(self.name, self.version)) + + @property + def name(self) -> str: + """Returns the name of the package.""" + return self._name + + def ensure(self, state: PackageState): + """Ensure that a package is in a given state. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if self._state is not state: + if state not in (PackageState.Present, PackageState.Latest): + self._remove() + else: + self._add() + self._state = state + + @property + def present(self) -> bool: + """Returns whether or not a package is present.""" + return self._state in (PackageState.Present, PackageState.Latest) + + @property + def latest(self) -> bool: + """Returns whether the package is the most recent version.""" + return self._state is PackageState.Latest + + @property + def state(self) -> PackageState: + """Returns the current package state.""" + return self._state + + @state.setter + def state(self, state: PackageState) -> None: + """Set the package state to a given value. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if state in (PackageState.Latest, PackageState.Present): + self._add() + else: + self._remove() + self._state = state + + @property + def version(self) -> "Version": + """Returns the version for a package.""" + return self._version + + @property + def epoch(self) -> str: + """Returns the epoch for a package. May be unset.""" + return self._version.epoch + + @property + def arch(self) -> str: + """Returns the architecture for a package.""" + return self._arch + + @property + def fullversion(self) -> str: + """Returns the name+epoch for a package.""" + return "{}.{}".format(self._version, self._arch) + + @staticmethod + def _get_epoch_from_version(version: str) -> Tuple[str, str]: + """Pull the epoch, if any, out of a version string.""" + epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") + matches = epoch_matcher.search(version).groupdict() + return matches.get("epoch", ""), matches.get("version") + + @classmethod + def from_system( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Locates a package, either on the system or known to apt, and serializes the information. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. If an + architecture is not specified, this will be used for selection. + + """ + try: + return DebianPackage.from_installed_package(package, version, arch) + except PackageNotFoundError: + logger.debug( + "package '%s' is not currently installed or has the wrong architecture.", package + ) + + # Ok, try `apt-cache ...` + try: + return DebianPackage.from_apt_cache(package, version, arch) + except (PackageNotFoundError, PackageError): + # If we get here, it's not known to the systems. + # This seems unnecessary, but virtually all `apt` commands have a return code of `100`, + # and providing meaningful error messages without this is ugly. + raise PackageNotFoundError( + "Package '{}{}' could not be found on the system or in the apt cache!".format( + package, ".{}".format(arch) if arch else "" + ) + ) from None + + @classmethod + def from_installed_package( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + output = "" + try: + output = check_output(["dpkg", "-l", package], stderr=PIPE, universal_newlines=True) + except CalledProcessError: + raise PackageNotFoundError("Package is not installed: {}".format(package)) from None + + # Pop off the output from `dpkg -l' because there's no flag to + # omit it` + lines = str(output).splitlines()[5:] + + dpkg_matcher = re.compile( + r""" + ^(?P\w+?)\s+ + (?P.*?)(?P:\w+?)?\s+ + (?P.*?)\s+ + (?P\w+?)\s+ + (?P.*) + """, + re.VERBOSE, + ) + + for line in lines: + try: + matches = dpkg_matcher.search(line).groupdict() + package_status = matches["package_status"] + + if not package_status.endswith("i"): + logger.debug( + "package '%s' in dpkg output but not installed, status: '%s'", + package, + package_status, + ) + break + + epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) + pkg = DebianPackage( + matches["package_name"], + split_version, + epoch, + matches["arch"], + PackageState.Present, + ) + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + except AttributeError: + logger.warning("dpkg matcher could not parse line: %s", line) + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) + + @classmethod + def from_apt_cache( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + keys = ("Package", "Architecture", "Version") + + try: + output = check_output( + ["apt-cache", "show", package], stderr=PIPE, universal_newlines=True + ) + except CalledProcessError as e: + raise PackageError( + "Could not list packages in apt-cache: {}".format(e.stderr) + ) from None + + pkg_groups = output.strip().split("\n\n") + keys = ("Package", "Architecture", "Version") + + for pkg_raw in pkg_groups: + lines = str(pkg_raw).splitlines() + vals = {} + for line in lines: + if line.startswith(keys): + items = line.split(":", 1) + vals[items[0]] = items[1].strip() + else: + continue + + epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) + pkg = DebianPackage( + vals["Package"], + split_version, + epoch, + vals["Architecture"], + PackageState.Available, + ) + + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not in the apt cache!".format(package, arch)) + + +class Version: + """An abstraction around package versions. + + This seems like it should be strictly unnecessary, except that `apt_pkg` is not usable inside a + venv, and wedging version comparisons into `DebianPackage` would overcomplicate it. + + This class implements the algorithm found here: + https://www.debian.org/doc/debian-policy/ch-controlfields.html#version + """ + + def __init__(self, version: str, epoch: str): + self._version = version + self._epoch = epoch or "" + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return human-readable representation of the package.""" + return "{}{}".format("{}:".format(self._epoch) if self._epoch else "", self._version) + + @property + def epoch(self): + """Returns the epoch for a package. May be empty.""" + return self._epoch + + @property + def number(self) -> str: + """Returns the version number for a package.""" + return self._version + + def _get_parts(self, version: str) -> Tuple[str, str]: + """Separate the version into component upstream and Debian pieces.""" + try: + version.rindex("-") + except ValueError: + # No hyphens means no Debian version + return version, "0" + + upstream, debian = version.rsplit("-", 1) + return upstream, debian + + def _listify(self, revision: str) -> List[str]: + """Split a revision string into a listself. + + This list is comprised of alternating between strings and numbers, + padded on either end to always be "str, int, str, int..." and + always be of even length. This allows us to trivially implement the + comparison algorithm described. + """ + result = [] + while revision: + rev_1, remains = self._get_alphas(revision) + rev_2, remains = self._get_digits(remains) + result.extend([rev_1, rev_2]) + revision = remains + return result + + def _get_alphas(self, revision: str) -> Tuple[str, str]: + """Return a tuple of the first non-digit characters of a revision.""" + # get the index of the first digit + for i, char in enumerate(revision): + if char.isdigit(): + if i == 0: + return "", revision + return revision[0:i], revision[i:] + # string is entirely alphas + return revision, "" + + def _get_digits(self, revision: str) -> Tuple[int, str]: + """Return a tuple of the first integer characters of a revision.""" + # If the string is empty, return (0,'') + if not revision: + return 0, "" + # get the index of the first non-digit + for i, char in enumerate(revision): + if not char.isdigit(): + if i == 0: + return 0, revision + return int(revision[0:i]), revision[i:] + # string is entirely digits + return int(revision), "" + + def _dstringcmp(self, a, b): # noqa: C901 + """Debian package version string section lexical sort algorithm. + + The lexical comparison is a comparison of ASCII values modified so + that all the letters sort earlier than all the non-letters and so that + a tilde sorts before anything, even the end of a part. + """ + if a == b: + return 0 + try: + for i, char in enumerate(a): + if char == b[i]: + continue + # "a tilde sorts before anything, even the end of a part" + # (emptyness) + if char == "~": + return -1 + if b[i] == "~": + return 1 + # "all the letters sort earlier than all the non-letters" + if char.isalpha() and not b[i].isalpha(): + return -1 + if not char.isalpha() and b[i].isalpha(): + return 1 + # otherwise lexical sort + if ord(char) > ord(b[i]): + return 1 + if ord(char) < ord(b[i]): + return -1 + except IndexError: + # a is longer than b but otherwise equal, greater unless there are tildes + if char == "~": + return -1 + return 1 + # if we get here, a is shorter than b but otherwise equal, so check for tildes... + if b[len(a)] == "~": + return 1 + return -1 + + def _compare_revision_strings(self, first: str, second: str): # noqa: C901 + """Compare two debian revision strings.""" + if first == second: + return 0 + + # listify pads results so that we will always be comparing ints to ints + # and strings to strings (at least until we fall off the end of a list) + first_list = self._listify(first) + second_list = self._listify(second) + if first_list == second_list: + return 0 + try: + for i, item in enumerate(first_list): + # explicitly raise IndexError if we've fallen off the edge of list2 + if i >= len(second_list): + raise IndexError + # if the items are equal, next + if item == second_list[i]: + continue + # numeric comparison + if isinstance(item, int): + if item > second_list[i]: + return 1 + if item < second_list[i]: + return -1 + else: + # string comparison + return self._dstringcmp(item, second_list[i]) + except IndexError: + # rev1 is longer than rev2 but otherwise equal, hence greater + # ...except for goddamn tildes + if first_list[len(second_list)][0][0] == "~": + return 1 + return 1 + # rev1 is shorter than rev2 but otherwise equal, hence lesser + # ...except for goddamn tildes + if second_list[len(first_list)][0][0] == "~": + return -1 + return -1 + + def _compare_version(self, other) -> int: + if (self.number, self.epoch) == (other.number, other.epoch): + return 0 + + if self.epoch < other.epoch: + return -1 + if self.epoch > other.epoch: + return 1 + + # If none of these are true, follow the algorithm + upstream_version, debian_version = self._get_parts(self.number) + other_upstream_version, other_debian_version = self._get_parts(other.number) + + upstream_cmp = self._compare_revision_strings(upstream_version, other_upstream_version) + if upstream_cmp != 0: + return upstream_cmp + + debian_cmp = self._compare_revision_strings(debian_version, other_debian_version) + if debian_cmp != 0: + return debian_cmp + + return 0 + + def __lt__(self, other) -> bool: + """Less than magic method impl.""" + return self._compare_version(other) < 0 + + def __eq__(self, other) -> bool: + """Equality magic method impl.""" + return self._compare_version(other) == 0 + + def __gt__(self, other) -> bool: + """Greater than magic method impl.""" + return self._compare_version(other) > 0 + + def __le__(self, other) -> bool: + """Less than or equal to magic method impl.""" + return self.__eq__(other) or self.__lt__(other) + + def __ge__(self, other) -> bool: + """Greater than or equal to magic method impl.""" + return self.__gt__(other) or self.__eq__(other) + + def __ne__(self, other) -> bool: + """Not equal to magic method impl.""" + return not self.__eq__(other) + + +def add_package( + package_names: Union[str, List[str]], + version: Optional[str] = "", + arch: Optional[str] = "", + update_cache: Optional[bool] = False, +) -> Union[DebianPackage, List[DebianPackage]]: + """Add a package or list of packages to the system. + + Args: + package_names: single package name, or list of package names + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + update_cache: whether or not to run `apt-get update` prior to operating + + Raises: + TypeError if no package name is given, or explicit version is set for multiple packages + PackageNotFoundError if the package is not in the cache. + PackageError if packages fail to install + """ + cache_refreshed = False + if update_cache: + update() + cache_refreshed = True + + packages = {"success": [], "retry": [], "failed": []} + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + if len(package_names) != 1 and version: + raise TypeError( + "Explicit version should not be set if more than one package is being added!" + ) + + for p in package_names: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + logger.warning("failed to locate and install/update '%s'", pkg) + packages["retry"].append(p) + + if packages["retry"] and not cache_refreshed: + logger.info("updating the apt-cache and retrying installation of failed packages.") + update() + + for p in packages["retry"]: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + packages["failed"].append(p) + + if packages["failed"]: + raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) + + return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] + + +def _add( + name: str, + version: Optional[str] = "", + arch: Optional[str] = "", +) -> Tuple[Union[DebianPackage, str], bool]: + """Add a package to the system. + + Args: + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + + Returns: a tuple of `DebianPackage` if found, or a :str: if it is not, and + a boolean indicating success + """ + try: + pkg = DebianPackage.from_system(name, version, arch) + pkg.ensure(state=PackageState.Present) + return pkg, True + except PackageNotFoundError: + return name, False + + +def remove_package( + package_names: Union[str, List[str]] +) -> Union[DebianPackage, List[DebianPackage]]: + """Remove package(s) from the system. + + Args: + package_names: the name of a package + + Raises: + PackageNotFoundError if the package is not found. + """ + packages = [] + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + for p in package_names: + try: + pkg = DebianPackage.from_installed_package(p) + pkg.ensure(state=PackageState.Absent) + packages.append(pkg) + except PackageNotFoundError: + logger.info("package '%s' was requested for removal, but it was not installed.", p) + + # the list of packages will be empty when no package is removed + logger.debug("packages: '%s'", packages) + return packages[0] if len(packages) == 1 else packages + + +def update() -> None: + """Update the apt cache via `apt-get update`.""" + subprocess.run(["apt-get", "update"], capture_output=True, check=True) + + +def import_key(key: str) -> str: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, including BEGIN + and END markers or a keyid. + + Returns: + The GPG key filename written. + + Raises: + GPGKeyError if the key could not be imported + """ + key = key.strip() + if "-" in key or "\n" in key: + # Send everything not obviously a keyid to GPG to import, as + # we trust its validation better than our own. eg. handling + # comments before the key. + logger.debug("PGP key found (looks like ASCII Armor format)") + if ( + "-----BEGIN PGP PUBLIC KEY BLOCK-----" in key + and "-----END PGP PUBLIC KEY BLOCK-----" in key + ): + logger.debug("Writing provided PGP key in the binary format") + key_bytes = key.encode("utf-8") + key_name = DebianRepository._get_keyid_by_gpg_key(key_bytes) + key_gpg = DebianRepository._dearmor_gpg_key(key_bytes) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key_name) + DebianRepository._write_apt_gpg_keyfile( + key_name=gpg_key_filename, key_material=key_gpg + ) + return gpg_key_filename + else: + raise GPGKeyError("ASCII armor markers missing from GPG key") + else: + logger.warning( + "PGP key found (looks like Radix64 format). " + "SECURELY importing PGP key from keyserver; " + "full key not provided." + ) + # as of bionic add-apt-repository uses curl with an HTTPS keyserver URL + # to retrieve GPG keys. `apt-key adv` command is deprecated as is + # apt-key in general as noted in its manpage. See lp:1433761 for more + # history. Instead, /etc/apt/trusted.gpg.d is used directly to drop + # gpg + key_asc = DebianRepository._get_key_by_keyid(key) + # write the key in GPG format so that apt-key list shows it + key_gpg = DebianRepository._dearmor_gpg_key(key_asc.encode("utf-8")) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key) + DebianRepository._write_apt_gpg_keyfile(key_name=gpg_key_filename, key_material=key_gpg) + return gpg_key_filename + + +class InvalidSourceError(Error): + """Exceptions for invalid source entries.""" + + +class GPGKeyError(Error): + """Exceptions for GPG keys.""" + + +class DebianRepository: + """An abstraction to represent a repository.""" + + def __init__( + self, + enabled: bool, + repotype: str, + uri: str, + release: str, + groups: List[str], + filename: Optional[str] = "", + gpg_key_filename: Optional[str] = "", + options: Optional[dict] = None, + ): + self._enabled = enabled + self._repotype = repotype + self._uri = uri + self._release = release + self._groups = groups + self._filename = filename + self._gpg_key_filename = gpg_key_filename + self._options = options + + @property + def enabled(self): + """Return whether or not the repository is enabled.""" + return self._enabled + + @property + def repotype(self): + """Return whether it is binary or source.""" + return self._repotype + + @property + def uri(self): + """Return the URI.""" + return self._uri + + @property + def release(self): + """Return which Debian/Ubuntu releases it is valid for.""" + return self._release + + @property + def groups(self): + """Return the enabled package groups.""" + return self._groups + + @property + def filename(self): + """Returns the filename for a repository.""" + return self._filename + + @filename.setter + def filename(self, fname: str) -> None: + """Set the filename used when a repo is written back to disk. + + Args: + fname: a filename to write the repository information to. + """ + if not fname.endswith(".list"): + raise InvalidSourceError("apt source filenames should end in .list!") + + self._filename = fname + + @property + def gpg_key(self): + """Returns the path to the GPG key for this repository.""" + return self._gpg_key_filename + + @property + def options(self): + """Returns any additional repo options which are set.""" + return self._options + + def make_options_string(self) -> str: + """Generate the complete options string for a a repository. + + Combining `gpg_key`, if set, and the rest of the options to find + a complex repo string. + """ + options = self._options if self._options else {} + if self._gpg_key_filename: + options["signed-by"] = self._gpg_key_filename + + return ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in options.items()])) + if options + else "" + ) + + @staticmethod + def prefix_from_uri(uri: str) -> str: + """Get a repo list prefix from the uri, depending on whether a path is set.""" + uridetails = urlparse(uri) + path = ( + uridetails.path.lstrip("/").replace("/", "-") if uridetails.path else uridetails.netloc + ) + return "/etc/apt/sources.list.d/{}".format(path) + + @staticmethod + def from_repo_line(repo_line: str, write_file: Optional[bool] = True) -> "DebianRepository": + """Instantiate a new `DebianRepository` a `sources.list` entry line. + + Args: + repo_line: a string representing a repository entry + write_file: boolean to enable writing the new repo to disk + """ + repo = RepositoryMapping._parse(repo_line, "UserInput") + fname = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + repo.filename = fname + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + # For Python 3.5 it's required to use sorted in the options dict in order to not have + # different results in the order of the options between executions. + options_str = ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in sorted(options.items())])) + if options + else "" + ) + + if write_file: + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, options_str, repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + return repo + + def disable(self) -> None: + """Remove this repository from consideration. + + Disable it instead of removing from the repository file. + """ + searcher = "{} {}{} {}".format( + self.repotype, self.make_options_string(), self.uri, self.release + ) + for line in fileinput.input(self._filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + def import_key(self, key: str) -> None: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, + including BEGIN and END markers or a keyid. + + Raises: + GPGKeyError if the key could not be imported + """ + self._gpg_key_filename = import_key(key) + + @staticmethod + def _get_keyid_by_gpg_key(key_material: bytes) -> str: + """Get a GPG key fingerprint by GPG key material. + + Gets a GPG key fingerprint (40-digit, 160-bit) by the ASCII armor-encoded + or binary GPG key material. Can be used, for example, to generate file + names for keys passed via charm options. + """ + # Use the same gpg command for both Xenial and Bionic + cmd = ["gpg", "--with-colons", "--with-fingerprint"] + ps = subprocess.run( + cmd, + stdout=PIPE, + stderr=PIPE, + input=key_material, + ) + out, err = ps.stdout.decode(), ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError("Invalid GPG key material provided") + # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) + return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) + + @staticmethod + def _get_key_by_keyid(keyid: str) -> str: + """Get a key via HTTPS from the Ubuntu keyserver. + + Different key ID formats are supported by SKS keyservers (the longer ones + are more secure, see "dead beef attack" and https://evil32.com/). Since + HTTPS is used, if SSLBump-like HTTPS proxies are in place, they will + impersonate keyserver.ubuntu.com and generate a certificate with + keyserver.ubuntu.com in the CN field or in SubjAltName fields of a + certificate. If such proxy behavior is expected it is necessary to add the + CA certificate chain containing the intermediate CA of the SSLBump proxy to + every machine that this code runs on via ca-certs cloud-init directive (via + cloudinit-userdata model-config) or via other means (such as through a + custom charm option). Also note that DNS resolution for the hostname in a + URL is done at a proxy server - not at the client side. + 8-digit (32 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x4652B4E6 + 16-digit (64 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x6E85A86E4652B4E6 + 40-digit key ID: + https://keyserver.ubuntu.com/pks/lookup?search=0x35F77D63B5CEC106C577ED856E85A86E4652B4E6 + + Args: + keyid: An 8, 16 or 40 hex digit keyid to find a key for + + Returns: + A string contining key material for the specified GPG key id + + + Raises: + subprocess.CalledProcessError + """ + # options=mr - machine-readable output (disables html wrappers) + keyserver_url = ( + "https://keyserver.ubuntu.com" "/pks/lookup?op=get&options=mr&exact=on&search=0x{}" + ) + curl_cmd = ["curl", keyserver_url.format(keyid)] + # use proxy server settings in order to retrieve the key + return check_output(curl_cmd).decode() + + @staticmethod + def _dearmor_gpg_key(key_asc: bytes) -> bytes: + """Convert a GPG key in the ASCII armor format to the binary format. + + Args: + key_asc: A GPG key in ASCII armor format. + + Returns: + A GPG key in binary format as a string + + Raises: + GPGKeyError + """ + ps = subprocess.run(["gpg", "--dearmor"], stdout=PIPE, stderr=PIPE, input=key_asc) + out, err = ps.stdout, ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError( + "Invalid GPG key material. Check your network setup" + " (MTU, routing, DNS) and/or proxy server settings" + " as well as destination keyserver status." + ) + else: + return out + + @staticmethod + def _write_apt_gpg_keyfile(key_name: str, key_material: bytes) -> None: + """Write GPG key material into a file at a provided path. + + Args: + key_name: A key name to use for a key file (could be a fingerprint) + key_material: A GPG key material (binary) + """ + with open(key_name, "wb") as keyf: + keyf.write(key_material) + + +class RepositoryMapping(Mapping): + """An representation of known repositories. + + Instantiation of `RepositoryMapping` will iterate through the + filesystem, parse out repository files in `/etc/apt/...`, and create + `DebianRepository` objects in this list. + + Typical usage: + + repositories = apt.RepositoryMapping() + repositories.add(DebianRepository( + enabled=True, repotype="deb", uri="https://example.com", release="focal", + groups=["universe"] + )) + """ + + def __init__(self): + self._repository_map = {} + # Repositories that we're adding -- used to implement mode param + self.default_file = "/etc/apt/sources.list" + + # read sources.list if it exists + if os.path.isfile(self.default_file): + self.load(self.default_file) + + # read sources.list.d + for file in glob.iglob("/etc/apt/sources.list.d/*.list"): + self.load(file) + + def __contains__(self, key: str) -> bool: + """Magic method for checking presence of repo in mapping.""" + return key in self._repository_map + + def __len__(self) -> int: + """Return number of repositories in map.""" + return len(self._repository_map) + + def __iter__(self) -> Iterable[DebianRepository]: + """Return iterator for RepositoryMapping.""" + return iter(self._repository_map.values()) + + def __getitem__(self, repository_uri: str) -> DebianRepository: + """Return a given `DebianRepository`.""" + return self._repository_map[repository_uri] + + def __setitem__(self, repository_uri: str, repository: DebianRepository) -> None: + """Add a `DebianRepository` to the cache.""" + self._repository_map[repository_uri] = repository + + def load(self, filename: str): + """Load a repository source file into the cache. + + Args: + filename: the path to the repository file + """ + parsed = [] + skipped = [] + with open(filename, "r") as f: + for n, line in enumerate(f): + try: + repo = self._parse(line, filename) + except InvalidSourceError: + skipped.append(n) + else: + repo_identifier = "{}-{}-{}".format(repo.repotype, repo.uri, repo.release) + self._repository_map[repo_identifier] = repo + parsed.append(n) + logger.debug("parsed repo: '%s'", repo_identifier) + + if skipped: + skip_list = ", ".join(str(s) for s in skipped) + logger.debug("skipped the following lines in file '%s': %s", filename, skip_list) + + if parsed: + logger.info("parsed %d apt package repositories", len(parsed)) + else: + raise InvalidSourceError("all repository lines in '{}' were invalid!".format(filename)) + + @staticmethod + def _parse(line: str, filename: str) -> DebianRepository: + """Parse a line in a sources.list file. + + Args: + line: a single line from `load` to parse + filename: the filename being read + + Raises: + InvalidSourceError if the source type is unknown + """ + enabled = True + repotype = uri = release = gpg_key = "" + options = {} + groups = [] + + line = line.strip() + if line.startswith("#"): + enabled = False + line = line[1:] + + # Check for "#" in the line and treat a part after it as a comment then strip it off. + i = line.find("#") + if i > 0: + line = line[:i] + + # Split a source into substrings to initialize a new repo. + source = line.strip() + if source: + # Match any repo options, and get a dict representation. + for v in re.findall(OPTIONS_MATCHER, source): + opts = dict(o.split("=") for o in v.strip("[]").split()) + # Extract the 'signed-by' option for the gpg_key + gpg_key = opts.pop("signed-by", "") + options = opts + + # Remove any options from the source string and split the string into chunks + source = re.sub(OPTIONS_MATCHER, "", source) + chunks = source.split() + + # Check we've got a valid list of chunks + if len(chunks) < 3 or chunks[0] not in VALID_SOURCE_TYPES: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + repotype = chunks[0] + uri = chunks[1] + release = chunks[2] + groups = chunks[3:] + + return DebianRepository( + enabled, repotype, uri, release, groups, filename, gpg_key, options + ) + else: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + def add(self, repo: DebianRepository, default_filename: Optional[bool] = False) -> None: + """Add a new repository to the system. + + Args: + repo: a `DebianRepository` object + default_filename: an (Optional) filename if the default is not desirable + """ + new_filename = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + + fname = repo.filename or new_filename + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, repo.make_options_string(), repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo + + def disable(self, repo: DebianRepository) -> None: + """Remove a repository. Disable by default. + + Args: + repo: a `DebianRepository` to disable + """ + searcher = "{} {}{} {}".format( + repo.repotype, repo.make_options_string(), repo.uri, repo.release + ) + + for line in fileinput.input(repo.filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py index 5be34c1..cdcbad6 100644 --- a/lib/charms/operator_libs_linux/v1/systemd.py +++ b/lib/charms/operator_libs_linux/v1/systemd.py @@ -23,6 +23,7 @@ service_resume with run the mask/unmask and enable/disable invocations. Example usage: + ```python from charms.operator_libs_linux.v0.systemd import service_running, service_reload @@ -33,13 +34,14 @@ # Attempt to reload a service, restarting if necessary success = service_reload("nginx", restart_on_failure=True) ``` - """ -import logging -import subprocess - __all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) + "SystemdError", + "daemon_reload", + "service_disable", + "service_enable", + "service_failed", "service_pause", "service_reload", "service_restart", @@ -47,9 +49,11 @@ "service_running", "service_start", "service_stop", - "daemon_reload", ] +import logging +import subprocess + logger = logging.getLogger(__name__) # The unique Charmhub library identifier, never change it @@ -60,122 +64,168 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 0 +LIBPATCH = 4 class SystemdError(Exception): - pass + """Custom exception for SystemD related errors.""" -def _popen_kwargs(): - return dict( - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=1, - universal_newlines=True, - encoding="utf-8", - ) +def _systemctl(*args: str, check: bool = False) -> int: + """Control a system service using systemctl. + Args: + *args: Arguments to pass to systemctl. + check: Check the output of the systemctl command. Default: False. -def _systemctl( - sub_cmd: str, service_name: str = None, now: bool = None, quiet: bool = None -) -> bool: - """Control a system service. + Returns: + Returncode of systemctl command execution. - Args: - sub_cmd: the systemctl subcommand to issue - service_name: the name of the service to perform the action on - now: passes the --now flag to the shell invocation. - quiet: passes the --quiet flag to the shell invocation. + Raises: + SystemdError: Raised if calling systemctl returns a non-zero returncode and check is True. """ - cmd = ["systemctl", sub_cmd] - - if service_name is not None: - cmd.append(service_name) - if now is not None: - cmd.append("--now") - if quiet is not None: - cmd.append("--quiet") - if sub_cmd != "is-active": - logger.debug("Attempting to {} '{}' with command {}.".format(cmd, service_name, cmd)) - else: - logger.debug("Checking if '{}' is active".format(service_name)) - - proc = subprocess.Popen(cmd, **_popen_kwargs()) - last_line = "" - for line in iter(proc.stdout.readline, ""): - last_line = line - logger.debug(line) - - proc.wait() - - if sub_cmd == "is-active": - # If we are just checking whether a service is running, return True/False, rather - # than raising an error. - if proc.returncode < 1: - return True - if proc.returncode == 3: # Code returned when service is not active. - return False - - if proc.returncode < 1: - return True - - raise SystemdError( - "Could not {}{}: systemd output: {}".format( - sub_cmd, " {}".format(service_name) if service_name else "", last_line + cmd = ["systemctl", *args] + logger.debug(f"Executing command: {cmd}") + try: + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + encoding="utf-8", + check=check, + ) + logger.debug( + f"Command {cmd} exit code: {proc.returncode}. systemctl output:\n{proc.stdout}" + ) + return proc.returncode + except subprocess.CalledProcessError as e: + raise SystemdError( + f"Command {cmd} failed with returncode {e.returncode}. systemctl output:\n{e.stdout}" ) - ) def service_running(service_name: str) -> bool: - """Determine whether a system service is running. + """Report whether a system service is running. + + Args: + service_name: The name of the service to check. + + Return: + True if service is running/active; False if not. + """ + # If returncode is 0, this means that is service is active. + return _systemctl("--quiet", "is-active", service_name) == 0 + + +def service_failed(service_name: str) -> bool: + """Report whether a system service has failed. Args: - service_name: the name of the service + service_name: The name of the service to check. + + Returns: + True if service is marked as failed; False if not. """ - return _systemctl("is-active", service_name, quiet=True) + # If returncode is 0, this means that the service has failed. + return _systemctl("--quiet", "is-failed", service_name) == 0 -def service_start(service_name: str) -> bool: +def service_start(*args: str) -> bool: """Start a system service. Args: - service_name: the name of the service to stop + *args: Arguments to pass to `systemctl start` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl start ...` returns a non-zero returncode. """ - return _systemctl("start", service_name) + return _systemctl("start", *args, check=True) == 0 -def service_stop(service_name: str) -> bool: +def service_stop(*args: str) -> bool: """Stop a system service. Args: - service_name: the name of the service to stop + *args: Arguments to pass to `systemctl stop` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl stop ...` returns a non-zero returncode. """ - return _systemctl("stop", service_name) + return _systemctl("stop", *args, check=True) == 0 -def service_restart(service_name: str) -> bool: +def service_restart(*args: str) -> bool: """Restart a system service. Args: - service_name: the name of the service to restart + *args: Arguments to pass to `systemctl restart` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl restart ...` returns a non-zero returncode. """ - return _systemctl("restart", service_name) + return _systemctl("restart", *args, check=True) == 0 + + +def service_enable(*args: str) -> bool: + """Enable a system service. + + Args: + *args: Arguments to pass to `systemctl enable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl enable ...` returns a non-zero returncode. + """ + return _systemctl("enable", *args, check=True) == 0 + + +def service_disable(*args: str) -> bool: + """Disable a system service. + + Args: + *args: Arguments to pass to `systemctl disable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl disable ...` returns a non-zero returncode. + """ + return _systemctl("disable", *args, check=True) == 0 def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: """Reload a system service, optionally falling back to restart if reload fails. Args: - service_name: the name of the service to reload - restart_on_failure: boolean indicating whether to fallback to a restart if the - reload fails. + service_name: The name of the service to reload. + restart_on_failure: + Boolean indicating whether to fall back to a restart if the reload fails. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl reload|restart ...` returns a non-zero returncode. """ try: - return _systemctl("reload", service_name) + return _systemctl("reload", service_name, check=True) == 0 except SystemdError: if restart_on_failure: - return _systemctl("restart", service_name) + return service_restart(service_name) else: raise @@ -183,37 +233,56 @@ def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: def service_pause(service_name: str) -> bool: """Pause a system service. - Stop it, and prevent it from starting again at boot. + Stops the service and prevents the service from starting again at boot. Args: - service_name: the name of the service to pause + service_name: The name of the service to pause. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is still running after being paused by systemctl. """ - _systemctl("disable", service_name, now=True) + _systemctl("disable", "--now", service_name) _systemctl("mask", service_name) - if not service_running(service_name): - return True + if service_running(service_name): + raise SystemdError(f"Attempted to pause {service_name!r}, but it is still running.") - raise SystemdError("Attempted to pause '{}', but it is still running.".format(service_name)) + return True def service_resume(service_name: str) -> bool: """Resume a system service. - Re-enable starting again at boot. Start the service. + Re-enable starting the service again at boot. Start the service. Args: - service_name: the name of the service to resume + service_name: The name of the service to resume. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is not running after being resumed by systemctl. """ _systemctl("unmask", service_name) - _systemctl("enable", service_name, now=True) + _systemctl("enable", "--now", service_name) - if service_running(service_name): - return True + if not service_running(service_name): + raise SystemdError(f"Attempted to resume {service_name!r}, but it is not running.") - raise SystemdError("Attempted to resume '{}', but it is not running.".format(service_name)) + return True def daemon_reload() -> bool: - """Reload systemd manager configuration.""" - return _systemctl("daemon-reload") + """Reload systemd manager configuration. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl daemon-reload` returns a non-zero returncode. + """ + return _systemctl("daemon-reload", check=True) == 0 diff --git a/metadata.yaml b/metadata.yaml deleted file mode 100644 index 01f2b0d..0000000 --- a/metadata.yaml +++ /dev/null @@ -1,30 +0,0 @@ -name: slurmdbd -summary: | - Slurm DBD accounting daemon -description: | - This charm provides slurmdbd, munged, and the bindings to other utilities - that make lifecycle operations a breeze. - - slurmdbd provides a secure enterprise-wide interface to a database for - SLURM. This is particularly useful for archiving accounting records. -source: https://github.com/omnivector-solutions/slurmdbd-operator -issues: https://github.com/omnivector-solutions/slurmdbd-operator/issues -maintainers: - - OmniVector Solutions - - Jason C. Nucciarone - - David Gomez - -peers: - slurmdbd-peer: - interface: slurmdbd-peer -requires: - database: - interface: mysql_client - fluentbit: - interface: fluentbit -provides: - slurmdbd: - interface: slurmdbd - -assumes: - - juju diff --git a/pyproject.toml b/pyproject.toml index 99734d9..65f01b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ target-version = ["py38"] # Linting tools configuration [tool.ruff] line-length = 99 -select = ["E", "W", "F", "C", "N", "D", "I001"] -extend-ignore = [ +lint.select = ["E", "W", "F", "C", "N", "D", "I001"] +lint.extend-ignore = [ "D203", "D204", "D213", @@ -50,9 +50,9 @@ extend-ignore = [ "D409", "D413", ] -ignore = ["E501", "D107"] +lint.ignore = ["E501", "D107"] extend-exclude = ["__pycache__", "*.egg_info"] -per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} +lint.per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 10 diff --git a/requirements.txt b/requirements.txt index dfa795a..5889f38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -ops==2.* -git+https://github.com/omnivector-solutions/slurm-ops-manager.git@0.8.16 \ No newline at end of file +ops==2.14.0 +distro==1.9.0 diff --git a/src/charm.py b/src/charm.py index 0c0d775..c86544e 100755 --- a/src/charm.py +++ b/src/charm.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 -# Copyright 2020 Omnivector Solutions, LLC. +# Copyright 2020-2024 Omnivector, LLC. # See LICENSE file for licensing details. """Slurmdbd Operator Charm.""" import logging -from pathlib import Path from time import sleep -from typing import Any, Dict +from typing import Any, Dict, Union from urllib.parse import urlparse from charms.data_platform_libs.v0.data_interfaces import ( @@ -15,149 +14,105 @@ DatabaseRequires, ) from charms.fluentbit.v0.fluentbit import FluentbitClient -from interface_slurmdbd import Slurmdbd -from interface_slurmdbd_peer import SlurmdbdPeer -from ops.charm import CharmBase, CharmEvents -from ops.framework import EventBase, EventSource, StoredState -from ops.main import main -from ops.model import ActiveStatus, BlockedStatus, ErrorStatus, WaitingStatus -from slurm_ops_manager import SlurmManager -from utils.manager import SlurmdbdManager +from constants import CHARM_MAINTAINED_PARAMETERS, SLURM_ACCT_DB +from interface_slurmctld import Slurmctld, SlurmctldAvailableEvent, SlurmctldUnavailableEvent +from ops import ( + ActiveStatus, + BlockedStatus, + CharmBase, + ConfigChangedEvent, + InstallEvent, + RelationCreatedEvent, + StoredState, + UpdateStatusEvent, + WaitingStatus, + main, +) +from slurmdbd_ops import SlurmdbdOpsManager logger = logging.getLogger(__name__) -SLURM_ACCT_DB = "slurm_acct_db" - - -class JwtAvailable(EventBase): - """Emitted when JWT RSA is available.""" - - -class MungeAvailable(EventBase): - """Emitted when JWT RSA is available.""" - - -class WriteConfigAndRestartSlurmdbd(EventBase): - """Emitted when config needs to be written.""" - - -class SlurmdbdCharmEvents(CharmEvents): - """Slurmdbd emitted events.""" - - jwt_available = EventSource(JwtAvailable) - munge_available = EventSource(MungeAvailable) - write_config = EventSource(WriteConfigAndRestartSlurmdbd) - - class SlurmdbdCharm(CharmBase): """Slurmdbd Charm.""" _stored = StoredState() - on = SlurmdbdCharmEvents() def __init__(self, *args, **kwargs) -> None: """Set the default class attributes.""" super().__init__(*args, **kwargs) self._stored.set_default( - db_info={}, - jwt_available=False, - munge_available=False, slurm_installed=False, - cluster_name=str(), + db_info={}, ) - self._slurmdbd_manager = SlurmdbdManager() self._db = DatabaseRequires(self, relation_name="database", database_name=SLURM_ACCT_DB) - self._slurm_manager = SlurmManager(self, "slurmdbd") - self._slurmdbd = Slurmdbd(self, "slurmdbd") - self._slurmdbd_peer = SlurmdbdPeer(self, "slurmdbd-peer") self._fluentbit = FluentbitClient(self, "fluentbit") + self._slurmdbd_ops_manager = SlurmdbdOpsManager() + self._slurmctld = Slurmctld(self, "slurmctld") - for event, handler in { + event_handler_bindings = { + # Charm core events self.on.install: self._on_install, - self.on.upgrade_charm: self._on_upgrade, self.on.update_status: self._on_update_status, self.on.config_changed: self._write_config_and_restart_slurmdbd, - self.on.jwt_available: self._on_jwt_available, - self.on.munge_available: self._on_munge_available, - self.on.write_config: self._write_config_and_restart_slurmdbd, + # Database relation self._db.on.database_created: self._on_database_created, - self._slurmdbd_peer.on.slurmdbd_peer_available: self._write_config_and_restart_slurmdbd, - self._slurmdbd.on.slurmctld_available: self._on_slurmctld_available, - self._slurmdbd.on.slurmctld_unavailable: self._on_slurmctld_unavailable, + # Slurmctld + self._slurmctld.on.slurmctld_available: self._on_slurmctld_available, + self._slurmctld.on.slurmctld_unavailable: self._on_slurmctld_unavailable, # fluentbit self.on["fluentbit"].relation_created: self._on_fluentbit_relation_created, - }.items(): + } + for event, handler in event_handler_bindings.items(): self.framework.observe(event, handler) - def _on_install(self, event): + def _on_install(self, event: InstallEvent) -> None: """Perform installation operations for slurmdbd.""" - self.unit.set_workload_version(Path("version").read_text().strip()) + if not self.model.unit.is_leader(): + self.unit.status = BlockedStatus("Only singleton slurmdbd currently supported.") + event.defer() + return self.unit.status = WaitingStatus("Installing slurmdbd") - custom_repo = self.config.get("custom-slurm-repo") - successful_installation = self._slurm_manager.install(custom_repo) - - if successful_installation: + if self._slurmdbd_ops_manager.install() is not False: + self.unit.set_workload_version(self._slurmdbd_ops_manager.version) self._stored.slurm_installed = True - self.unit.status = ActiveStatus("slurmdbd successfully installed") - else: - self.unit.status = BlockedStatus("Error installing slurmdbd") - event.defer() - return + if self._slurmdbd_ops_manager.start_munge(): + logger.debug("## Munge started successfully") + else: + logger.error("## Unable to start munge") + self.unit.status = BlockedStatus("Error restarting munge") + event.defer() + return + + self.unit.status = ActiveStatus("slurmdbd successfully installed") self._check_status() - def _on_fluentbit_relation_created(self, event): + def _on_fluentbit_relation_created(self, event: RelationCreatedEvent) -> None: """Set up Fluentbit log forwarding.""" - self._configure_fluentbit() - - def _configure_fluentbit(self): logger.debug("## Configuring fluentbit") - cfg = [] - cfg.extend(self._slurm_manager.fluentbit_config_nhc) - cfg.extend(self._slurm_manager.fluentbit_config_slurm) - self._fluentbit.configure(cfg) + self._fluentbit.configure(self._slurmdbd_ops_manager.fluentbit_config_slurm) - def _on_upgrade(self, event): - """Perform upgrade operations.""" - self.unit.set_workload_version(Path("version").read_text().strip()) - - def _on_update_status(self, event): + def _on_update_status(self, event: UpdateStatusEvent) -> None: """Handle update status.""" self._check_status() - def _on_jwt_available(self, event): - """Retrieve and configure the jwt_rsa key.""" - # jwt rsa lives in slurm spool dir, it is created when slurm is installed - if not self._stored.slurm_installed: - event.defer() - return - - jwt_rsa = self._slurmdbd.get_jwt_rsa() - self._slurm_manager.configure_jwt_rsa(jwt_rsa) - self._stored.jwt_available = True - - def _on_munge_available(self, event): - """Retrieve munge key and start munged.""" - # munge is installed together with slurm - if not self._stored.slurm_installed: + def _on_slurmctld_available(self, event: SlurmctldAvailableEvent) -> None: + """Retrieve and configure the jwt_rsa and munge_key when slurmctld_available.""" + if self._stored.slurm_installed is not True: event.defer() return - munge_key = self._slurmdbd.get_munge_key() - self._slurm_manager.configure_munge_key(munge_key) + if (jwt := event.jwt_rsa) is not None: + self._slurmdbd_ops_manager.write_jwt_rsa(jwt) + if (munge_key := event.munge_key) is not None: + self._slurmdbd_ops_manager.write_munge_key(munge_key) - if self._slurm_manager.restart_munged(): - logger.debug("## Munge restarted successfully") - self._stored.munge_available = True - else: - logger.error("## Unable to restart munge") - self.unit.status = BlockedStatus("Error restarting munge") - event.defer() + self._write_config_and_restart_slurmdbd(event) def _on_database_created(self, event: DatabaseCreatedEvent) -> None: """Process the DatabaseCreatedEvent and updates the database parameters. @@ -192,7 +147,7 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: # a human to look at and resolve the proper next steps. Reprocessing the # deferred event will only result in continual errors. logger.error(f"No endpoints provided: {event.endpoints}") - self.unit.status = ErrorStatus("No database endpoints") + self.unit.status = BlockedStatus("No database endpoints provided") raise ValueError(f"Unexpected endpoint types: {event.endpoints}") for endpoint in [ep.strip() for ep in event.endpoints.split(",")]: @@ -205,9 +160,9 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: tcp_endpoints.append(endpoint) db_info = { - "db_username": event.username, - "db_password": event.password, - "db_name": SLURM_ACCT_DB, + "StorageUser": event.username, + "StoragePass": event.password, + "StorageLoc": SLURM_ACCT_DB, } if socket_endpoints: @@ -222,12 +177,12 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: # Make sure to strip the file:// off the front of the first endpoint # otherwise slurmdbd will not be able to connect to the database socket = urlparse(socket_endpoints[0]).path - self._slurmdbd_manager.set_environment_var(mysql_unix_port=f'"{socket}"') + self._slurmdbd_ops_manager.set_environment_var(mysql_unix_port=f'"{socket}"') elif tcp_endpoints: # This must be using TCP endpoint and the connection information will # be host_address:port. Only one remote mysql service will be configured # in this case. - logger.debug("Using tcp endpoints specified in the relation.") + logger.debug("Using tcp endpoints specified in the relation") if len(tcp_endpoints) > 1: logger.warning( f"{len(tcp_endpoints)} tcp endpoints are specified, " @@ -239,166 +194,136 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: addr = addr[1:-1] db_info.update( { - "db_hostname": addr, - "db_port": port, + "StorageHost": addr, + "StoragePort": port, } ) # Make sure that the MYSQL_UNIX_PORT is removed from the env file. - self._slurmdbd_manager.set_environment_var(mysql_unix_port=None) + self._slurmdbd_ops_manager.set_environment_var(mysql_unix_port=None) else: # This is 100% an error condition that the charm doesn't know how to handle # and is an unexpected condition. This happens when there are commas but no # usable data in the endpoints. logger.error(f"No endpoints provided: {event.endpoints}") - self.unit.status = ErrorStatus("No database endpoints") + self.unit.status = BlockedStatus("No database endpoints provided") raise ValueError(f"No endpoints provided: {event.endpoints}") - self.set_db_info(db_info) + self._stored.db_info = db_info self._write_config_and_restart_slurmdbd(event) - def _on_slurmctld_available(self, event): - self.on.jwt_available.emit() - self.on.munge_available.emit() - - self.on.write_config.emit() - if self._fluentbit._relation is not None: - self._configure_fluentbit() - - def _on_slurmctld_unavailable(self, event): + def _on_slurmctld_unavailable(self, event: SlurmctldUnavailableEvent) -> None: """Reset state and charm status when slurmctld broken.""" - self._stored.jwt_available = False - self._stored.munge_available = False + self._stored.slurmctld_available = False self._check_status() - def _is_leader(self): - return self.model.unit.is_leader() - - def _write_config_and_restart_slurmdbd(self, event): - """Check for prereqs before writing config/restart of slurmdbd.""" + def _write_config_and_restart_slurmdbd( + self, + event: Union[ + ConfigChangedEvent, + DatabaseCreatedEvent, + InstallEvent, + SlurmctldAvailableEvent, + ], + ) -> None: + """Check that we have what we need before we proceed.""" # Ensure all pre-conditions are met with _check_status(), if not # defer the event. if not self._check_status(): event.defer() return - slurmdbd_config = { - "slurmdbd_debug": self.config.get("slurmdbd-debug"), - **self._slurmdbd_peer.get_slurmdbd_info(), - **self._stored.db_info, - } + if ( + charm_config_slurmdbd_conf_params := self.config.get("slurmdbd-conf-parameters") + ) is not None: + if ( + charm_config_slurmdbd_conf_params + != self._stored.user_supplied_slurmdbd_conf_params + ): + logger.debug("## User supplied parameters changed.") + self._stored.user_supplied_slurmdbd_conf_params = charm_config_slurmdbd_conf_params + + if binding := self.model.get_binding("slurmctld"): + slurmdbd_full_config = { + **CHARM_MAINTAINED_PARAMETERS, + **self._stored.db_info, + **{"DbdHost": self._slurmdbd_ops_manager.hostname}, + **{"DbdAddr": f"{binding.network.ingress_address}"}, + **self._get_user_supplied_parameters(), + } + + if self._slurmctld.is_joined: + slurmdbd_full_config["AuthAltParameters"] = ( + '"jwt_key=/var/spool/slurmdbd/jwt_hs256.key"' + ) - self._slurmdbd_manager.stop() - self._slurm_manager.render_slurm_configs(slurmdbd_config) + self._slurmdbd_ops_manager.stop_slurmdbd() + self._slurmdbd_ops_manager.write_slurmdbd_conf(slurmdbd_full_config) - # At this point, we must guarantee that slurmdbd is correctly - # initialized. Its startup might take a while, so we have to wait - # for it. - self._check_slurmdbd() + # At this point, we must guarantee that slurmdbd is correctly + # initialized. Its startup might take a while, so we have to wait + # for it. + self._check_slurmdbd() - # Only the leader can set relation data on the application. - # Enforce that no one other than the leader tries to set - # application relation data. - if self.model.unit.is_leader(): - self._slurmdbd.set_slurmdbd_info_on_app_relation_data( - slurmdbd_config, - ) + # Only the leader can set relation data on the application. + # Enforce that no one other than the leader tries to set + # application relation data. + if self.model.unit.is_leader(): + self._slurmctld.set_slurmdbd_host_on_app_relation_data( + self._slurmdbd_ops_manager.hostname + ) + else: + logger.debug("Cannot get network binding. Please Debug.") + event.defer() + return self._check_status() - def _check_slurmdbd(self, max_attemps=5) -> None: + def _get_user_supplied_parameters(self) -> Dict[Any, Any]: + """Gather, parse, and return the user supplied parameters.""" + user_supplied_parameters = {} + if custom_config := self.config.get("slurmdbd-conf-parameters"): + try: + user_supplied_parameters = { + line.split("=")[0]: line.split("=")[1] + for line in str(custom_config).split("\n") + if not line.startswith("#") and line.strip() != "" + } + except IndexError as e: + logger.error(f"Could not parse user supplied parameters: {e}.") + return user_supplied_parameters + + def _check_slurmdbd(self, max_attemps: int = 5) -> None: """Ensure slurmdbd is up and running.""" logger.debug("## Checking if slurmdbd is active") for i in range(max_attemps): - if self._slurmdbd_manager.active: + if self._slurmdbd_ops_manager.is_slurmdbd_active(): logger.debug("## Slurmdbd running") break else: logger.warning("## Slurmdbd not running, trying to start it") - self.unit.status = WaitingStatus("Starting slurmdbd") - self._slurmdbd_manager.restart() + self.unit.status = WaitingStatus("Starting slurmdbd ...") + self._slurmdbd_ops_manager.restart_slurmdbd() sleep(3 + i) - if self._slurmdbd_manager.active: + if self._slurmdbd_ops_manager.is_slurmdbd_active(): self._check_status() else: - self.unit.status = BlockedStatus("Cannot start slurmdbd") + self.unit.status = BlockedStatus("Cannot start slurmdbd.") - def _check_status(self) -> bool: # noqa C901 + def _check_status(self) -> bool: """Check that we have the things we need.""" - slurm_installed = self._stored.slurm_installed - if not slurm_installed: - self.unit.status = BlockedStatus("Error installing slurm") + if self._stored.slurm_installed is not True: + self.unit.status = BlockedStatus("Error installing slurmdbd.") return False - # we must be sure to initialize the charms correctly. Slurmdbd must - # first connect to the db to be able to connect to slurmctld correctly - slurmctld_available = self._stored.jwt_available and self._stored.munge_available - statuses = { - "MySQL": { - "available": self._stored.db_info != {}, - "joined": self._stored.db_info != {}, - }, - "slurmctld": {"available": slurmctld_available, "joined": self._slurmdbd.is_joined}, - } - - relations_needed = [] - waiting_on = [] - for component in statuses.keys(): - if not statuses[component]["joined"]: - relations_needed.append(component) - if not statuses[component]["available"]: - waiting_on.append(component) - - if len(relations_needed): - msg = f"Need relations: {','.join(relations_needed)}" - self.unit.status = BlockedStatus(msg) - return False - - if len(waiting_on): - msg = f"Waiting on: {','.join(waiting_on)}" - self.unit.status = WaitingStatus(msg) - return False - - slurmdbd_info = self._slurmdbd_peer.get_slurmdbd_info() - if not slurmdbd_info: - self.unit.status = WaitingStatus("slurmdbd starting") - return False - - if not self._slurm_manager.check_munged(): - self.unit.status = WaitingStatus("munged starting") + if self._stored.db_info == {}: + self.unit.status = WaitingStatus("Waiting on: MySQL") return False - self.unit.status = ActiveStatus("slurmdbd available") + self.unit.status = ActiveStatus() return True - def get_port(self): - """Return the port from slurm-ops-manager.""" - return self._slurm_manager.port - - def get_hostname(self): - """Return the hostname from slurm-ops-manager.""" - return self._slurm_manager.hostname - - def set_db_info(self, new_db_info: Dict[str, Any]) -> None: - """Set the db_info in the stored state. - - Args: - new_db_info (Dict[str, Any]): - New backend database information to set. - """ - self._stored.db_info.update(new_db_info) - - @property - def cluster_name(self) -> str: - """Return the cluster-name.""" - return self._stored.cluster_name - - @cluster_name.setter - def cluster_name(self, name: str): - """Set the cluster-name.""" - self._stored.cluster_name = name - if __name__ == "__main__": - main(SlurmdbdCharm) + main.main(SlurmdbdCharm) diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..23a1afe --- /dev/null +++ b/src/constants.py @@ -0,0 +1,52 @@ +# Copyright 2024 Omnivector, LLC. +# See LICENSE file for licensing details. +"""Constants.""" +from pathlib import Path + +CHARM_MAINTAINED_PARAMETERS = { + "DbdPort": "6819", + "AuthType": "auth/munge", + "AuthInfo": '"socket=/var/run/munge/munge.socket.2"', + "SlurmUser": "slurm", + "PluginDir": "/usr/lib/x86_64-linux-gnu/slurm-wlm", + "PidFile": "/var/run/slurmdbd.pid", + "LogFile": "/var/log/slurm/slurmdbd.log", + "StorageType": "accounting_storage/mysql", +} + +SLURM_ACCT_DB = "slurm_acct_db" + +SLURMDBD_DEFAULTS_FILE = Path("/etc/default/slurmdbd") + +UBUNTU_HPC_PPA_KEY: str = """ +-----BEGIN PGP PUBLIC KEY BLOCK----- +Comment: Hostname: +Version: Hockeypuck 2.1.1-10-gec3b0e7 + +xsFNBGTuZb8BEACtJ1CnZe6/hv84DceHv+a54y3Pqq0gqED0xhTKnbj/E2ByJpmT +NlDNkpeITwPAAN1e3824Me76Qn31RkogTMoPJ2o2XfG253RXd67MPxYhfKTJcnM3 +CEkmeI4u2Lynh3O6RQ08nAFS2AGTeFVFH2GPNWrfOsGZW03Jas85TZ0k7LXVHiBs +W6qonbsFJhshvwC3SryG4XYT+z/+35x5fus4rPtMrrEOD65hij7EtQNaE8owuAju +Kcd0m2b+crMXNcllWFWmYMV0VjksQvYD7jwGrWeKs+EeHgU8ZuqaIP4pYHvoQjag +umqnH9Qsaq5NAXiuAIAGDIIV4RdAfQIR4opGaVgIFJdvoSwYe3oh2JlrLPBlyxyY +dayDifd3X8jxq6/oAuyH1h5K/QLs46jLSR8fUbG98SCHlRmvozTuWGk+e07ALtGe +sGv78ToHKwoM2buXaTTHMwYwu7Rx8LZ4bZPHdersN1VW/m9yn1n5hMzwbFKy2s6/ +D4Q2ZBsqlN+5aW2q0IUmO+m0GhcdaDv8U7RVto1cWWPr50HhiCi7Yvei1qZiD9jq +57oYZVqTUNCTPxi6NeTOdEc+YqNynWNArx4PHh38LT0bqKtlZCGHNfoAJLPVYhbB +b2AHj9edYtHU9AAFSIy+HstET6P0UDxy02IeyE2yxoUBqdlXyv6FL44E+wARAQAB +zRxMYXVuY2hwYWQgUFBBIGZvciBVYnVudHUgSFBDwsGOBBMBCgA4FiEErocSHcPk +oLD4H/Aj9tDF1ca+s3sFAmTuZb8CGwMFCwkIBwIGFQoJCAsCBBYCAwECHgECF4AA +CgkQ9tDF1ca+s3sz3w//RNawsgydrutcbKf0yphDhzWS53wgfrs2KF1KgB0u/H+u +6Kn2C6jrVM0vuY4NKpbEPCduOj21pTCepL6PoCLv++tICOLVok5wY7Zn3WQFq0js +Iy1wO5t3kA1cTD/05v/qQVBGZ2j4DsJo33iMcQS5AjHvSr0nu7XSvDDEE3cQE55D +87vL7lgGjuTOikPh5FpCoS1gpemBfwm2Lbm4P8vGOA4/witRjGgfC1fv1idUnZLM +TbGrDlhVie8pX2kgB6yTYbJ3P3kpC1ZPpXSRWO/cQ8xoYpLBTXOOtqwZZUnxyzHh +gM+hv42vPTOnCo+apD97/VArsp59pDqEVoAtMTk72fdBqR+BB77g2hBkKESgQIEq +EiE1/TOISioMkE0AuUdaJ2ebyQXugSHHuBaqbEC47v8t5DVN5Qr9OriuzCuSDNFn +6SBHpahN9ZNi9w0A/Yh1+lFfpkVw2t04Q2LNuupqOpW+h3/62AeUqjUIAIrmfeML +IDRE2VdquYdIXKuhNvfpJYGdyvx/wAbiAeBWg0uPSepwTfTG59VPQmj0FtalkMnN +ya2212K5q68O5eXOfCnGeMvqIXxqzpdukxSZnLkgk40uFJnJVESd/CxHquqHPUDE +fy6i2AnB3kUI27D4HY2YSlXLSRbjiSxTfVwNCzDsIh7Czefsm6ITK2+cVWs0hNQ= +=cs1s +-----END PGP PUBLIC KEY BLOCK----- +""" diff --git a/src/interface_slurmctld.py b/src/interface_slurmctld.py new file mode 100644 index 0000000..ec692d8 --- /dev/null +++ b/src/interface_slurmctld.py @@ -0,0 +1,118 @@ +"""Slurmctld interface for slurmdbd-operator.""" + +import json +import logging +from typing import List, Union + +from ops import ( + EventBase, + EventSource, + Object, + ObjectEvents, + Relation, + RelationBrokenEvent, + RelationChangedEvent, +) + +logger = logging.getLogger() + + +class SlurmctldAvailableEvent(EventBase): + """Emitted when slurmctld is unavailable.""" + + def __init__(self, handle, munge_key, jwt_rsa): + super().__init__(handle) + + self.munge_key = munge_key + self.jwt_rsa = jwt_rsa + + def snapshot(self): + """Snapshot the event data.""" + return { + "munge_key": self.munge_key, + "jwt_rsa": self.jwt_rsa, + } + + def restore(self, snapshot): + """Restore the snapshot of the event data.""" + self.munge_key = snapshot.get("munge_key") + self.jwt_rsa = snapshot.get("jwt_rsa") + + +class SlurmctldUnavailableEvent(EventBase): + """Emitted when slurmctld joins the relation.""" + + +class Events(ObjectEvents): + """Slurmctld relation events.""" + + slurmctld_available = EventSource(SlurmctldAvailableEvent) + slurmctld_unavailable = EventSource(SlurmctldUnavailableEvent) + + +class Slurmctld(Object): + """Slurmctld interface for slurmdbd.""" + + on = Events() # pyright: ignore [reportIncompatibleMethodOverride, reportAssignmentType] + + def __init__(self, charm, relation_name): + """Observe relation lifecycle events.""" + super().__init__(charm, relation_name) + + self._charm = charm + self._relation_name = relation_name + + self.framework.observe( + self._charm.on[self._relation_name].relation_changed, + self._on_relation_changed, + ) + + self.framework.observe( + self._charm.on[self._relation_name].relation_broken, + self._on_relation_broken, + ) + + @property + def _relations(self) -> Union[List[Relation], None]: + return self.model.relations.get(self._relation_name) + + @property + def is_joined(self) -> bool: + """Return True if self._relation is not None.""" + return True if self._relations else False + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Handle the relation-changed event. + + Get the cluster_info (munge_key and jwt_rsa) from slurmctld and emit to the charm. + """ + if app := event.app: + event_app_data = event.relation.data[app] + if cluster_info_json := event_app_data.get("cluster_info"): + try: + cluster_info = json.loads(cluster_info_json) + except json.JSONDecodeError as e: + logger.debug(e) + raise (e) + + self.on.slurmctld_available.emit(**cluster_info) + logger.debug(f"## 'cluster_info': {cluster_info}.") + else: + logger.debug("'cluster_info' not in application relation data.") + else: + logger.debug("## No application in the event.") + + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + """Clear the application relation data and emit the event.""" + self.set_slurmdbd_host_on_app_relation_data("") + self.on.slurmctld_unavailable.emit() + + def set_slurmdbd_host_on_app_relation_data(self, slurmdbd_host: str) -> None: + """Send slurmdbd_info to slurmctld.""" + # Iterate over each of the relations setting the relation data. + if (relations := self._relations) is not None: + logger.debug(f"## Setting slurmdbd_host on app relation data: {slurmdbd_host}") + for relation in relations: + relation.data[self.model.app]["slurmdbd_host"] = slurmdbd_host + else: + logger.debug("## No relation, not setting data.") diff --git a/src/interface_slurmdbd.py b/src/interface_slurmdbd.py deleted file mode 100644 index e113f81..0000000 --- a/src/interface_slurmdbd.py +++ /dev/null @@ -1,149 +0,0 @@ -#! /usr/bin/env python3 -"""Slurmdbd.""" -import json -import logging - -from ops.framework import ( - EventBase, - EventSource, - Object, - ObjectEvents, - StoredState, -) - -logger = logging.getLogger() - - -class SlurmctldUnAvailableEvent(EventBase): - """Emitted when slurmctld is unavailable.""" - - -class SlurmctldAvailableEvent(EventBase): - """Emitted when slurmctld joins the relation.""" - - -class SlurmdbdEvents(ObjectEvents): - """Slurmdbd relation events.""" - - slurmctld_available = EventSource(SlurmctldAvailableEvent) - slurmctld_unavailable = EventSource(SlurmctldUnAvailableEvent) - - -class Slurmdbd(Object): - """Slurmdbd.""" - - _stored = StoredState() - on = SlurmdbdEvents() - - def __init__(self, charm, relation_name): - """Observe relation lifecycle events.""" - super().__init__(charm, relation_name) - - self._charm = charm - self._relation_name = relation_name - - self._stored.set_default( - munge_key=str(), - jwt_key=str(), - slurmctld_joined=False, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_created, - self._on_relation_created, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_joined, - self._on_relation_joined, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_broken, - self._on_relation_broken, - ) - - @property - def is_joined(self): - """Return True if juju related slurmdbd <-> slurmctld.""" - return self._stored.slurmctld_joined - - @is_joined.setter - def is_joined(self, flag): - """Set the is_joined property.""" - self._stored.slurmctld_joined = flag - - def _on_relation_created(self, event): - """Handle the relation-created event.""" - self.is_joined = True - - def _on_relation_joined(self, event): - """Handle the relation-joined event. - - Get the munge_key and jwt_rsa from slurmctld and save it to the charm - stored state. - """ - # Since we are in relation-joined (with the app on the other side) - # we can almost guarantee that the app object will exist in - # the event, but check for it just in case. - event_app_data = event.relation.data.get(event.app) - if not event_app_data: - event.defer() - return - - # slurmctld sets the munge_key on the relation-created event - # which happens before relation-joined. We can almost guarantee that - # the munge key will exist at this point, but check for it just in case. - munge_key = event_app_data.get("munge_key") - if not munge_key: - event.defer() - return - - # slurmctld sets the jwt_rsa on the relation-created event - # which happens before relation-joined. We can almost guarantee that - # the jwt_rsa will exist at this point, but check for it just in case. - jwt_rsa = event_app_data.get("jwt_rsa") - if not jwt_rsa: - event.defer() - return - - # Store the munge_key and jwt_rsa in the interface's stored state - # object and emit the slurmctld_available event. - self._store_munge_key(munge_key) - self._store_jwt_rsa(jwt_rsa) - self.on.slurmctld_available.emit() - - self._charm.cluster_name = event_app_data.get("cluster_name") - - def _on_relation_broken(self, event): - """Clear the application relation data and emit the event.""" - self.set_slurmdbd_info_on_app_relation_data("") - self.is_joined = False - self.on.slurmctld_unavailable.emit() - - def set_slurmdbd_info_on_app_relation_data(self, slurmdbd_info): - """Send slurmdbd_info to slurmctld.""" - logger.debug(f"## Setting info in app relation data: {slurmdbd_info}") - relations = self.framework.model.relations["slurmdbd"] - # Iterate over each of the relations setting the relation data. - for relation in relations: - if slurmdbd_info != "": - relation.data[self.model.app]["slurmdbd_info"] = json.dumps(slurmdbd_info) - else: - relation.data[self.model.app]["slurmdbd_info"] = "" - - def _store_munge_key(self, munge_key): - """Set the munge key in the stored state.""" - self._stored.munge_key = munge_key - - def get_munge_key(self): - """Retrieve the munge key from the stored state.""" - return self._stored.munge_key - - def _store_jwt_rsa(self, jwt_rsa): - """Store the jwt_rsa in the interface stored state.""" - self._stored.jwt_rsa = jwt_rsa - - def get_jwt_rsa(self): - """Retrieve the jwt_rsa from the stored state.""" - return self._stored.jwt_rsa diff --git a/src/interface_slurmdbd_peer.py b/src/interface_slurmdbd_peer.py deleted file mode 100644 index d4abd60..0000000 --- a/src/interface_slurmdbd_peer.py +++ /dev/null @@ -1,183 +0,0 @@ -#!/usr/bin/env python3 -"""SlurmdbdPeer.""" -import copy -import json -import logging -import subprocess - -from ops.framework import EventBase, EventSource, Object, ObjectEvents - -logger = logging.getLogger() - - -class SlurmdbdPeerAvailableEvent(EventBase): - """Emitted in the relation_changed event when a peer comes online.""" - - -class SlurmdbdPeerRelationEvents(ObjectEvents): - """Slurmdbd peer relation events.""" - - slurmdbd_peer_available = EventSource(SlurmdbdPeerAvailableEvent) - - -class SlurmdbdPeer(Object): - """SlurmdbdPeer Interface.""" - - on = SlurmdbdPeerRelationEvents() - - def __init__(self, charm, relation_name): - """Initialize and observe.""" - super().__init__(charm, relation_name) - self._charm = charm - self._relation_name = relation_name - - self.framework.observe( - self._charm.on[self._relation_name].relation_created, - self._on_relation_created, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_changed, - self._on_relation_changed, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_departed, - self._on_relation_departed, - ) - - def _on_relation_created(self, event): - """Set hostname and port on the unit data.""" - relation = self.framework.model.get_relation(self._relation_name) - unit_relation_data = relation.data[self.model.unit] - - unit_relation_data["hostname"] = self._charm.get_hostname() - unit_relation_data["port"] = self._charm.get_port() - - # Call _on_relation_changed to assemble the slurmdbd_info and - # emit the slurmdbd_peer_available event. - self._on_relation_changed(event) - - def _on_relation_changed(self, event): - """Use the leader and app relation data to schedule active/backup.""" - # We only modify the slurmdbd queue if we are the leader. - # As such, we don't need to perform any operations here - # if we are not the leader. - if self.framework.model.unit.is_leader(): - relation = self.framework.model.get_relation(self._relation_name) - - app_relation_data = relation.data[self.model.app] - unit_relation_data = relation.data[self.model.unit] - - slurmdbd_peers = _get_active_peers() - slurmdbd_peers_tmp = copy.deepcopy(slurmdbd_peers) - - active_slurmdbd = app_relation_data.get("active_slurmdbd") - backup_slurmdbd = app_relation_data.get("backup_slurmdbd") - - # Account for the active slurmdbd - # In this case, tightly couple the active slurmdbd to the leader. - # - # If we are the leader but are not the active slurmdbd, - # then the previous slurmdbd leader must have died. - # Set our unit to the active_slurmdbd. - if active_slurmdbd != self.model.unit.name: - app_relation_data["active_slurmdbd"] = self.model.unit.name - - # Account for the backup and standby slurmdbd - # - # If the backup slurmdbd exists in the application relation data - # then check that it also exists in the slurmdbd_peers. If it does - # exist in the slurmdbd peers then remove it from the list of - # active peers and set the rest of the peers to be standbys. - if backup_slurmdbd: - # Just because the backup_slurmdbd exists in the application - # data doesn't mean that it really exists. Check that the - # backup_slurmdbd that we have in the application data still - # exists in the list of active units. If the backup_slurmdbd - # isn't in the list of active units then check for - # slurmdbd_peers > 0 and try to promote a standby to a backup. - if backup_slurmdbd in slurmdbd_peers: - slurmdbd_peers_tmp.remove(backup_slurmdbd) - app_relation_data["standby_slurmdbd"] = json.dumps(slurmdbd_peers_tmp) - else: - if len(slurmdbd_peers) > 0: - app_relation_data["backup_slurmdbd"] = slurmdbd_peers_tmp.pop() - app_relation_data["standby_slurmdbd"] = json.dumps(slurmdbd_peers_tmp) - else: - app_relation_data["backup_slurmdbd"] = "" - app_relation_data["standby_slurmdbd"] = json.dumps([]) - else: - if len(slurmdbd_peers) > 0: - app_relation_data["backup_slurmdbd"] = slurmdbd_peers_tmp.pop() - app_relation_data["standby_slurmdbd"] = json.dumps(slurmdbd_peers_tmp) - else: - app_relation_data["standby_slurmdbd"] = json.dumps([]) - - ctxt = {} - backup_slurmdbd = app_relation_data.get("backup_slurmdbd") - - # NOTE: We only care about the active and backup slurdbd. - # Set the active slurmdbd info and check for and set the - # backup slurmdbd information if one exists. - ctxt["active_slurmdbd_ingress_address"] = unit_relation_data["ingress-address"] - ctxt["active_slurmdbd_hostname"] = self._charm.get_hostname() - ctxt["active_slurmdbd_port"] = str(self._charm.get_port()) - - # If we have > 0 slurmdbd (also have a backup), iterate over - # them retrieving the info for the backup and set it along with - # the info for the active slurmdbd, then emit the - # 'slurmdbd_peer_available' event. - if backup_slurmdbd: - for unit in relation.units: - if unit.name == backup_slurmdbd: - unit_data = relation.data[unit] - ctxt["backup_slurmdbd_ingress_address"] = unit_data["ingress-address"] - ctxt["backup_slurmdbd_hostname"] = unit_data["hostname"] - ctxt["backup_slurmdbd_port"] = unit_data["port"] - else: - ctxt["backup_slurmdbd_ingress_address"] = "" - ctxt["backup_slurmdbd_hostname"] = "" - ctxt["backup_slurmdbd_port"] = "" - - app_relation_data["slurmdbd_info"] = json.dumps(ctxt) - self.on.slurmdbd_peer_available.emit() - - def _on_relation_departed(self, event): - self.on.slurmdbd_peer_available.emit() - - @property - def _relation(self): - return self.framework.model.get_relation(self._relation_name) - - def get_slurmdbd_info(self) -> dict: - """Return slurmdbd info.""" - relation = self._relation - if relation: - app = relation.app - if app: - slurmdbd_info = relation.data[app].get("slurmdbd_info") - if slurmdbd_info: - return json.loads(slurmdbd_info) - return {} - - -def _related_units(relid): - """List of related units.""" - units_cmd_line = ["relation-list", "--format=json", "-r", relid] - return json.loads(subprocess.check_output(units_cmd_line).decode("UTF-8")) or [] - - -def _relation_ids(reltype): - """List of relation_ids.""" - relid_cmd_line = ["relation-ids", "--format=json", reltype] - return json.loads(subprocess.check_output(relid_cmd_line).decode("UTF-8")) or [] - - -def _get_active_peers(): - """Return the active_units.""" - active_units = [] - for rel_id in _relation_ids("slurmdbd-peer"): - for unit in _related_units(rel_id): - active_units.append(unit) - return active_units diff --git a/src/slurmdbd_ops.py b/src/slurmdbd_ops.py new file mode 100644 index 0000000..b7d07bf --- /dev/null +++ b/src/slurmdbd_ops.py @@ -0,0 +1,358 @@ +# Copyright 2020-2024 Omnivector, LLC. +"""SlurmdbdOpsManager.""" +import logging +import os +import socket +import subprocess +import textwrap +from base64 import b64decode +from datetime import datetime +from grp import getgrnam +from pathlib import Path +from pwd import getpwnam +from typing import Optional + +import charms.operator_libs_linux.v0.apt as apt +import charms.operator_libs_linux.v1.systemd as systemd +import distro +from constants import SLURMDBD_DEFAULTS_FILE, UBUNTU_HPC_PPA_KEY + +logger = logging.getLogger() + + +class SlurmdbdManagerError(BaseException): + """Exception for use with SlurmdbdManager.""" + + def __init__(self, message): + super().__init__(message) + self.message = message + + +class CharmedHPCPackageLifecycleManager: + """Facilitate ubuntu-hpc slurm component package lifecycles.""" + + def __init__(self, package_name: str): + self._package_name = package_name + self._keyring_path = Path(f"/usr/share/keyrings/ubuntu-hpc-{self._package_name}.asc") + + def _repo(self) -> apt.DebianRepository: + """Return the ubuntu-hpc repo.""" + ppa_url: str = "https://ppa.launchpadcontent.net/ubuntu-hpc/slurm-wlm-23.02/ubuntu" + sources_list: str = ( + f"deb [signed-by={self._keyring_path}] {ppa_url} {distro.codename()} main" + ) + return apt.DebianRepository.from_repo_line(sources_list) + + def install(self) -> bool: + """Install package using lib apt.""" + package_installed = False + + if self._keyring_path.exists(): + self._keyring_path.unlink() + self._keyring_path.write_text(UBUNTU_HPC_PPA_KEY) + + repositories = apt.RepositoryMapping() + repositories.add(self._repo()) + + try: + apt.update() + apt.add_package([self._package_name]) + package_installed = True + except apt.PackageNotFoundError: + logger.error(f"'{self._package_name}' not found in package cache or on system.") + except apt.PackageError as e: + logger.error(f"Could not install '{self._package_name}'. Reason: {e.message}") + + return package_installed + + def uninstall(self) -> None: + """Uninstall the package using libapt.""" + if apt.remove_package(self._package_name): + logger.info(f"'{self._package_name}' removed from system.") + else: + logger.error(f"'{self._package_name}' not found on system.") + + repositories = apt.RepositoryMapping() + repositories.disable(self._repo()) + + if self._keyring_path.exists(): + self._keyring_path.unlink() + + def upgrade_to_latest(self) -> None: + """Upgrade package to latest.""" + try: + slurm_package = apt.DebianPackage.from_system(self._package_name) + slurm_package.ensure(apt.PackageState.Latest) + logger.info(f"Updated '{self._package_name}' to: {slurm_package.version.number}.") + except apt.PackageNotFoundError: + logger.error(f"'{self._package_name}' not found in package cache or on system.") + except apt.PackageError as e: + logger.error(f"Could not install '{self._package_name}'. Reason: {e.message}") + + def version(self) -> str: + """Return the package version.""" + slurmdbd_vers = "" + try: + slurmdbd_vers = apt.DebianPackage.from_installed_package( + self._package_name + ).version.number + except apt.PackageNotFoundError: + logger.error(f"'{self._package_name}' not found on system.") + return slurmdbd_vers + + +class SlurmdbdOpsManager: + """SlurmdbdOpsManager.""" + + def __init__(self): + """Set the initial attribute values.""" + self._slurm_component = "slurmdbd" + + self._slurm_state_dir = Path("/var/spool/slurmdbd") + self._slurmdbd_conf_dir = Path("/etc/slurm") + self._slurmdbd_log_dir = Path("/var/log/slurm") + self._slurm_user = "slurm" + self._slurm_group = "slurm" + + self._munge_package = CharmedHPCPackageLifecycleManager("munge") + self._slurmdbd_package = CharmedHPCPackageLifecycleManager("slurmdbd") + + def write_slurmdbd_conf(self, slurmdbd_parameters: dict) -> None: + """Render slurmdbd.conf.""" + slurmdbd_conf = self._slurmdbd_conf_dir / "slurmdbd.conf" + slurm_user_uid = getpwnam(self._slurm_user).pw_uid + slurm_group_gid = getgrnam(self._slurm_group).gr_gid + + header = textwrap.dedent( + f""" + # + # {slurmdbd_conf} generated at {datetime.now()} + # + + """ + ) + slurmdbd_conf.write_text( + header + "\n".join([f"{k}={v}" for k, v in slurmdbd_parameters.items() if v != ""]) + ) + + slurmdbd_conf.chmod(0o600) + os.chown(f"{slurmdbd_conf}", slurm_user_uid, slurm_group_gid) + + def write_munge_key(self, munge_key_data: str) -> None: + """Base64 decode and write the munge key.""" + munge_key_path = Path("/etc/munge/munge.key") + munge_key_path.write_bytes(b64decode(munge_key_data.encode())) + + def write_jwt_rsa(self, jwt_rsa: str) -> None: + """Write the jwt_rsa key and set permissions.""" + jwt_rsa_path = self._slurm_state_dir / "jwt_hs256.key" + slurm_user_uid = getpwnam(self._slurm_user).pw_uid + slurm_group_gid = getgrnam(self._slurm_group).gr_gid + + # Write the jwt_rsa key to the file and chmod 0600 + chown to slurm_user. + jwt_rsa_path.write_text(jwt_rsa) + jwt_rsa_path.chmod(0o600) + + os.chown(f"{jwt_rsa_path}", slurm_user_uid, slurm_group_gid) + + def install(self) -> bool: + """Install slurmdbd and munge to the system and setup paths.""" + slurm_user_uid = getpwnam(self._slurm_user).pw_uid + slurm_group_gid = getgrnam(self._slurm_group).gr_gid + + if self._slurmdbd_package.install() is not True: + return False + systemd.service_stop("slurmdbd") + + if self._munge_package.install() is not True: + return False + systemd.service_stop("munge") + + # Create needed paths with correct permissions. + for syspath in [self._slurmdbd_conf_dir, self._slurmdbd_log_dir, self._slurm_state_dir]: + if not syspath.exists(): + syspath.mkdir() + os.chown(f"{syspath}", slurm_user_uid, slurm_group_gid) + return True + + def stop_slurmdbd(self) -> None: + """Stop slurmdbd.""" + systemd.service_stop("slurmdbd") + + def is_slurmdbd_active(self) -> bool: + """Get if slurmdbd daemon is active or not.""" + return systemd.service_running("slurmdbd") + + def stop_munge(self) -> None: + """Stop munge.""" + systemd.service_stop("munge") + + def start_munge(self) -> bool: + """Start the munged process. + + Return True on success, and False otherwise. + """ + logger.debug("Starting munge.") + try: + systemd.service_start("munge") + # Ignore pyright error for is not a valid exception class, reportGeneralTypeIssues + except SlurmdbdManagerError( + "Cannot start munge." + ) as e: # pyright: ignore [reportGeneralTypeIssues] + logger.error(e) + return False + return self.check_munged() + + def check_munged(self) -> bool: + """Check if munge is working correctly.""" + if not systemd.service_running("munge"): + return False + + output = "" + # check if munge is working, i.e., can use the credentials correctly + try: + logger.debug("## Testing if munge is working correctly") + munge = subprocess.Popen( + ["munge", "-n"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + if munge is not None: + unmunge = subprocess.Popen( + ["unmunge"], stdin=munge.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + output = unmunge.communicate()[0].decode() + if "Success" in output: + logger.debug(f"## Munge working as expected: {output}") + return True + logger.error(f"## Munge not working: {output}") + except subprocess.CalledProcessError as e: + logger.error(f"## Error testing munge: {e}") + + return False + + def restart_slurmdbd(self) -> bool: + """Restart the slurmdbd process. + + Return True on success, and False otherwise. + """ + logger.debug("Attempting to restart slurmdbd.") + try: + systemd.service_restart("slurmdbd") + # Ignore pyright error for is not a valid exception class, reportGeneralTypeIssues + except SlurmdbdManagerError( + "Cannot restart slurmdbd." + ) as e: # pyright: ignore [reportGeneralTypeIssues] + logger.error(e) + return False + return True + + @property + def fluentbit_config_slurm(self) -> list: + """Return Fluentbit configuration parameters to forward Slurm logs.""" + log_file = self._slurmdbd_log_dir / "slurmdbd.log" + + cfg = [ + { + "input": [ + ("name", "tail"), + ("path", log_file.as_posix()), + ("path_key", "filename"), + ("tag", self._slurm_component), + ("parser", "slurm"), + ] + }, + { + "parser": [ + ("name", "slurm"), + ("format", "regex"), + ("regex", r"^\[(?