diff --git a/docker-compose.yaml b/docker-compose.yaml index e03e651..026fba0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,8 +4,8 @@ services: rabbitmq: image: rabbitmq:3.9-management-alpine ports: - - '127.0.0.1:5672:5672' - - '127.0.0.1:15672:15672' + - '127.0.0.1:5673:5672' + - '127.0.0.1:15673:15672' shell: build: diff --git a/servicelayer/logs.py b/servicelayer/logs.py index 08ca9a0..b12a222 100644 --- a/servicelayer/logs.py +++ b/servicelayer/logs.py @@ -86,7 +86,7 @@ def apply_task_context(task, **kwargs): start_time=time.time(), trace_id=str(uuid.uuid4()), retry=unpack_int(task.context.get("retries")), - **kwargs + **kwargs, ) diff --git a/servicelayer/taskqueue.py b/servicelayer/taskqueue.py index 0cd3468..de9fb71 100644 --- a/servicelayer/taskqueue.py +++ b/servicelayer/taskqueue.py @@ -13,6 +13,8 @@ from collections import defaultdict from threading import Thread from timeit import default_timer +import uuid +from random import randrange import pika.spec from prometheus_client import start_http_server @@ -753,3 +755,85 @@ def get_rabbitmq_connection(): backoff(failures=attempt) raise RuntimeError("Could not connect to RabbitMQ") + + +def get_task_count(collection_id, redis_conn) -> int: + """Get the total task count for a given dataset.""" + status = Dataset.get_active_dataset_status(conn=redis_conn) + try: + collection = status["datasets"][str(collection_id)] + total = collection["finished"] + collection["running"] + collection["pending"] + except KeyError: + total = 0 + return total + + +def get_priority(collection_id, redis_conn) -> int: + """ + Priority buckets for tasks based on the total (pending + running) task count. + """ + total_task_count = get_task_count(collection_id, redis_conn) + if total_task_count < 500: + return randrange(7, 9) + elif total_task_count < 10000: + return randrange(4, 7) + return randrange(1, 4) + + +def dataset_from_collection(collection): + """servicelayer dataset from a collection""" + if collection is None: + return NO_COLLECTION + return str(collection.id) + + +def queue_task( + rmq_conn, redis_conn, collection_id, stage, job_id=None, context=None, **payload +): + task_id = uuid.uuid4().hex + priority = get_priority(collection_id, redis_conn) + body = { + "collection_id": collection_id, + "job_id": job_id or uuid.uuid4().hex, + "task_id": task_id, + "operation": stage, + "context": context, + "payload": payload, + "priority": priority, + } + try: + channel = rmq_conn.channel() + channel.confirm_delivery() + channel.basic_publish( + exchange="", + routing_key=stage, + body=json.dumps(body), + properties=pika.BasicProperties( + delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE, priority=priority + ), + ) + dataset = Dataset(conn=redis_conn, name=str(collection_id)) + dataset.add_task(task_id, stage) + except (pika.exceptions.UnroutableError, pika.exceptions.AMQPConnectionError): + log.exception("Error while queuing task") + finally: + try: + if channel: + channel.close() + except pika.exceptions.ChannelWrongStateError: + log.exception("Failed to explicitly close RabbitMQ channel.") + + +def flush_queues(rmq_conn, redis_conn, queues): + try: + channel = rmq_conn.channel() + for queue in queues: + try: + channel.queue_purge(queue) + except ValueError: + logging.exception(f"Error while flushing the {queue} queue") + channel.close() + except pika.exceptions.AMQPError: + logging.exception("Error while flushing task queue") + for key in redis_conn.scan_iter(PREFIX + "*"): + redis_conn.delete(key) diff --git a/tests/test_taskqueue.py b/tests/test_taskqueue.py index 68f504c..9f710cd 100644 --- a/tests/test_taskqueue.py +++ b/tests/test_taskqueue.py @@ -15,9 +15,12 @@ get_rabbitmq_connection, dataset_from_collection_id, declare_rabbitmq_queue, + flush_queues, ) from servicelayer.util import unpack_datetime +from servicelayer.taskqueue import get_priority, get_task_count, queue_task + class CountingWorker(Worker): def dispatch_task(self, task: Task) -> Task: @@ -183,3 +186,76 @@ def did_nack(): assert stage["pending"] == 1 assert stage["running"] == 0 assert dataset.is_task_tracked(Task(**body)) + + +def test_get_priority_bucket(): + redis = get_fakeredis() + rmq = get_rabbitmq_connection() + flush_queues(rmq, redis, ["index"]) + collection_id = 1 + + assert get_task_count(collection_id, redis) == 0 + assert get_priority(collection_id, redis) in (7, 8) + + queue_task(rmq, redis, collection_id, "index") + + assert get_task_count(collection_id, redis) == 1 + assert get_priority(collection_id, redis) in (7, 8) + + with patch.object( + Dataset, + "get_active_dataset_status", + return_value={ + "total": 9999, + "datasets": { + "1": { + "finished": 9999, + "running": 0, + "pending": 0, + "stages": [ + { + "job_id": "", + "stage": "index", + "pending": 0, + "running": 0, + "finished": 9999, + } + ], + "start_time": "2024-06-25T10:58:49.779811", + "end_time": None, + "last_update": "2024-06-25T10:58:49.779819", + } + }, + }, + ): + assert get_task_count(collection_id, redis) == 9999 + assert get_priority(collection_id, redis) in (4, 5, 6) + + with patch.object( + Dataset, + "get_active_dataset_status", + return_value={ + "total": 10001, + "datasets": { + "1": { + "finished": 10000, + "running": 0, + "pending": 1, + "stages": [ + { + "job_id": "", + "stage": "index", + "pending": 10001, + "running": 0, + "finished": 0, + } + ], + "start_time": "2024-06-25T10:58:49.779811", + "end_time": None, + "last_update": "2024-06-25T10:58:49.779819", + } + }, + }, + ): + assert get_task_count(collection_id, redis) == 10001 + assert get_priority(collection_id, redis) in (1, 2, 3)