Skip to content

Commit

Permalink
Enable Prometheus metrics on the RabbitMQ Worker (#181)
Browse files Browse the repository at this point in the history
* Enable Prometheus metrics on the RabbitMQ Worker

* Extract Prom.  metrics to separate module

* Fix linter errors, fix missing import

* Fix import

* Add comment
  • Loading branch information
catileptic authored Jun 5, 2024
1 parent 8c8e531 commit d6d6129
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 58 deletions.
56 changes: 56 additions & 0 deletions servicelayer/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from prometheus_client import (
Counter,
Histogram,
REGISTRY,
GC_COLLECTOR,
PROCESS_COLLECTOR,
)

# These definitions should be moved as close to the place
# where they are used as possible. However, since we
# support both a homebrewed Worker and one based on
# RabbitMQ, these definitions would come into conflict.

REGISTRY.unregister(GC_COLLECTOR)
REGISTRY.unregister(PROCESS_COLLECTOR)

TASKS_STARTED = Counter(
"servicelayer_tasks_started_total",
"Number of tasks that a worker started processing",
["stage"],
)

TASKS_SUCCEEDED = Counter(
"servicelayer_tasks_succeeded_total",
"Number of successfully processed tasks",
["stage", "retries"],
)

TASKS_FAILED = Counter(
"servicelayer_tasks_failed_total",
"Number of failed tasks",
["stage", "retries", "failed_permanently"],
)

TASK_DURATION = Histogram(
"servicelayer_task_duration_seconds",
"Task duration in seconds",
["stage"],
# The bucket sizes are a rough guess right now, we might want to adjust
# them later based on observed runtimes
buckets=[
0.25,
0.5,
1,
5,
15,
30,
60,
60 * 15,
60 * 30,
60 * 60,
60 * 60 * 2,
60 * 60 * 6,
60 * 60 * 24,
],
)
61 changes: 60 additions & 1 deletion servicelayer/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from queue import Queue, Empty
import platform
from collections import defaultdict
from threading import Thread
from timeit import default_timer

import pika.spec
from prometheus_client import start_http_server

from structlog.contextvars import clear_contextvars, bind_contextvars
import pika
Expand All @@ -22,6 +25,7 @@
from servicelayer.util import pack_now, unpack_int
from servicelayer import settings
from servicelayer.util import service_retries, backoff
from servicelayer import metrics

log = logging.getLogger(__name__)
local = threading.local()
Expand Down Expand Up @@ -385,6 +389,17 @@ def __init__(
version=None,
prefetch_count_mapping=defaultdict(lambda: 1),
):
if settings.SENTRY_DSN:
import sentry_sdk

sentry_sdk.init(
dsn=settings.SENTRY_DSN,
traces_sample_rate=0,
release=settings.SENTRY_RELEASE,
environment=settings.SENTRY_ENVIRONMENT,
send_default_pii=False,
)

self.conn = conn or get_redis()
self.num_threads = num_threads
self.queues = ensure_list(queues)
Expand All @@ -402,6 +417,19 @@ def __init__(
send_default_pii=False,
)

def run_prometheus_server(self):
if not settings.PROMETHEUS_ENABLED:
return

def run_server():
port = settings.PROMETHEUS_PORT
log.info(f"Running Prometheus metrics server on port {port}")
start_http_server(port)

thread = Thread(target=run_server)
thread.start()
thread.join()

def on_signal(self, signal, _):
log.warning(f"Shutting down worker (signal {signal})")
# Exit eagerly without waiting for current task to finish running
Expand Down Expand Up @@ -463,23 +491,52 @@ def handle(self, task: Task, channel):
conn=self.conn, name=dataset_from_collection_id(task.collection_id)
)
if dataset.should_execute(task.task_id):
if task.get_retry_count(self.conn) > settings.WORKER_RETRY:
task_retry_count = task.get_retry_count(self.conn)
if task_retry_count:
metrics.TASKS_FAILED.labels(
stage=task.operation,
retries=task_retry_count,
failed_permanently=False,
).inc()

if task_retry_count > settings.WORKER_RETRY:
raise MaxRetriesExceededError(
f"Max retries reached for task {task.task_id}. Aborting."
)

dataset.checkout_task(task.task_id, task.operation)
task.increment_retry_count(self.conn)

# Emit Prometheus metrics
metrics.TASKS_STARTED.labels(stage=task.operation).inc()
start_time = default_timer()
log.info(
f"Dispatching task {task.task_id} from job {task.job_id}"
f"to worker {platform.node()}"
)

