diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index aaed2e5..3bc2dd8 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -331,7 +331,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 39 +LIBPATCH = 40 PYDEPS = ["ops>=2.0.0"] @@ -391,6 +391,10 @@ class IllegalOperationError(DataInterfacesError): """To be used when an operation is not allowed to be performed.""" +class PrematureDataAccessError(DataInterfacesError): + """To be raised when the Relation Data may be accessed (written) before protocol init complete.""" + + ############################################################################## # Global helpers / utilities ############################################################################## @@ -1453,6 +1457,8 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: class ProviderData(Data): """Base provides-side of the data products relation.""" + RESOURCE_FIELD = "database" + def __init__( self, model: Model, @@ -1618,6 +1624,15 @@ def _fetch_my_specific_relation_data( def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: """Set values for fields not caring whether it's a secret or not.""" req_secret_fields = [] + + keys = set(data.keys()) + if self.fetch_relation_field(relation.id, self.RESOURCE_FIELD) is None and ( + keys - {"endpoints", "read-only-endpoints", "replset"} + ): + raise PrematureDataAccessError( + "Premature access to relation data, update is forbidden before the connection is initialized." + ) + if relation.app: req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) @@ -3290,6 +3305,8 @@ class KafkaRequiresEvents(CharmEvents): class KafkaProviderData(ProviderData): """Provider-side of the Kafka relation.""" + RESOURCE_FIELD = "topic" + def __init__(self, model: Model, relation_name: str) -> None: super().__init__(model, relation_name) @@ -3539,6 +3556,8 @@ class OpenSearchRequiresEvents(CharmEvents): class OpenSearchProvidesData(ProviderData): """Provider-side of the OpenSearch relation.""" + RESOURCE_FIELD = "index" + def __init__(self, model: Model, relation_name: str) -> None: super().__init__(model, relation_name) diff --git a/poetry.lock b/poetry.lock index 457db57..d40a600 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,27 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "anyio" +version = "4.6.2.post1" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.9" +files = [ + {file = "anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d"}, + {file = "anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +trio = ["trio (>=0.26.1)"] + [[package]] name = "asttokens" version = "2.4.1" @@ -423,6 +445,63 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.6" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"}, + {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.27.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "hvac" version = "2.3.0" @@ -605,6 +684,36 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "lightkube" +version = "0.15.4" +description = "Lightweight kubernetes client library" +optional = false +python-versions = "*" +files = [ + {file = "lightkube-0.15.4-py3-none-any.whl", hash = "sha256:7dde49694f2933b757ee6dfee0e028a56d7a13f47476bf54a21ec7cee343b09b"}, + {file = "lightkube-0.15.4.tar.gz", hash = "sha256:fe7939f8da5b68d80809243c9abf601fca2e5425228bb02c2ad23ed3e56401a8"}, +] + +[package.dependencies] +httpx = ">=0.24.0" +lightkube-models = ">=1.15.12.0" +PyYAML = "*" + +[package.extras] +dev = ["pytest", "pytest-asyncio (<0.17.0)", "respx"] + +[[package]] +name = "lightkube-models" +version = "1.31.1.8" +description = "Models and Resources for lightkube module" +optional = false +python-versions = "*" +files = [ + {file = "lightkube-models-1.31.1.8.tar.gz", hash = "sha256:14fbfa990b4d3393fa4ac3e9e46d67514c4d659508e296b30f1a5d254eecc097"}, + {file = "lightkube_models-1.31.1.8-py3-none-any.whl", hash = "sha256:50c0e2dd2c125cd9b50e93269e2d212bcbec19f7b00de91aa66a5ec320772fae"}, +] + [[package]] name = "macaroonbakery" version = "1.3.4" @@ -1330,6 +1439,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -1569,4 +1689,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d350e24f1530101e8ce086ce317f86b141e32a65bf5c6cf38026c47cd1e6b294" +content-hash = "9afe6218f613e098a779acfe673103953ea929f58307a05266e902b0964641ac" diff --git a/pyproject.toml b/pyproject.toml index 478374b..1c68d60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,14 +2,7 @@ # See LICENSE file for licensing details. [tool.poetry] -name = "postgresql-test-app" -version = "0.0.1-dev.0" -description = "" -authors = [] -license = "Apache-2.0" -readme = "README.md" -homepage = "https://charmhub.io/postgresql-test-app" -repository = "https://github.com/canonical/postgresql-test-app" +package-mode = false [tool.poetry.dependencies] python = "^3.10" @@ -34,12 +27,12 @@ optional = true [tool.poetry.group.lint.dependencies] codespell = "2.3.0" - [tool.poetry.group.integration] optional = true [tool.poetry.group.integration.dependencies] pytest = "^8.3.3" +lightkube = "^0.15.4" pytest-github-secrets = {git = "https://github.com/canonical/data-platform-workflows", tag = "v23.0.4", subdirectory = "python/pytest_plugins/github_secrets"} pytest-operator = "^0.38.0" pytest-operator-cache = {git = "https://github.com/canonical/data-platform-workflows", tag = "v23.0.4", subdirectory = "python/pytest_plugins/pytest_operator_cache"} diff --git a/src/charm.py b/src/charm.py index aaaa699..209bf39 100755 --- a/src/charm.py +++ b/src/charm.py @@ -22,9 +22,7 @@ DatabaseEndpointsChangedEvent, DatabaseRequires, ) -from ops.charm import ActionEvent, CharmBase -from ops.main import main -from ops.model import ActiveStatus, Relation +from ops import ActionEvent, ActiveStatus, CharmBase, Relation, StartEvent, main from tenacity import RetryError, Retrying, stop_after_delay, wait_fixed logger = logging.getLogger(__name__) @@ -64,10 +62,8 @@ def __init__(self, *args): # Events related to the first database that is requested # (these events are defined in the database requires charm library). - self.first_database_name = f'{self.app.name.replace("-", "_")}_database' - self.database = DatabaseRequires( - self, "database", self.first_database_name, EXTRA_USER_ROLES - ) + self.database_name = f'{self.app.name.replace("-", "_")}_database' + self.database = DatabaseRequires(self, "database", self.database_name, EXTRA_USER_ROLES) self.framework.observe(self.database.on.database_created, self._on_database_created) self.framework.observe( self.database.on.endpoints_changed, self._on_database_endpoints_changed @@ -163,9 +159,31 @@ def __init__(self, *args): self.framework.observe(self.on.run_sql_action, self._on_run_sql_action) self.framework.observe(self.on.test_tls_action, self._on_test_tls_action) - def _on_start(self, _) -> None: + def are_writes_running(self) -> bool: + """Returns whether continuous writes script is running.""" + try: + os.kill(int(self.app_peer_data[PROC_PID_KEY])) + return True + except Exception: + return False + + def _on_start(self, event: StartEvent) -> None: """Only sets an Active status.""" self.unit.status = ActiveStatus() + if ( + self.model.unit.is_leader() + and PROC_PID_KEY in self.app_peer_data + and not self.are_writes_running() + ): + try: + writes = self._get_db_writes() + except Exception: + logger.debug("Connection to db not yet available") + event.defer() + return + if writes > 0: + logger.info("Restarting continuous writes from db") + self._start_continuous_writes(writes + 1) # First database events observers. def _on_database_created(self, event: DatabaseCreatedEvent) -> None: @@ -256,7 +274,7 @@ def _connection_string(self) -> Optional[str]: if db_data: data = db_data[0] else: - data = list(self.first_database.fetch_relation_data().values())[0] + data = list(self.database.fetch_relation_data().values())[0] username = data.get("username") password = data.get("password") endpoints = data.get("endpoints") @@ -343,20 +361,24 @@ def _on_start_continuous_writes_action(self, event: ActionEvent) -> None: self._start_continuous_writes(1) event.set_results({"result": "True"}) - def _on_show_continuous_writes_action(self, event: ActionEvent) -> None: - """Count the continuous writes.""" + def _get_db_writes(self) -> int: try: with psycopg2.connect( self._connection_string ) as connection, connection.cursor() as cursor: connection.autocommit = True cursor.execute("SELECT COUNT(*) FROM continuous_writes;") - event.set_results({"writes": cursor.fetchone()[0]}) + writes = cursor.fetchone()[0] except Exception: - event.set_results({"writes": -1}) + writes = -1 logger.exception("Unable to count writes") finally: connection.close() + return writes + + def _on_show_continuous_writes_action(self, event: ActionEvent) -> None: + """Count the continuous writes.""" + event.set_results({"writes": self._get_db_writes()}) def _on_stop_continuous_writes_action(self, event: ActionEvent) -> None: """Stops the continuous writes process.""" diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py new file mode 100644 index 0000000..3ce2ee0 --- /dev/null +++ b/tests/integration/helpers.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import logging +import subprocess + +from pytest_operator.plugin import OpsTest + +logger = logging.getLogger(__name__) + + +async def run_command_on_unit(ops_test: OpsTest, unit_name: str, command: str) -> str: + """Run a command on a specific unit. + + Args: + ops_test: The ops test framework instance + unit_name: The name of the unit to run the command on + command: The command to run + + Returns: + the command output if it succeeds, otherwise raises an exception. + """ + complete_command = ["exec", "--unit", unit_name, "--", *command.split()] + return_code, stdout, _ = await ops_test.juju(*complete_command) + if return_code != 0: + logger.error(stdout) + raise Exception( + f"Expected command '{command}' to succeed instead it failed: {return_code}" + ) + return stdout + + +async def get_machine_from_unit(ops_test: OpsTest, unit_name: str) -> str: + """Get the name of the machine from a specific unit. + + Args: + ops_test: The ops test framework instance + unit_name: The name of the unit to get the machine + + Returns: + The name of the machine. + """ + raw_hostname = await run_command_on_unit(ops_test, unit_name, "hostname") + return raw_hostname.strip() + + +async def restart_machine(ops_test: OpsTest, unit_name: str) -> None: + """Restart the machine where a unit run on. + + Args: + ops_test: The ops test framework instance + unit_name: The name of the unit to restart the machine + """ + raw_hostname = await get_machine_from_unit(ops_test, unit_name) + restart_machine_command = f"lxc restart {raw_hostname}" + subprocess.check_call(restart_machine_command.split()) diff --git a/tests/integration/test_smoke.py b/tests/integration/test_smoke.py index b6bbeb1..68f0e19 100644 --- a/tests/integration/test_smoke.py +++ b/tests/integration/test_smoke.py @@ -8,8 +8,12 @@ import pytest from juju.relation import Relation +from lightkube.core.client import Client +from lightkube.resources.core_v1 import Pod from pytest_operator.plugin import OpsTest +from .helpers import restart_machine + logger = logging.getLogger(__name__) TEST_APP_NAME = "postgresql-test-app" @@ -107,3 +111,62 @@ async def test_smoke(ops_test: OpsTest) -> None: maximum = int(maximum) assert writes == count == maximum + + await ( + await ops_test.model.applications[TEST_APP_NAME] + .units[0] + .run_action("clear-continuous-writes") + ).wait() + + +@pytest.mark.group(1) +async def test_restart(ops_test: OpsTest) -> None: + """Verify that the charm works with latest Postgresql and Pgbouncer.""" + is_k8s = ops_test.model.info.provider_type == "kubernetes" + + logger.info("Start continuous writes") + await ( + await ops_test.model.applications[TEST_APP_NAME] + .units[0] + .run_action("start-continuous-writes") + ).wait() + + time.sleep(10) + + results = await ( + await ops_test.model.applications[TEST_APP_NAME] + .units[0] + .run_action("show-continuous-writes") + ).wait() + early_writes = int(results.results["writes"]) + + if is_k8s: + logger.info("Deleting the pod") + client = Client(namespace=ops_test.model.info.name) + client.delete(Pod, name=f"{TEST_APP_NAME}-0") + else: + logger.info("Restarting lxc") + await restart_machine(ops_test, ops_test.model.applications[TEST_APP_NAME].units[0].name) + + logger.info("Wait for idle") + await ops_test.model.wait_for_idle(apps=[TEST_APP_NAME], status="active", timeout=600) + + logger.info("Check that writes are increasing") + results = await ( + await ops_test.model.applications[TEST_APP_NAME] + .units[0] + .run_action("show-continuous-writes") + ).wait() + show_writes = int(results.results["writes"]) + + time.sleep(10) + + results = await ( + await ops_test.model.applications[TEST_APP_NAME] + .units[0] + .run_action("stop-continuous-writes") + ).wait() + + writes = int(results.results["writes"]) + assert writes > 0 + assert writes > show_writes > early_writes