Skip to content

Commit

Permalink
[DPE-5028] restart continuous writes (#236)
Browse files Browse the repository at this point in the history
* Bump libs

* Try to use stop event

* Restart continuous writes on start hook

* Integration test

* Wrong on start check

* Defer restart if db connection is not available yet

* Typo
  • Loading branch information
dragomirp authored Oct 18, 2024
1 parent ad568e6 commit cd2734f
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 24 deletions.
21 changes: 20 additions & 1 deletion lib/charms/data_platform_libs/v0/data_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent):

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 39
LIBPATCH = 40

PYDEPS = ["ops>=2.0.0"]

Expand Down Expand Up @@ -391,6 +391,10 @@ class IllegalOperationError(DataInterfacesError):
"""To be used when an operation is not allowed to be performed."""


class PrematureDataAccessError(DataInterfacesError):
"""To be raised when the Relation Data may be accessed (written) before protocol init complete."""


##############################################################################
# Global helpers / utilities
##############################################################################
Expand Down Expand Up @@ -1453,6 +1457,8 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None:
class ProviderData(Data):
"""Base provides-side of the data products relation."""

RESOURCE_FIELD = "database"

def __init__(
self,
model: Model,
Expand Down Expand Up @@ -1618,6 +1624,15 @@ def _fetch_my_specific_relation_data(
def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None:
"""Set values for fields not caring whether it's a secret or not."""
req_secret_fields = []

keys = set(data.keys())
if self.fetch_relation_field(relation.id, self.RESOURCE_FIELD) is None and (
keys - {"endpoints", "read-only-endpoints", "replset"}
):
raise PrematureDataAccessError(
"Premature access to relation data, update is forbidden before the connection is initialized."
)

if relation.app:
req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS)

Expand Down Expand Up @@ -3290,6 +3305,8 @@ class KafkaRequiresEvents(CharmEvents):
class KafkaProviderData(ProviderData):
"""Provider-side of the Kafka relation."""

RESOURCE_FIELD = "topic"

def __init__(self, model: Model, relation_name: str) -> None:
super().__init__(model, relation_name)

Expand Down Expand Up @@ -3539,6 +3556,8 @@ class OpenSearchRequiresEvents(CharmEvents):
class OpenSearchProvidesData(ProviderData):
"""Provider-side of the OpenSearch relation."""

RESOURCE_FIELD = "index"

def __init__(self, model: Model, relation_name: str) -> None:
super().__init__(model, relation_name)

Expand Down
122 changes: 121 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 2 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@
# See LICENSE file for licensing details.

[tool.poetry]
name = "postgresql-test-app"
version = "0.0.1-dev.0"
description = ""
authors = []
license = "Apache-2.0"
readme = "README.md"
homepage = "https://charmhub.io/postgresql-test-app"
repository = "https://github.com/canonical/postgresql-test-app"
package-mode = false

[tool.poetry.dependencies]
python = "^3.10"
Expand All @@ -34,12 +27,12 @@ optional = true
[tool.poetry.group.lint.dependencies]
codespell = "2.3.0"


[tool.poetry.group.integration]
optional = true

[tool.poetry.group.integration.dependencies]
pytest = "^8.3.3"
lightkube = "^0.15.4"
pytest-github-secrets = {git = "https://github.com/canonical/data-platform-workflows", tag = "v23.0.4", subdirectory = "python/pytest_plugins/github_secrets"}
pytest-operator = "^0.38.0"
pytest-operator-cache = {git = "https://github.com/canonical/data-platform-workflows", tag = "v23.0.4", subdirectory = "python/pytest_plugins/pytest_operator_cache"}
Expand Down
48 changes: 35 additions & 13 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
DatabaseEndpointsChangedEvent,
DatabaseRequires,
)
from ops.charm import ActionEvent, CharmBase
from ops.main import main
from ops.model import ActiveStatus, Relation
from ops import ActionEvent, ActiveStatus, CharmBase, Relation, StartEvent, main
from tenacity import RetryError, Retrying, stop_after_delay, wait_fixed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,10 +62,8 @@ def __init__(self, *args):

