Skip to content

Commit

Permalink
Support faster shutdown.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Dec 18, 2024
1 parent 327854c commit c3d48fb
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 32 deletions.
64 changes: 36 additions & 28 deletions huey/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def create_logger(self):
def initialize(self):
pass

def sleep_for_interval(self, start_ts, nseconds):
def sleep_for_interval(self, evt, start_ts, nseconds):
"""
Sleep for a given interval with respect to the start timestamp.
Expand All @@ -58,9 +58,9 @@ def sleep_for_interval(self, start_ts, nseconds):
# pre-empted by the kernel while logging.
sleep_time = nseconds - (time_clock() - start_ts)
if sleep_time > 0:
time.sleep(sleep_time)
evt.wait(sleep_time)

def loop(self, now=None):
def loop(self, evt, now=None):
"""
Process-specific implementation. Called repeatedly for as long as the
consumer is running. The `now` parameter is currently only used in the
Expand Down Expand Up @@ -107,7 +107,7 @@ def shutdown(self):
except Exception as exc:
self._logger.exception('shutdown hook "%s" failed', name)

def loop(self, now=None):
def loop(self, evt, now=None):
task = None
try:
task = self.huey.dequeue()
Expand All @@ -123,13 +123,13 @@ def loop(self, now=None):
self._logger.exception('Unhandled error during execution '
'of task %s.', task.id)
elif not self.huey.storage.blocking:
self.sleep()
self.sleep(evt)

def sleep(self):
def sleep(self, evt):
if self.delay > self.max_delay:
self.delay = self.max_delay

time.sleep(self.delay)
evt.wait(self.delay)
self.delay *= self.backoff


Expand All @@ -153,7 +153,7 @@ def __init__(self, huey, interval, periodic):
self._next_loop = time_clock()
self._next_periodic = time_clock()

def loop(self, now=None):
def loop(self, evt, now=None):
current = self._next_loop
self._next_loop += self.interval
if self._next_loop < time_clock():
Expand All @@ -173,7 +173,7 @@ def loop(self, now=None):
self._next_periodic += self.periodic_task_seconds
self.enqueue_periodic_tasks(now)

self.sleep_for_interval(current, self.interval)
self.sleep_for_interval(evt, current, self.interval)

def enqueue_periodic_tasks(self, now):
self._logger.debug('Checking periodic tasks')
Expand All @@ -186,7 +186,7 @@ class Environment(object):
"""
Provide a common interface to the supported concurrent environments.
"""
def get_stop_flag(self):
def create_event(self):
raise NotImplementedError

def create_process(self, runnable, name):
Expand All @@ -197,7 +197,7 @@ def is_alive(self, proc):


class ThreadEnvironment(Environment):
def get_stop_flag(self):
def create_event(self):
return threading.Event()

def create_process(self, runnable, name):
Expand All @@ -210,7 +210,7 @@ def is_alive(self, proc):


class GreenletEnvironment(Environment):
def get_stop_flag(self):
def create_event(self):
return GreenEvent()

def create_process(self, runnable, name):
Expand All @@ -225,7 +225,7 @@ def is_alive(self, proc):


class ProcessEnvironment(Environment):
def get_stop_flag(self):
def create_event(self):
return ProcessEvent()

