diff --git a/servicelayer/taskqueue.py b/servicelayer/taskqueue.py index ff7f0f3..872efa7 100644 --- a/servicelayer/taskqueue.py +++ b/servicelayer/taskqueue.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import json import time import threading @@ -310,6 +310,22 @@ def mark_done(self, task: Task): # delete stages key pipe.delete(self.active_stages_key) + def mark_for_retry(self, task): + pipe = self.conn.pipeline() + stage_key = self.get_stage_key(task.operation) + + log.info( + f"Marking task {task.task_id} (stage {task.operation})" + f" for retry after NACK" + ) + + pipe.sadd(make_key(stage_key, "pending"), task.task_id) + pipe.srem(make_key(stage_key, "running"), task.task_id) + pipe.delete(task.retry_key) + pipe.srem(stage_key, task.task_id) + + pipe.set(self.last_update_key, pack_now()) + def is_done(self): status = self.get_status() return status["pending"] == 0 and status["running"] == 0 @@ -455,8 +471,15 @@ def process_blocking(self): try: (task, channel, connection) = self.local_queue.get(timeout=TIMEOUT) apply_task_context(task, v=self.version) - self.handle(task, channel) - cb = functools.partial(self.ack_message, task, channel) + success, retry = self.handle(task, channel) + log.debug( + f"Task {task.task_id} finished with success={success}" + f" and retry={retry}" + ) + if success: + cb = functools.partial(self.ack_message, task, channel) + else: + cb = functools.partial(self.nack_message, task, channel, retry) connection.add_callback_threadsafe(cb) except Empty: pass @@ -480,7 +503,11 @@ def process_nonblocking(self): else: queue_active[queue] = True task = get_task(body, method.delivery_tag) - self.handle(task, channel) + success, retry = self.handle(task, channel) + if success: + channel.basic_ack(task.delivery_tag) + else: + channel.basic_nack(task.delivery_tag, requeue=retry) def process(self, blocking=True): if blocking: @@ -488,8 +515,12 @@ def process(self, blocking=True): else: self.process_nonblocking() - def handle(self, task: Task, channel): - """Execute a task.""" + def handle(self, task: Task, channel) -> Tuple[bool, bool]: + """Execute a task. + + Returns a tuple of (success, retry).""" + success = True + retry = True try: dataset = Dataset( conn=self.conn, name=dataset_from_collection_id(task.collection_id) @@ -531,7 +562,7 @@ def handle(self, task: Task, channel): log.info( f"Sending a NACK for message {task.delivery_tag}" f" for task_id {task.task_id}." - f"Message will be requeued." + f" Message will be requeued." ) # In this case, a task ID was found neither in the # list of Pending, nor the list of Running tasks @@ -541,12 +572,19 @@ def handle(self, task: Task, channel): retries=0, failed_permanently=True, ).inc() - if channel.is_open: - channel.basic_nack(task.delivery_tag) + success = False + except MaxRetriesExceededError: + log.exception( + f"Task {task.task_id} permanently failed and will be discarded." + ) + success = False + retry = False except Exception: log.exception("Error in task handling") + success = False finally: self.after_task(task) + return success, retry @abstractmethod def dispatch_task(self, task: Task) -> Task: @@ -584,6 +622,17 @@ def ack_message(self, task, channel): channel.basic_ack(task.delivery_tag) clear_contextvars() + def nack_message(self, task, channel, requeue=True): + """NACK task and update status.""" + apply_task_context(task, v=self.version) + log.info(f"NACKing message {task.delivery_tag} for task_id {task.task_id}") + dataset = task.get_dataset(conn=self.conn) + # Sync state to redis + dataset.mark_for_retry(task) + if channel.is_open: + channel.basic_nack(delivery_tag=task.delivery_tag, requeue=requeue) + clear_contextvars() + def run(self): """Run a blocking worker instance""" diff --git a/tests/test_taskqueue.py b/tests/test_taskqueue.py index ebd11f8..286e23d 100644 --- a/tests/test_taskqueue.py +++ b/tests/test_taskqueue.py @@ -20,18 +20,20 @@ class CountingWorker(Worker): - def dispatch_task(self, task): + def dispatch_task(self, task: Task) -> Task: assert isinstance(task, Task), task if not hasattr(self, "test_done"): self.test_done = 0 self.test_done += 1 self.test_task = task + return task class TaskQueueTest(TestCase): def test_task_queue(self): test_queue_name = "sls-queue-ingest" conn = get_fakeredis() + conn.flushdb() collection_id = 2 task_id = "test-task" priority = randrange(1, settings.RABBITMQ_MAX_PRIORITY + 1) @@ -110,6 +112,7 @@ def test_task_queue(self): def test_task_that_shouldnt_execute(self, mock_should_execute): test_queue_name = "sls-queue-ingest" conn = get_fakeredis() + conn.flushdb() collection_id = 2 task_id = "test-task" priority = randrange(1, settings.RABBITMQ_MAX_PRIORITY + 1) @@ -142,6 +145,10 @@ def did_nack(): dataset = Dataset(conn=conn, name=dataset_from_collection_id(collection_id)) dataset.add_task(task_id, "test-op") + status = dataset.get_active_dataset_status(conn=conn) + stage = status["datasets"]["2"]["stages"][0] + assert stage["pending"] == 1 + assert stage["running"] == 0 worker = CountingWorker(queues=[test_queue_name], conn=conn, num_threads=1) assert not dataset.should_execute(task_id=task_id) @@ -160,3 +167,8 @@ def did_nack(): dispatch_fn.assert_not_called() channel.close() + + status = dataset.get_active_dataset_status(conn=conn) + stage = status["datasets"]["2"]["stages"][0] + assert stage["pending"] == 1 + assert stage["running"] == 0