From 50003707527c74a286f6573a0959ba303921013a Mon Sep 17 00:00:00 2001 From: Kamal Al Marhubi Date: Thu, 4 Apr 2024 16:18:33 -0700 Subject: [PATCH] feature: add middleware after_{consumer,worker}_thread_boot hooks This allows the middleware to run code in the context of a worker or consumer thread before it enters its run loop. This could be used to set up thread-local resources, or as in the author's case, to get a reference to the thread before it does any work. This was proposed on the mailing list [0] with Bogdan accepting the idea in principle [1]. [0]: https://groups.io/g/dramatiq-users/topic/105311701 [1]: https://groups.io/g/dramatiq-users/topic/105311701#258 --- dramatiq/middleware/middleware.py | 10 +++++++--- dramatiq/worker.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dramatiq/middleware/middleware.py b/dramatiq/middleware/middleware.py index 774addfc..5ace255c 100644 --- a/dramatiq/middleware/middleware.py +++ b/dramatiq/middleware/middleware.py @@ -138,18 +138,22 @@ def after_worker_shutdown(self, broker, worker): """Called after the worker process shuts down. """ + def after_consumer_thread_boot(self, broker, thread): + """Called from a consumer thread after it starts but before it starts its run loop. + """ + def before_consumer_thread_shutdown(self, broker, thread): """Called before a consumer thread shuts down. This may be used to clean up thread-local resources (such as Django database connections). + """ - There is no ``after_consumer_thread_boot``. + def after_worker_thread_boot(self, broker, thread): + """Called from a worker thread after it starts but before it starts its run loop. """ def before_worker_thread_shutdown(self, broker, thread): """Called before a worker thread shuts down. This may be used to clean up thread-local resources (such as Django database connections). - - There is no ``after_worker_thread_boot``. """ diff --git a/dramatiq/worker.py b/dramatiq/worker.py index 46e9f0c1..a3058420 100644 --- a/dramatiq/worker.py +++ b/dramatiq/worker.py @@ -247,6 +247,7 @@ def __init__(self, *, broker, queue_name, prefetch, work_queue, worker_timeout): def run(self): self.logger.debug("Running consumer thread...") self.running = True + self.broker.emit_after("consumer_thread_boot", self) while self.running: if self.paused: self.logger.debug("Consumer is paused. Sleeping for %.02fms...", self.worker_timeout) @@ -448,6 +449,7 @@ def __init__(self, *, broker, consumers, work_queue, worker_timeout): def run(self): self.logger.debug("Running worker thread...") self.running = True + self.broker.emit_after("worker_thread_boot", self) while self.running: if self.paused: self.logger.debug("Worker is paused. Sleeping for %.02f...", self.timeout)