diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml index eaf859f..727e3f8 100644 --- a/.github/workflows/integration_test.yaml +++ b/.github/workflows/integration_test.yaml @@ -14,7 +14,7 @@ jobs: extra-arguments: -x --localstack-address 172.17.0.1 pre-run-script: localstack-installation.sh charmcraft-channel: latest/edge - modules: '["test_charm.py", "test_cos.py", "test_database.py", "test_db_migration.py", "test_django.py", "test_django_integrations.py", "test_fastapi.py", "test_go.py", "test_integrations.py", "test_proxy.py"]' + modules: '["test_charm.py", "test_cos.py", "test_database.py", "test_db_migration.py", "test_django.py", "test_django_integrations.py", "test_fastapi.py", "test_go.py", "test_integrations.py", "test_proxy.py", "test_workers.py"]' rockcraft-channel: latest/edge juju-channel: ${{ matrix.juju-version }} channel: 1.29-strict/stable diff --git a/examples/flask/test_rock/app.py b/examples/flask/test_rock/app.py index 2f21b52..db803da 100644 --- a/examples/flask/test_rock/app.py +++ b/examples/flask/test_rock/app.py @@ -1,7 +1,9 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. +import logging import os +import socket import time import urllib.parse from urllib.parse import urlparse @@ -16,11 +18,86 @@ import pymysql import pymysql.cursors import redis +from celery import Celery, Task from flask import Flask, g, jsonify, request + +def hostname(): + """Get the hostname of the current machine.""" + return socket.gethostbyname(socket.gethostname()) + + +def celery_init_app(app: Flask, broker_url: str) -> Celery: + """Initialise celery using the redis connection string. + + See https://flask.palletsprojects.com/en/3.0.x/patterns/celery/#integrate-celery-with-flask. + """ + + class FlaskTask(Task): + def __call__(self, *args: object, **kwargs: object) -> object: + with app.app_context(): + return self.run(*args, **kwargs) + + celery_app = Celery(app.name, task_cls=FlaskTask) + celery_app.set_default() + app.extensions["celery"] = celery_app + app.config.from_mapping( + CELERY=dict( + broker_url=broker_url, + result_backend=broker_url, + task_ignore_result=True, + ), + ) + celery_app.config_from_object(app.config["CELERY"]) + return celery_app + + app = Flask(__name__) app.config.from_prefixed_env() +broker_url = os.environ.get("REDIS_DB_CONNECT_STRING") +# Configure Celery only if Redis is configured +if broker_url: + celery_app = celery_init_app(app, broker_url) + redis_client = redis.Redis.from_url(broker_url) + + @celery_app.on_after_configure.connect + def setup_periodic_tasks(sender, **kwargs): + """Set up periodic tasks in the scheduler.""" + try: + # This will only have an effect in the beat scheduler. + sender.add_periodic_task(0.5, scheduled_task.s(hostname()), name="every 0.5s") + except NameError as e: + logging.exception("Failed to configure the periodic task") + + @celery_app.task + def scheduled_task(scheduler_hostname): + """Function to run a schedule task in a worker. + + The worker that will run this task will add the scheduler hostname argument + to the "schedulers" set in Redis, and the worker's hostname to the "workers" + set in Redis. + """ + worker_hostname = hostname() + logging.info( + "scheduler host received %s in worker host %s", scheduler_hostname, worker_hostname + ) + redis_client.sadd("schedulers", scheduler_hostname) + redis_client.sadd("workers", worker_hostname) + logging.info("schedulers: %s", redis_client.smembers("schedulers")) + logging.info("workers: %s", redis_client.smembers("workers")) + # The goal is to have all workers busy in all processes. + # For that it maybe necessary to exhaust all workers, but not to get the pending tasks + # too big, so all schedulers can manage to run their scheduled tasks. + # Celery prefetches tasks, and if they cannot be run they are put in reserved. + # If all processes have tasks in reserved, this task will finish immediately to not make + # queues any longer. + inspect_obj = celery_app.control.inspect() + reserved_sizes = [len(tasks) for tasks in inspect_obj.reserved().values()] + logging.info("number of reserved tasks %s", reserved_sizes) + delay = 0 if min(reserved_sizes) > 0 else 5 + time.sleep(delay) + def get_mysql_database(): """Get the mysql db connection.""" @@ -213,16 +290,42 @@ def mongodb_status(): @app.route("/redis/status") def redis_status(): - """Mongodb status endpoint.""" + """Redis status endpoint.""" if database := get_redis_database(): try: database.set("foo", "bar") return "SUCCESS" - except pymongo.errors.PyMongoError: - pass + except redis.exceptions.RedisError: + logging.exception("Error querying redis") return "FAIL" +@app.route("/redis/clear_celery_stats") +def redis_celery_clear_stats(): + """Reset Redis statistics about workers and schedulers.""" + if database := get_redis_database(): + try: + database.delete("workers") + database.delete("schedulers") + return "SUCCESS" + except redis.exceptions.RedisError: + logging.exception("Error querying redis") + return "FAIL", 500 + + +@app.route("/redis/celery_stats") +def redis_celery_stats(): + """Read Redis statistics about workers and schedulers.""" + if database := get_redis_database(): + try: + worker_set = [str(host) for host in database.smembers("workers")] + beat_set = [str(host) for host in database.smembers("schedulers")] + return jsonify({"workers": worker_set, "schedulers": beat_set}) + except redis.exceptions.RedisError: + logging.exception("Error querying redis") + return "FAIL", 500 + + @app.route("/rabbitmq/send") def rabbitmq_send(): """Send a message to "charm" queue.""" diff --git a/examples/flask/test_rock/requirements.txt b/examples/flask/test_rock/requirements.txt index 5f1ab38..2ff69c0 100644 --- a/examples/flask/test_rock/requirements.txt +++ b/examples/flask/test_rock/requirements.txt @@ -7,3 +7,4 @@ pymongo redis[hiredis] boto3 pika +celery diff --git a/examples/flask/test_rock/rockcraft.yaml b/examples/flask/test_rock/rockcraft.yaml index 97dc7a7..91ad18b 100644 --- a/examples/flask/test_rock/rockcraft.yaml +++ b/examples/flask/test_rock/rockcraft.yaml @@ -11,3 +11,19 @@ platforms: extensions: - flask-framework + +services: + celery-worker: + override: replace + # redis is not mandatory in the charm. We do not want the charm to fail immediately, so the sleep + command: bash -c "sleep 5; celery -A app:celery_app worker -c 2 --loglevel DEBUG" + startup: enabled + user: _daemon_ + working-dir: /flask/app + celery-beat-scheduler: + override: replace + # redis is not mandatory in the charm. We do not want the charm to fail immediately, so the sleep + command: bash -c "sleep 5; celery -A app:celery_app beat --loglevel DEBUG -s /tmp/celerybeat-schedule" + startup: enabled + user: _daemon_ + working-dir: /flask/app diff --git a/paas_app_charmer/_gunicorn/charm.py b/paas_app_charmer/_gunicorn/charm.py index 95fa8dc..ad749eb 100644 --- a/paas_app_charmer/_gunicorn/charm.py +++ b/paas_app_charmer/_gunicorn/charm.py @@ -19,7 +19,9 @@ class GunicornBase(PaasCharm): @property def _workload_config(self) -> WorkloadConfig: """Return a WorkloadConfig instance.""" - return create_workload_config(self._framework_name) + return create_workload_config( + framework_name=self._framework_name, unit_name=self.unit.name + ) def _create_app(self) -> App: """Build an App instance for the Gunicorn based charm. diff --git a/paas_app_charmer/_gunicorn/workload_config.py b/paas_app_charmer/_gunicorn/workload_config.py index 0859512..d93bbf3 100644 --- a/paas_app_charmer/_gunicorn/workload_config.py +++ b/paas_app_charmer/_gunicorn/workload_config.py @@ -12,11 +12,12 @@ APPLICATION_ERROR_LOG_FILE_FMT = "/var/log/{framework}/error.log" -def create_workload_config(framework_name: str) -> WorkloadConfig: +def create_workload_config(framework_name: str, unit_name: str) -> WorkloadConfig: """Create an WorkloadConfig for Gunicorn. Args: framework_name: framework name. + unit_name: name of the app unit. Returns: new WorkloadConfig @@ -35,4 +36,5 @@ def create_workload_config(framework_name: str) -> WorkloadConfig: pathlib.Path(str.format(APPLICATION_ERROR_LOG_FILE_FMT, framework=framework_name)), ], metrics_target="*:9102", + unit_name=unit_name, ) diff --git a/paas_app_charmer/_gunicorn/wsgi_app.py b/paas_app_charmer/_gunicorn/wsgi_app.py index 9e9c773..05dfe90 100644 --- a/paas_app_charmer/_gunicorn/wsgi_app.py +++ b/paas_app_charmer/_gunicorn/wsgi_app.py @@ -20,6 +20,7 @@ class WsgiApp(App): def __init__( # pylint: disable=too-many-arguments self, + *, container: ops.Container, charm_state: CharmState, workload_config: WorkloadConfig, diff --git a/paas_app_charmer/app.py b/paas_app_charmer/app.py index 3cd2ce3..e547260 100644 --- a/paas_app_charmer/app.py +++ b/paas_app_charmer/app.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) +WORKER_SUFFIX = "-worker" +SCHEDULER_SUFFIX = "-scheduler" + @dataclass(kw_only=True) class WorkloadConfig: # pylint: disable=too-many-instance-attributes @@ -37,6 +40,7 @@ class WorkloadConfig: # pylint: disable=too-many-instance-attributes log_files: list of files to monitor. metrics_target: target to scrape for metrics. metrics_path: path to scrape for metrics. + unit_name: Name of the unit. Needed to know if schedulers should run here. """ framework: str @@ -51,6 +55,16 @@ class WorkloadConfig: # pylint: disable=too-many-instance-attributes log_files: List[pathlib.Path] metrics_target: str | None = None metrics_path: str | None = "/metrics" + unit_name: str + + def should_run_scheduler(self) -> bool: + """Return if the unit should run scheduler processes. + + Return: + True if the unit should run scheduler processes, False otherwise. + """ + unit_id = self.unit_name.split("/")[1] + return unit_id == "0" class App: @@ -58,6 +72,7 @@ class App: def __init__( # pylint: disable=too-many-arguments self, + *, container: ops.Container, charm_state: CharmState, workload_config: WorkloadConfig, @@ -192,6 +207,19 @@ def _app_layer(self) -> ops.pebble.LayerDict: services[self._workload_config.service_name]["override"] = "replace" services[self._workload_config.service_name]["environment"] = self.gen_environment() + for service_name, service in services.items(): + normalised_service_name = service_name.lower() + # Add environment variables to all worker processes. + if normalised_service_name.endswith(WORKER_SUFFIX): + service["environment"] = self.gen_environment() + # For scheduler processes, add environment variables if + # the scheduler should run in the unit, disable it otherwise. + if normalised_service_name.endswith(SCHEDULER_SUFFIX): + if self._workload_config.should_run_scheduler(): + service["environment"] = self.gen_environment() + else: + service["startup"] = "disabled" + return ops.pebble.LayerDict(services=services) diff --git a/paas_app_charmer/charm.py b/paas_app_charmer/charm.py index 0938682..510a57c 100644 --- a/paas_app_charmer/charm.py +++ b/paas_app_charmer/charm.py @@ -128,7 +128,7 @@ def __init__(self, framework: ops.Framework, framework_name: str) -> None: ) self._observability = Observability( - self, + charm=self, log_files=self._workload_config.log_files, container_name=self._workload_config.container_name, cos_dir=self.get_cos_dir(), diff --git a/paas_app_charmer/charm_state.py b/paas_app_charmer/charm_state.py index ca30baf..d320a9c 100644 --- a/paas_app_charmer/charm_state.py +++ b/paas_app_charmer/charm_state.py @@ -80,6 +80,7 @@ def __init__( # pylint: disable=too-many-arguments @classmethod def from_charm( # pylint: disable=too-many-arguments cls, + *, charm: ops.CharmBase, framework: str, framework_config: BaseModel, @@ -221,6 +222,7 @@ class IntegrationsState: @classmethod def build( # pylint: disable=too-many-arguments cls, + *, redis_uri: str | None, database_requirers: dict[str, DatabaseRequires], s3_connection_info: dict[str, str] | None, diff --git a/paas_app_charmer/database_migration.py b/paas_app_charmer/database_migration.py index 3ebf7b0..2d58bcc 100644 --- a/paas_app_charmer/database_migration.py +++ b/paas_app_charmer/database_migration.py @@ -75,6 +75,7 @@ def _set_status(self, status: DatabaseMigrationStatus) -> None: # pylint: disable=too-many-arguments def run( self, + *, command: list[str], environment: dict[str, str], working_dir: pathlib.Path, diff --git a/paas_app_charmer/fastapi/charm.py b/paas_app_charmer/fastapi/charm.py index f926731..744852a 100644 --- a/paas_app_charmer/fastapi/charm.py +++ b/paas_app_charmer/fastapi/charm.py @@ -72,6 +72,7 @@ def _workload_config(self) -> WorkloadConfig: log_files=[], metrics_target=f"*:{framework_config.metrics_port}", metrics_path=framework_config.metrics_path, + unit_name=self.unit.name, ) def get_cos_dir(self) -> str: diff --git a/paas_app_charmer/go/charm.py b/paas_app_charmer/go/charm.py index 6c68c02..af11dc4 100644 --- a/paas_app_charmer/go/charm.py +++ b/paas_app_charmer/go/charm.py @@ -64,6 +64,7 @@ def _workload_config(self) -> WorkloadConfig: log_files=[], metrics_target=f"*:{framework_config.metrics_port}", metrics_path=framework_config.metrics_path, + unit_name=self.unit.name, ) def get_cos_dir(self) -> str: diff --git a/paas_app_charmer/observability.py b/paas_app_charmer/observability.py index 4e4df7a..cd82489 100644 --- a/paas_app_charmer/observability.py +++ b/paas_app_charmer/observability.py @@ -17,6 +17,7 @@ class Observability(ops.Object): def __init__( # pylint: disable=too-many-arguments self, + *, charm: ops.CharmBase, container_name: str, cos_dir: str, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 91bc089..124616d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -153,6 +153,26 @@ async def deploy_postgres_fixture(ops_test: OpsTest, model: Model): return await model.deploy("postgresql-k8s", channel="14/stable", revision=300, trust=True) +@pytest_asyncio.fixture(scope="module", name="redis_k8s_app") +async def deploy_redisk8s_fixture(ops_test: OpsTest, model: Model): + """Deploy Redis k8s charm.""" + redis_app = await model.deploy("redis-k8s", channel="edge") + await model.wait_for_idle(apps=[redis_app.name], status="active") + return redis_app + + +@pytest_asyncio.fixture(scope="function", name="integrate_redis_k8s_flask") +async def integrate_redis_k8s_flask_fixture( + ops_test: OpsTest, model: Model, flask_app: Application, redis_k8s_app: Application +): + """Integrate redis_k8s with flask apps.""" + relation = await model.integrate(flask_app.name, redis_k8s_app.name) + await model.wait_for_idle(apps=[redis_k8s_app.name], status="active") + yield relation + await flask_app.destroy_relation("redis", f"{redis_k8s_app.name}") + await model.wait_for_idle() + + @pytest_asyncio.fixture def run_action(ops_test: OpsTest): async def _run_action(application_name, action_name, **params): diff --git a/tests/integration/flask/test_workers.py b/tests/integration/flask/test_workers.py new file mode 100644 index 0000000..9dea9fd --- /dev/null +++ b/tests/integration/flask/test_workers.py @@ -0,0 +1,70 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Integration tests for Flask workers and schedulers.""" + +import asyncio +import logging +import time + +import pytest +import requests +from juju.application import Application +from juju.model import Model +from juju.utils import block_until +from pytest_operator.plugin import OpsTest + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "num_units", + [1, 3], +) +@pytest.mark.usefixtures("integrate_redis_k8s_flask") +async def test_workers_and_scheduler_services( + ops_test: OpsTest, model: Model, flask_app: Application, get_unit_ips, num_units: int +): + """ + arrange: Flask and redis deployed and integrated. + act: Scale the app to the desired number of units. + assert: There should be only one scheduler and as many workers as units. That will + be checked because the scheduler is constantly sending tasks with its hostname + to the workers, and the workers will put its own hostname and the schedulers + hostname in Redis sets. Those sets are checked through the Flask app, + that queries Redis. + """ + await flask_app.scale(num_units) + await model.wait_for_idle(apps=[flask_app.name], status="active") + + # the flask unit is not important. Take the first one + flask_unit_ip = (await get_unit_ips(flask_app.name))[0] + + def check_correct_celery_stats(num_schedulers, num_workers): + """Check that the expected number of workers and schedulers is right.""" + response = requests.get(f"http://{flask_unit_ip}:8000/redis/celery_stats", timeout=5) + assert response.status_code == 200 + data = response.json() + logger.info( + "check_correct_celery_stats. Expected schedulers: %d, expected workers %d. Result %s", + num_schedulers, + num_workers, + data, + ) + return len(data["workers"]) == num_workers and len(data["schedulers"]) == num_schedulers + + # clean the current celery stats + response = requests.get(f"http://{flask_unit_ip}:8000/redis/clear_celery_stats", timeout=5) + assert response.status_code == 200 + assert "SUCCESS" == response.text + + # enough time for all the schedulers to send messages + time.sleep(10) + try: + await block_until( + lambda: check_correct_celery_stats(num_schedulers=1, num_workers=num_units), + timeout=60, + wait_period=1, + ) + except asyncio.TimeoutError: + assert False, "Failed to get 2 workers and 1 scheduler" diff --git a/tests/unit/django/test_charm.py b/tests/unit/django/test_charm.py index c872da7..42900d5 100644 --- a/tests/unit/django/test_charm.py +++ b/tests/unit/django/test_charm.py @@ -63,7 +63,7 @@ def test_django_config(harness: Harness, config: dict, env: dict) -> None: database_requirers={}, ) webserver_config = WebserverConfig.from_charm_config(harness.charm.config) - workload_config = create_workload_config(framework_name="django") + workload_config = create_workload_config(framework_name="django", unit_name="django/0") webserver = GunicornWebserver( webserver_config=webserver_config, workload_config=workload_config, diff --git a/tests/unit/fastapi/test_charm.py b/tests/unit/fastapi/test_charm.py index f4948bb..f20aa8b 100644 --- a/tests/unit/fastapi/test_charm.py +++ b/tests/unit/fastapi/test_charm.py @@ -13,9 +13,10 @@ @pytest.mark.parametrize( - "config, env", + "config, postgresql_relation_data, env", [ pytest.param( + {}, {}, { "UVICORN_PORT": "8080", @@ -37,6 +38,12 @@ "metrics-path": "/othermetrics", "user-defined-config": "userdefined", }, + { + "database": "test-database", + "endpoints": "test-postgresql:5432,test-postgresql-2:5432", + "password": "test-password", + "username": "test-username", + }, { "UVICORN_PORT": "9000", "WEB_CONCURRENCY": "1", @@ -47,19 +54,36 @@ "METRICS_PATH": "/othermetrics", "APP_SECRET_KEY": "foobar", "APP_USER_DEFINED_CONFIG": "userdefined", + "APP_POSTGRESQL_DB_CONNECT_STRING": "postgresql://test-username:test-password@test-postgresql:5432/test-database", + "APP_POSTGRESQL_DB_FRAGMENT": "", + "APP_POSTGRESQL_DB_HOSTNAME": "test-postgresql", + "APP_POSTGRESQL_DB_NAME": "test-database", + "APP_POSTGRESQL_DB_NETLOC": "test-username:test-password@test-postgresql:5432", + "APP_POSTGRESQL_DB_PARAMS": "", + "APP_POSTGRESQL_DB_PASSWORD": "test-password", + "APP_POSTGRESQL_DB_PATH": "/test-database", + "APP_POSTGRESQL_DB_PORT": "5432", + "APP_POSTGRESQL_DB_QUERY": "", + "APP_POSTGRESQL_DB_SCHEME": "postgresql", + "APP_POSTGRESQL_DB_USERNAME": "test-username", }, id="custom config", ), ], ) -def test_fastapi_config(harness: Harness, config: dict, env: dict) -> None: +def test_fastapi_config( + harness: Harness, config: dict, postgresql_relation_data: dict, env: dict +) -> None: """ - arrange: none - act: start the fastapi charm and set the container to be ready. + arrange: prepare the charm optionally with the postgresql relation. + act: start the fastapi charm update the config options. assert: fastapi charm should submit the correct fastapi pebble layer to pebble. """ container = harness.model.unit.get_container(FASTAPI_CONTAINER_NAME) container.add_layer("a_layer", DEFAULT_LAYER) + if postgresql_relation_data: + harness.add_relation("postgresql", "postgresql-k8s", app_data=postgresql_relation_data) + harness.begin_with_initial_hooks() harness.charm._secret_storage.get_secret_key = unittest.mock.MagicMock(return_value="test") harness.update_config(config) diff --git a/tests/unit/flask/constants.py b/tests/unit/flask/constants.py index df1b0f4..d9936c3 100644 --- a/tests/unit/flask/constants.py +++ b/tests/unit/flask/constants.py @@ -25,6 +25,60 @@ }, } } + +LAYER_WITH_WORKER = { + "services": { + "flask": { + "override": "replace", + "startup": "enabled", + "command": f"/bin/python3 -m gunicorn -c /flask/gunicorn.conf.py app:app", + "after": ["statsd-exporter"], + "user": "_daemon_", + }, + "statsd-exporter": { + "override": "merge", + "command": ( + "/bin/statsd_exporter --statsd.mapping-config=/statsd-mapping.conf " + "--statsd.listen-udp=localhost:9125 " + "--statsd.listen-tcp=localhost:9125" + ), + "summary": "statsd exporter service", + "startup": "enabled", + "user": "_daemon_", + }, + "not-worker-service": { + "override": "replace", + "startup": "enabled", + "command": "/bin/noworker", + "user": "_daemon_", + }, + "real-worker": { + "override": "replace", + "startup": "enabled", + "command": "/bin/worker", + "user": "_daemon_", + }, + "Another-Real-WorkeR": { + "override": "replace", + "startup": "enabled", + "command": "/bin/worker", + "user": "_daemon_", + }, + "real-scheduler": { + "override": "replace", + "startup": "enabled", + "command": "/bin/scheduler", + "user": "_daemon_", + }, + "ANOTHER-REAL-SCHEDULER": { + "override": "replace", + "startup": "enabled", + "command": "/bin/worker", + "user": "_daemon_", + }, + } +} + FLASK_CONTAINER_NAME = "flask-app" SAML_APP_RELATION_DATA_EXAMPLE = { diff --git a/tests/unit/flask/test_charm.py b/tests/unit/flask/test_charm.py index ddcad7b..286726c 100644 --- a/tests/unit/flask/test_charm.py +++ b/tests/unit/flask/test_charm.py @@ -51,7 +51,7 @@ def test_flask_pebble_layer(harness: Harness) -> None: database_requirers={}, ) webserver_config = WebserverConfig.from_charm_config(harness.charm.config) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") webserver = GunicornWebserver( webserver_config=webserver_config, workload_config=workload_config, diff --git a/tests/unit/flask/test_database_migration.py b/tests/unit/flask/test_database_migration.py index 99930c8..5546d95 100644 --- a/tests/unit/flask/test_database_migration.py +++ b/tests/unit/flask/test_database_migration.py @@ -35,7 +35,7 @@ def test_database_migration(harness: Harness): is_secret_storage_ready=True, secret_key="", ) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") webserver_config = WebserverConfig() webserver = GunicornWebserver( webserver_config=webserver_config, @@ -109,7 +109,7 @@ def test_database_migrate_command(harness: Harness, file: str, command: list[str secret_key="", ) webserver_config = WebserverConfig() - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") webserver = GunicornWebserver( webserver_config=webserver_config, workload_config=workload_config, @@ -150,8 +150,12 @@ def test_database_migration_status(harness: Harness): ) assert database_migration.get_status() == DatabaseMigrationStatus.PENDING with pytest.raises(CharmConfigInvalidError): - database_migration.run(["migrate"], {}, pathlib.Path("/flask/app")) + database_migration.run( + command=["migrate"], environment={}, working_dir=pathlib.Path("/flask/app") + ) assert database_migration.get_status() == DatabaseMigrationStatus.FAILED harness.handle_exec(container, [], result=0) - database_migration.run(["migrate"], {}, pathlib.Path("/flask/app")) + database_migration.run( + command=["migrate"], environment={}, working_dir=pathlib.Path("/flask/app") + ) assert database_migration.get_status() == DatabaseMigrationStatus.COMPLETED diff --git a/tests/unit/flask/test_flask_app.py b/tests/unit/flask/test_flask_app.py index 473649d..86b2ac3 100644 --- a/tests/unit/flask/test_flask_app.py +++ b/tests/unit/flask/test_flask_app.py @@ -41,7 +41,7 @@ def test_flask_env(flask_config: dict, app_config: dict, database_migration_mock framework_config=flask_config, app_config=app_config, ) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") flask_app = WsgiApp( container=unittest.mock.MagicMock(), charm_state=charm_state, @@ -104,7 +104,7 @@ def test_http_proxy( secret_key="foobar", is_secret_storage_ready=True, ) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") flask_app = WsgiApp( container=unittest.mock.MagicMock(), charm_state=charm_state, @@ -157,7 +157,7 @@ def test_integrations_env( is_secret_storage_ready=True, integrations=integrations, ) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") flask_app = WsgiApp( container=unittest.mock.MagicMock(), charm_state=charm_state, diff --git a/tests/unit/flask/test_webserver.py b/tests/unit/flask/test_webserver.py index c2c54d3..cf51bf6 100644 --- a/tests/unit/flask/test_webserver.py +++ b/tests/unit/flask/test_webserver.py @@ -73,7 +73,7 @@ def test_gunicorn_config( secret_key="", is_secret_storage_ready=True, ) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") webserver_config = WebserverConfig(**charm_state_params) webserver = GunicornWebserver( webserver_config=webserver_config, @@ -116,7 +116,7 @@ def test_webserver_reload(monkeypatch, harness: Harness, is_running, database_mi is_secret_storage_ready=True, ) webserver_config = WebserverConfig() - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") webserver = GunicornWebserver( webserver_config=webserver_config, workload_config=workload_config, @@ -171,7 +171,7 @@ def test_gunicorn_config_with_pebble_log_forwarding( secret_key="", is_secret_storage_ready=True, ) - workload_config = create_workload_config(framework_name="flask") + workload_config = create_workload_config(framework_name="flask", unit_name="flask/0") webserver_config = WebserverConfig() webserver = GunicornWebserver( webserver_config=webserver_config, diff --git a/tests/unit/flask/test_workers.py b/tests/unit/flask/test_workers.py new file mode 100644 index 0000000..144dc8f --- /dev/null +++ b/tests/unit/flask/test_workers.py @@ -0,0 +1,72 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Unit tests for worker services.""" + +import copy +import unittest.mock +from secrets import token_hex + +import ops +import pytest +from ops.testing import Harness + +from .constants import FLASK_CONTAINER_NAME, LAYER_WITH_WORKER + + +def test_worker(harness: Harness): + """ + arrange: Prepare a unit with workers and schedulers. + act: Run initial hooks. + assert: The workers should have all the environment variables. Also the schedulers, as + the unit is 0. + """ + container = harness.model.unit.get_container(FLASK_CONTAINER_NAME) + flask_layer = copy.deepcopy(LAYER_WITH_WORKER) + container.add_layer("a_layer", LAYER_WITH_WORKER) + + harness.begin_with_initial_hooks() + + assert harness.model.unit.status == ops.ActiveStatus() + services = container.get_plan().services + assert "FLASK_SECRET_KEY" in services["flask"].environment + assert services["flask"].environment == services["real-worker"].environment + assert services["flask"].environment == services["Another-Real-WorkeR"].environment + assert services["real-scheduler"].startup == "enabled" + assert services["flask"].environment == services["real-scheduler"].environment + assert services["ANOTHER-REAL-SCHEDULER"].startup == "enabled" + assert services["flask"].environment == services["ANOTHER-REAL-SCHEDULER"].environment + assert "FLASK_SECRET_KEY" not in services["not-worker-service"].environment + + +def test_worker_multiple_units(harness: Harness): + """ + arrange: Prepare a unit with workers that is not the first one (number 1) + act: Run initial hooks. + assert: The workers should have all the environment variables. The schedulers should be + disabled and not have the environment variables + """ + + # This is tricky and could be problematic + harness.framework.model.unit.name = f"{harness._meta.name}/1" + harness.set_planned_units(3) + + # Just think that we are not the leader unit. For this it is necessary to put data + # in the peer relation for the secret.. + harness.set_leader(False) + harness.add_relation( + "secret-storage", harness.framework.model.app.name, app_data={"flask_secret_key": "XX"} + ) + + container = harness.model.unit.get_container(FLASK_CONTAINER_NAME) + container.add_layer("a_layer", LAYER_WITH_WORKER) + + harness.begin_with_initial_hooks() + + assert harness.model.unit.status == ops.ActiveStatus() + services = container.get_plan().services + assert "FLASK_SECRET_KEY" in services["flask"].environment + assert services["flask"].environment == services["real-worker"].environment + assert services["real-scheduler"].startup == "disabled" + assert "FLASK_SECRET_KEY" not in services["real-scheduler"].environment + assert "FLASK_SECRET_KEY" not in services["not-worker-service"].environment diff --git a/tests/unit/go/test_app.py b/tests/unit/go/test_app.py index 4c2201b..faaabf8 100644 --- a/tests/unit/go/test_app.py +++ b/tests/unit/go/test_app.py @@ -94,6 +94,7 @@ def test_go_environment_vars( log_files=[], metrics_target=f"*:{framework_config.metrics_port}", metrics_path=framework_config.metrics_path, + unit_name="go/0", ) charm_state = CharmState(