def create_process(self, runnable, name):
Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(self, huey, workers=1, periodic=True, initial_delay=0.1,
self._received_signal = False
self._restart = False
self._graceful = True
self.stop_flag = self.environment.get_stop_flag()
self.stop_flag = self.environment.create_event()

# In the event the consumer was killed while running a task that held
# a lock, this ensures that all locks are flushed before starting.
Expand All @@ -308,18 +308,19 @@ def __init__(self, huey, workers=1, periodic=True, initial_delay=0.1,

# Create the scheduler process (but don't start it yet).
scheduler = self._create_scheduler()
self.scheduler = self._create_process(scheduler, 'Scheduler')
(self.scheduler,
self.scheduler_evt) = self._create_process(scheduler, 'Scheduler')

# Create the worker process(es) (also not started yet).
self.worker_threads = []
for i in range(workers):
worker = self._create_worker()
process = self._create_process(worker, 'Worker-%d' % (i + 1))
process, evt = self._create_process(worker, 'Worker-%d' % (i + 1))

# The worker threads are stored as [(worker impl, worker_t), ...].
# The worker impl is not currently referenced in any consumer code,
# but it is referenced in the test-suite.
self.worker_threads.append((worker, process))
self.worker_threads.append((worker, process, evt))

def flush_locks(self, *names):
self._logger.debug('Flushing locks before starting up.')
Expand Down Expand Up @@ -352,21 +353,24 @@ def _create_process(self, process, name):
Repeatedly call the `loop()` method of the given process. Unhandled
exceptions in the `loop()` method will cause the process to terminate.
"""
evt = self.environment.create_event()

def _run():
if self.worker_type == WORKER_PROCESS:
self._set_child_signal_handlers()

process.initialize()
try:
while not self.stop_flag.is_set():
process.loop()
process.loop(evt)
except KeyboardInterrupt:
pass
except:
self._logger.exception('Process %s died!', name)
finally:
process.shutdown()
return self.environment.create_process(_run, name)

return self.environment.create_process(_run, name), evt

def start(self):
"""
Expand Down Expand Up @@ -402,7 +406,7 @@ def start(self):

# Start the scheduler and workers.
self.scheduler.start()
for _, worker_process in self.worker_threads:
for _, worker_process, _ in self.worker_threads:
worker_process.start()

# Finally set the signal handlers for main process.
Expand All @@ -419,7 +423,9 @@ def stop(self, graceful=False):
if graceful:
self._logger.info('Shutting down gracefully...')
try:
for _, worker_process in self.worker_threads:
self.scheduler_evt.set()
for _, worker_process, worker_evt in self.worker_threads:
worker_evt.set()
worker_process.join()
self.scheduler.join()
except KeyboardInterrupt:
Expand Down Expand Up @@ -452,7 +458,7 @@ def run(self):
else:
self._logger.info('Consumer exiting.')

def loop(self, health_check_ts=None):
def loop(self, evt, health_check_ts=None):
try:
self.stop_flag.wait(timeout=self._stop_flag_timeout)
except KeyboardInterrupt:
Expand Down Expand Up @@ -485,14 +491,15 @@ def check_worker_health(self):
self._logger.debug('Checking worker health.')
workers = []
restart_occurred = False
for i, (worker, worker_t) in enumerate(self.worker_threads):
for i, (worker, worker_t, worker_e) in enumerate(self.worker_threads):
if not self.environment.is_alive(worker_t):
self._logger.warning('Worker %d died, restarting.', i + 1)
worker = self._create_worker()
worker_t = self._create_process(worker, 'Worker-%d' % (i + 1))
worker_t, worker_e = self._create_process(worker,
'Worker-%d' % (i + 1))
worker_t.start()
restart_occurred = True
workers.append((worker, worker_t))
workers.append((worker, worker_t, worker_e))

if restart_occurred:
self.worker_threads = workers
Expand All @@ -501,8 +508,9 @@ def check_worker_health(self):

if not self.environment.is_alive(self.scheduler):
self._logger.warning('Scheduler died, restarting.')
scheduler = self._create_scheduler()
self.scheduler = self._create_process(scheduler, 'Scheduler')
scheduler, scheduler_evt = self._create_scheduler()
(self.scheduler,
self.scheduler_evt) = self._create_process(scheduler, 'Scheduler')
self.scheduler.start()
else:
self._logger.debug('Scheduler is up and running.')
Expand Down Expand Up @@ -539,7 +547,7 @@ def _handle_stop_signal(self, sig_num, frame):
self._graceful = False
if self.worker_type == WORKER_GREENLET:
def kill_workers():
gevent.killall([t for _, t in self.worker_threads],
gevent.killall([t for _, t, _ in self.worker_threads],
KeyboardInterrupt)
gevent.spawn(kill_workers)

Expand Down
9 changes: 5 additions & 4 deletions huey/tests/test_consumer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import threading
import time

from huey.api import crontab
Expand All @@ -11,7 +12,7 @@

class TestConsumer(Consumer):
class _Scheduler(Scheduler):
def sleep_for_interval(self, current, interval):
def sleep_for_interval(self, evt, current, interval):
pass
scheduler_class = _Scheduler

Expand All @@ -29,16 +30,16 @@ def task_a(n):
self.assertEqual(result.get(blocking=True, timeout=2), 2)

def work_on_tasks(self, consumer, n=1, now=None):
worker, _ = consumer.worker_threads[0]
worker, _, evt = consumer.worker_threads[0]
for i in range(n):
self.assertEqual(len(self.huey), n - i)
worker.loop(now)
worker.loop(evt, now=now)

def schedule_tasks(self, consumer, now=None):
scheduler = consumer._create_scheduler()
scheduler._next_loop = time_clock() + 60
scheduler._next_periodic = time_clock() - 60
scheduler.loop(now)
scheduler.loop(threading.Event(), now=now)

def test_consumer_schedule_task(self):
@self.huey.task()
Expand Down
1 change: 1 addition & 0 deletions huey/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
import threading
import time
import unittest
try:
from queue import Queue
Expand Down

0 comments on commit c3d48fb

Please sign in to comment.