# Events related to the first database that is requested
# (these events are defined in the database requires charm library).
self.first_database_name = f'{self.app.name.replace("-", "_")}_database'
self.database = DatabaseRequires(
self, "database", self.first_database_name, EXTRA_USER_ROLES
)
self.database_name = f'{self.app.name.replace("-", "_")}_database'
self.database = DatabaseRequires(self, "database", self.database_name, EXTRA_USER_ROLES)
self.framework.observe(self.database.on.database_created, self._on_database_created)
self.framework.observe(
self.database.on.endpoints_changed, self._on_database_endpoints_changed
Expand Down Expand Up @@ -163,9 +159,31 @@ def __init__(self, *args):
self.framework.observe(self.on.run_sql_action, self._on_run_sql_action)
self.framework.observe(self.on.test_tls_action, self._on_test_tls_action)

def _on_start(self, _) -> None:
def are_writes_running(self) -> bool:
"""Returns whether continuous writes script is running."""
try:
os.kill(int(self.app_peer_data[PROC_PID_KEY]))
return True
except Exception:
return False

def _on_start(self, event: StartEvent) -> None:
"""Only sets an Active status."""
self.unit.status = ActiveStatus()
if (
self.model.unit.is_leader()
and PROC_PID_KEY in self.app_peer_data
and not self.are_writes_running()
):
try:
writes = self._get_db_writes()
except Exception:
logger.debug("Connection to db not yet available")
event.defer()
return
if writes > 0:
logger.info("Restarting continuous writes from db")
self._start_continuous_writes(writes + 1)

# First database events observers.
def _on_database_created(self, event: DatabaseCreatedEvent) -> None:
Expand Down Expand Up @@ -256,7 +274,7 @@ def _connection_string(self) -> Optional[str]:
if db_data:
data = db_data[0]
else:
data = list(self.first_database.fetch_relation_data().values())[0]
data = list(self.database.fetch_relation_data().values())[0]
username = data.get("username")
password = data.get("password")
endpoints = data.get("endpoints")
Expand Down Expand Up @@ -343,20 +361,24 @@ def _on_start_continuous_writes_action(self, event: ActionEvent) -> None:
self._start_continuous_writes(1)
event.set_results({"result": "True"})

def _on_show_continuous_writes_action(self, event: ActionEvent) -> None:
"""Count the continuous writes."""
def _get_db_writes(self) -> int:
try:
with psycopg2.connect(
self._connection_string
) as connection, connection.cursor() as cursor:
connection.autocommit = True
cursor.execute("SELECT COUNT(*) FROM continuous_writes;")
event.set_results({"writes": cursor.fetchone()[0]})
writes = cursor.fetchone()[0]
except Exception:
event.set_results({"writes": -1})
writes = -1
logger.exception("Unable to count writes")
finally:
connection.close()
return writes

def _on_show_continuous_writes_action(self, event: ActionEvent) -> None:
"""Count the continuous writes."""
event.set_results({"writes": self._get_db_writes()})

def _on_stop_continuous_writes_action(self, event: ActionEvent) -> None:
"""Stops the continuous writes process."""
Expand Down
57 changes: 57 additions & 0 deletions tests/integration/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# Copyright 2024 Canonical Ltd.
# See LICENSE file for licensing details.

import logging
import subprocess

from pytest_operator.plugin import OpsTest

logger = logging.getLogger(__name__)


async def run_command_on_unit(ops_test: OpsTest, unit_name: str, command: str) -> str:
"""Run a command on a specific unit.
Args:
ops_test: The ops test framework instance
unit_name: The name of the unit to run the command on
command: The command to run
Returns:
the command output if it succeeds, otherwise raises an exception.
"""
complete_command = ["exec", "--unit", unit_name, "--", *command.split()]
return_code, stdout, _ = await ops_test.juju(*complete_command)
if return_code != 0:
logger.error(stdout)
raise Exception(
f"Expected command '{command}' to succeed instead it failed: {return_code}"
)
return stdout


async def get_machine_from_unit(ops_test: OpsTest, unit_name: str) -> str:
"""Get the name of the machine from a specific unit.
Args:
ops_test: The ops test framework instance
unit_name: The name of the unit to get the machine
Returns:
The name of the machine.
"""
raw_hostname = await run_command_on_unit(ops_test, unit_name, "hostname")
return raw_hostname.strip()


async def restart_machine(ops_test: OpsTest, unit_name: str) -> None:
"""Restart the machine where a unit run on.
Args:
ops_test: The ops test framework instance
unit_name: The name of the unit to restart the machine
"""
raw_hostname = await get_machine_from_unit(ops_test, unit_name)
restart_machine_command = f"lxc restart {raw_hostname}"
subprocess.check_call(restart_machine_command.split())
Loading

0 comments on commit cd2734f

Please sign in to comment.