diff --git a/servicelayer/taskqueue.py b/servicelayer/taskqueue.py index 611244b..00ce64c 100644 --- a/servicelayer/taskqueue.py +++ b/servicelayer/taskqueue.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import json import time import threading @@ -17,11 +17,13 @@ from random import randrange import pika.spec +from pika.adapters.blocking_connection import BlockingChannel from prometheus_client import start_http_server from structlog.contextvars import clear_contextvars, bind_contextvars import pika from banal import ensure_list +from redis import Redis from servicelayer.cache import get_redis, make_key from servicelayer.util import pack_now, unpack_int @@ -434,7 +436,7 @@ class Worker(ABC): def __init__( self, queues, - conn=None, + conn: Redis = None, num_threads=settings.WORKER_THREADS, version=None, prefetch_count_mapping=defaultdict(lambda: 1), @@ -491,19 +493,18 @@ def on_message(self, channel, method, properties, body, args): We have to make sure it doesn't block for long to ensure that RabbitMQ heartbeats are not interrupted. """ - connection = args[0] task = get_task(body, method.delivery_tag) # the task needs to be acknowledged in the same channel that it was # received. So store the channel. This is useful when executing batched # indexing tasks since they are acknowledged late. task._channel = channel - self.local_queue.put((task, channel, connection)) + self.local_queue.put((task, channel)) def process_blocking(self): """Blocking worker thread - executes tasks from a queue and periodic tasks""" while True: try: - (task, channel, connection) = self.local_queue.get(timeout=TIMEOUT) + (task, channel) = self.local_queue.get(timeout=TIMEOUT) apply_task_context(task, v=self.version) success, retry = self.handle(task, channel) log.debug( @@ -514,7 +515,7 @@ def process_blocking(self): cb = functools.partial(self.ack_message, task, channel) else: cb = functools.partial(self.nack_message, task, channel, retry) - connection.add_callback_threadsafe(cb) + channel.connection.add_callback_threadsafe(cb) except Empty: pass finally: @@ -523,8 +524,7 @@ def process_blocking(self): def process_nonblocking(self): """Non-blocking worker is used for tests only.""" - connection = get_rabbitmq_connection() - channel = connection.channel() + channel = get_rabbitmq_channel() queue_active = {queue: True for queue in self.queues} while True: for queue in self.queues: @@ -632,7 +632,7 @@ def periodic(self): """Periodic tasks to run.""" pass - def ack_message(self, task, channel): + def ack_message(self, task, channel, multiple=False): """Acknowledge a task after execution. RabbitMQ requires that the channel used for receiving the message must be used @@ -653,7 +653,7 @@ def ack_message(self, task, channel): # Sync state to redis dataset.mark_done(task) if channel.is_open: - channel.basic_ack(task.delivery_tag) + channel.basic_ack(task.delivery_tag, multiple=multiple) clear_contextvars() def nack_message(self, task, channel, requeue=True): @@ -697,9 +697,8 @@ def process(): log.info(f"Worker has {self.num_threads} worker threads.") - connection = get_rabbitmq_connection() - channel = connection.channel() - on_message_callback = functools.partial(self.on_message, args=(connection,)) + channel = get_rabbitmq_channel() + on_message_callback = functools.partial(self.on_message, args=(channel,)) for queue in self.queues: declare_rabbitmq_queue( @@ -709,13 +708,14 @@ def process(): channel.start_consuming() -def get_rabbitmq_connection(): +def get_rabbitmq_channel() -> BlockingChannel: for attempt in service_retries(): try: if ( not hasattr(local, "connection") or not local.connection or not local.connection.is_open + or not local.channel or attempt > 0 ): log.debug( @@ -735,16 +735,17 @@ def get_rabbitmq_connection(): ) ) local.connection = connection + local.channel = connection.channel() # Check that the connection is alive - result = local.connection.channel().exchange_declare( + result = local.channel.exchange_declare( exchange="amq.topic", exchange_type=pika.exchange_type.ExchangeType.topic, passive=True, ) assert isinstance(result.method, pika.spec.Exchange.DeclareOk) - return local.connection + return local.channel except ( pika.exceptions.AMQPConnectionError, @@ -764,6 +765,7 @@ def get_rabbitmq_connection(): f"Attempt: {attempt}/{service_retries().stop}" ) local.connection = None + local.channel = None backoff(failures=attempt) raise RuntimeError("Could not connect to RabbitMQ") @@ -804,7 +806,13 @@ def dataset_from_collection(collection): def queue_task( - rmq_conn, redis_conn, collection_id, stage, job_id=None, context=None, **payload + rmq_channel: BlockingChannel, + redis_conn, + collection_id: int, + stage: str, + job_id=None, + context=None, + **payload, ): task_id = uuid.uuid4().hex priority = get_priority(collection_id, redis_conn) @@ -818,9 +826,8 @@ def queue_task( "priority": priority, } try: - channel = rmq_conn.channel() - channel.confirm_delivery() - channel.basic_publish( + rmq_channel.confirm_delivery() + rmq_channel.basic_publish( exchange="", routing_key=stage, body=json.dumps(body), @@ -832,23 +839,15 @@ def queue_task( dataset.add_task(task_id, stage) except (pika.exceptions.UnroutableError, pika.exceptions.AMQPConnectionError): log.exception("Error while queuing task") - finally: - try: - if channel: - channel.close() - except pika.exceptions.ChannelWrongStateError: - log.exception("Failed to explicitly close RabbitMQ channel.") -def flush_queues(rmq_conn, redis_conn, queues): +def flush_queues(rmq_channel: BlockingChannel, redis_conn: Redis, queues: List[str]): try: - channel = rmq_conn.channel() for queue in queues: try: - channel.queue_purge(queue) + rmq_channel.queue_purge(queue) except ValueError: logging.exception(f"Error while flushing the {queue} queue") - channel.close() except pika.exceptions.AMQPError: logging.exception("Error while flushing task queue") for key in redis_conn.scan_iter(PREFIX + "*"): diff --git a/tests/test_taskqueue.py b/tests/test_taskqueue.py index b75a42a..29e4c64 100644 --- a/tests/test_taskqueue.py +++ b/tests/test_taskqueue.py @@ -12,7 +12,7 @@ Worker, Dataset, Task, - get_rabbitmq_connection, + get_rabbitmq_channel, dataset_from_collection_id, declare_rabbitmq_queue, flush_queues, @@ -49,8 +49,7 @@ def test_task_queue(self): "payload": {}, "priority": priority, } - connection = get_rabbitmq_connection() - channel = connection.channel() + channel = get_rabbitmq_channel() declare_rabbitmq_queue(channel, test_queue_name) channel.queue_purge(test_queue_name) channel.basic_publish( @@ -84,7 +83,7 @@ def test_task_queue(self): assert task.get_retry_count(conn) == 1 with patch("servicelayer.settings.WORKER_RETRY", 0): - channel = connection.channel() + channel = get_rabbitmq_channel() channel.queue_purge(test_queue_name) channel.basic_publish( properties=pika.BasicProperties(priority=priority), @@ -138,8 +137,7 @@ def test_task_that_shouldnt_execute(self, mock_should_execute): "collection_id": 2, } - connection = get_rabbitmq_connection() - channel = connection.channel() + channel = get_rabbitmq_channel() declare_rabbitmq_queue(channel, test_queue_name) channel.queue_purge(test_queue_name) channel.basic_publish( @@ -171,16 +169,14 @@ def did_nack(): return_value=None, ) as dispatch_fn: with patch.object( - pika.channel.Channel, + channel, attribute="basic_nack", return_value=None, ) as nack_fn: worker.process(blocking=False) - nack_fn.assert_any_call(delivery_tag=1, multiple=False, requeue=True) + nack_fn.assert_called_once() dispatch_fn.assert_not_called() - channel.close() - status = dataset.get_active_dataset_status(conn=conn) stage = status["datasets"]["2"]["stages"][0] assert stage["pending"] == 1 @@ -190,14 +186,14 @@ def did_nack(): def test_get_priority_bucket(): redis = get_fakeredis() - rmq = get_rabbitmq_connection() - flush_queues(rmq, redis, ["index"]) + rmq_channel = get_rabbitmq_channel() + flush_queues(rmq_channel, redis, ["index"]) collection_id = 1 assert get_task_count(collection_id, redis) == 0 assert get_priority(collection_id, redis) in (7, 8) - queue_task(rmq, redis, collection_id, "index") + queue_task(rmq_channel, redis, collection_id, "index") assert get_task_count(collection_id, redis) == 1 assert get_priority(collection_id, redis) in (7, 8)