Skip to content

Commit

Permalink
fix nack related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stchris committed Jun 11, 2024
1 parent 1ea8ac3 commit 1b4381f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
27 changes: 6 additions & 21 deletions servicelayer/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,6 @@ def declare_rabbitmq_queue(channel, queue, prefetch_count=1):
)


class MaxRetriesExceededError(Exception):
pass


class Worker(ABC):
def __init__(
self,
Expand Down Expand Up @@ -494,14 +490,9 @@ def handle(self, task: Task, channel):
dataset = Dataset(
conn=self.conn, name=dataset_from_collection_id(task.collection_id)
)
# The worker will attempt to complete a task a number of times
# defined by WORKER_RETRY.
task_retry_count = task.get_retry_count(self.conn)
# The worker will only attempt to run a task of the task_id
# exists in Redis in the "running" or "pending" lists.
should_execute = dataset.should_execute(task.task_id)
if should_execute or task_retry_count <= settings.WORKER_RETRY:
# Increase the num. of tasks that have failed but will be reattempted
task_retry_count = task.get_retry_count(self.conn)
if should_execute and task_retry_count <= settings.WORKER_RETRY:
if task_retry_count:
metrics.TASKS_FAILED.labels(
stage=task.operation,
Expand All @@ -512,7 +503,6 @@ def handle(self, task: Task, channel):
dataset.checkout_task(task.task_id, task.operation)
task.increment_retry_count(self.conn)

# Emit Prometheus metrics
metrics.TASKS_STARTED.labels(stage=task.operation).inc()
start_time = default_timer()
log.info(
Expand All @@ -522,7 +512,6 @@ def handle(self, task: Task, channel):

task = self.dispatch_task(task)

# Emit Prometheus metrics
duration = max(0, default_timer() - start_time)
metrics.TASK_DURATION.labels(stage=task.operation).observe(duration)
metrics.TASKS_SUCCEEDED.labels(
Expand All @@ -531,13 +520,10 @@ def handle(self, task: Task, channel):
else:
reason = ""
if not should_execute:
reason = f"Task {task.task_id} neither in 'running' nor 'pending'"
else:
reason = (
f"Task {task.task_id} was not in found"
f"neither the Running nor Pending lists"
)
elif task_retry_count > settings.WORKER_RETRY:
reason = (
f"Task {task.task_id} was"
f"Task {task.task_id} was "
f"attempted more than {settings.WORKER_RETRY} times"
)
# This task can be tracked as having failed permanently.
Expand All @@ -552,8 +538,7 @@ def handle(self, task: Task, channel):
f"{reason}. Message will be requeued."
)

if channel.is_open:
channel.basic_nack(task.delivery_tag)
channel.basic_nack(task.delivery_tag)
except Exception:
log.exception("Error in task handling")
finally:
Expand Down
24 changes: 20 additions & 4 deletions tests/test_taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def dispatch_task(self, task):
self.test_done = 0
self.test_done += 1
self.test_task = task
return task


class TaskQueueTest(TestCase):
Expand Down Expand Up @@ -91,9 +92,13 @@ def test_task_queue(self):
dataset = Dataset(conn=conn, name=dataset_from_collection_id(collection_id))
dataset.add_task(task_id, "test-op")
channel.close()
with self.assertLogs(level="ERROR") as ctx:
with patch.object(
pika.channel.Channel,
attribute="basic_nack",
return_value=None,
) as nack_fn:
worker.process(blocking=False)
assert "Max retries reached for task test-task. Aborting." in ctx.output[0]
nack_fn.assert_any_call(delivery_tag=1, multiple=False, requeue=True)
# Assert that retry count stays the same
assert task.get_retry_count(conn) == 1

Expand All @@ -107,7 +112,6 @@ def test_task_queue(self):
end_time = unpack_datetime(status["end_time"])
assert started < end_time < last_updated

@skip("Unfinished")
@patch("servicelayer.taskqueue.Dataset.should_execute")
def test_task_that_shouldnt_execute(self, mock_should_execute):
test_queue_name = "sls-queue-ingest"
Expand Down Expand Up @@ -147,6 +151,18 @@ def did_nack():

worker = CountingWorker(queues=[test_queue_name], conn=conn, num_threads=1)
assert not dataset.should_execute(task_id=task_id)
worker.process(blocking=False)
with patch.object(
worker,
attribute="dispatch_task",
return_value=None,
) as dispatch_fn:
with patch.object(
pika.channel.Channel,
attribute="basic_nack",
return_value=None,
) as nack_fn:
worker.process(blocking=False)
nack_fn.assert_any_call(delivery_tag=1, multiple=False, requeue=True)
dispatch_fn.assert_not_called()

channel.close()

0 comments on commit 1b4381f

Please sign in to comment.