task = self.dispatch_task(task)

# Emit Prometheus metrics
duration = max(0, default_timer() - start_time)
metrics.TASK_DURATION.labels(stage=task.operation).observe(duration)
metrics.TASKS_SUCCEEDED.labels(
stage=task.operation, retries=task_retry_count
).inc()
else:
log.info(
f"Sending a NACK for message {task.delivery_tag}"
f" for task_id {task.task_id}."
f"Message will be requeued."
)
# In this case, a task ID was found neither in the
# list of Pending, nor the list of Running tasks
# in Redis. It was never attempted.
metrics.TASKS_FAILED.labels(
stage=task.operation,
retries=0,
failed_permanently=True,
).inc()
if channel.is_open:
channel.basic_nack(task.delivery_tag)
except Exception:
Expand Down Expand Up @@ -530,6 +587,8 @@ def run(self):
signal.signal(signal.SIGINT, self.on_signal)
signal.signal(signal.SIGTERM, self.on_signal)

self.run_prometheus_server()

# worker threads
def process():
return self.process(blocking=True)
Expand Down
67 changes: 10 additions & 57 deletions servicelayer/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@
from banal import ensure_list
from abc import ABC, abstractmethod

from prometheus_client import (
start_http_server,
Counter,
Histogram,
REGISTRY,
GC_COLLECTOR,
PROCESS_COLLECTOR,
)
from prometheus_client import start_http_server

from servicelayer import settings
from servicelayer.jobs import Stage
from servicelayer.cache import get_redis
from servicelayer.util import unpack_int
from servicelayer import metrics


log = logging.getLogger(__name__)

Expand All @@ -29,50 +24,6 @@
INTERVAL = 2
TASK_FETCH_RETRY = 60 / INTERVAL

REGISTRY.unregister(GC_COLLECTOR)
REGISTRY.unregister(PROCESS_COLLECTOR)

TASKS_STARTED = Counter(
"servicelayer_tasks_started_total",
"Number of tasks that a worker started processing",
["stage"],
)

TASKS_SUCCEEDED = Counter(
"servicelayer_tasks_succeeded_total",
"Number of successfully processed tasks",
["stage", "retries"],
)

TASKS_FAILED = Counter(
"servicelayer_tasks_failed_total",
"Number of failed tasks",
["stage", "retries", "failed_permanently"],
)

TASK_DURATION = Histogram(
"servicelayer_task_duration_seconds",
"Task duration in seconds",
["stage"],
# The bucket sizes are a rough guess right now, we might want to adjust
# them later based on observed runtimes
buckets=[
0.25,
0.5,
1,
5,
15,
30,
60,
60 * 15,
60 * 30,
60 * 60,
60 * 60 * 2,
60 * 60 * 6,
60 * 60 * 24,
],
)


class Worker(ABC):
"""Workers of all microservices, unite!"""
Expand Down Expand Up @@ -108,12 +59,14 @@ def handle_safe(self, task):
retries = unpack_int(task.context.get("retries"))

try:
TASKS_STARTED.labels(stage=task.stage.stage).inc()
metrics.TASKS_STARTED.labels(stage=task.stage.stage).inc()
start_time = default_timer()
self.handle(task)
duration = max(0, default_timer() - start_time)
TASK_DURATION.labels(stage=task.stage.stage).observe(duration)
TASKS_SUCCEEDED.labels(stage=task.stage.stage, retries=retries).inc()
metrics.TASK_DURATION.labels(stage=task.stage.stage).observe(duration)
metrics.TASKS_SUCCEEDED.labels(
stage=task.stage.stage, retries=retries
).inc()
except SystemExit as exc:
self.exit_code = exc.code
self.retry(task)
Expand Down Expand Up @@ -153,7 +106,7 @@ def retry(self, task):
log.warning(
f"Queueing failed task for retry #{retry_count}/{settings.WORKER_RETRY}..." # noqa
)
TASKS_FAILED.labels(
metrics.TASKS_FAILED.labels(
stage=task.stage.stage,
retries=retries,
failed_permanently=False,
Expand All @@ -164,7 +117,7 @@ def retry(self, task):
log.warning(
f"Failed task, exhausted retry count of {settings.WORKER_RETRY}"
)
TASKS_FAILED.labels(
metrics.TASKS_FAILED.labels(
stage=task.stage.stage,
retries=retries,
failed_permanently=True,
Expand Down

0 comments on commit d6d6129

Please sign in to comment.