Skip to content

Commit 23b51cb

Browse files
authored
Merge pull request #781 from MolSSI/fix_snowflake
Fix issues with snowflake shutdowns
2 parents bcefe2c + 43bc166 commit 23b51cb

File tree

1 file changed

+94
-20
lines changed

1 file changed

+94
-20
lines changed

qcfractal/qcfractal/snowflake.py

+94-20
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,34 @@ def _api_process(
3939
started_event: multiprocessing.Event,
4040
) -> None:
4141

42+
import signal
43+
4244
qh = logging.handlers.QueueHandler(logging_queue)
4345
logger = logging.getLogger()
4446
logger.handlers.clear()
4547
logger.addHandler(qh)
4648

49+
early_stop = False
50+
51+
def signal_handler(signum, frame):
52+
nonlocal early_stop
53+
early_stop = True
54+
55+
signal.signal(signal.SIGINT, signal_handler)
56+
signal.signal(signal.SIGTERM, signal_handler)
57+
4758
api = FractalGunicornApp(qcf_config, finished_queue, started_event)
48-
api.run()
59+
60+
if early_stop:
61+
logging_queue.close()
62+
logging_queue.join_thread()
63+
return
64+
65+
try:
66+
api.run()
67+
finally:
68+
logging_queue.close()
69+
logging_queue.join_thread()
4970

5071

5172
def _compute_process(compute_config: FractalComputeConfig, logging_queue: multiprocessing.Queue) -> None:
@@ -57,15 +78,32 @@ def _compute_process(compute_config: FractalComputeConfig, logging_queue: multip
5778
logger.handlers.clear()
5879
logger.addHandler(qh)
5980

81+
early_stop = False
82+
83+
def signal_handler(signum, frame):
84+
nonlocal early_stop
85+
early_stop = True
86+
87+
signal.signal(signal.SIGINT, signal_handler)
88+
signal.signal(signal.SIGTERM, signal_handler)
89+
6090
compute = ComputeManager(compute_config)
91+
if early_stop:
92+
logging_queue.close()
93+
logging_queue.join_thread()
94+
return
6195

6296
def signal_handler(signum, frame):
6397
compute.stop()
6498

6599
signal.signal(signal.SIGINT, signal_handler)
66100
signal.signal(signal.SIGTERM, signal_handler)
67101

68-
compute.start()
102+
try:
103+
compute.start()
104+
finally:
105+
logging_queue.close()
106+
logging_queue.join_thread()
69107

70108

71109
def _job_runner_process(
@@ -79,25 +117,46 @@ def _job_runner_process(
79117
logger.handlers.clear()
80118
logger.addHandler(qh)
81119

120+
early_stop = False
121+
122+
def signal_handler(signum, frame):
123+
nonlocal early_stop
124+
early_stop = True
125+
126+
signal.signal(signal.SIGINT, signal_handler)
127+
signal.signal(signal.SIGTERM, signal_handler)
128+
82129
job_runner = FractalJobRunner(qcf_config, finished_queue)
83130

131+
if early_stop:
132+
logging_queue.close()
133+
logging_queue.join_thread()
134+
return
135+
84136
def signal_handler(signum, frame):
85137
job_runner.stop()
86138

87139
signal.signal(signal.SIGINT, signal_handler)
88140
signal.signal(signal.SIGTERM, signal_handler)
89141

90-
job_runner.start()
142+
try:
143+
job_runner.start()
144+
finally:
145+
logging_queue.close()
146+
logging_queue.join_thread()
91147

92148

93-
def _logging_thread(logging_queue):
149+
def _logging_thread(logging_queue, logging_thread_stop):
94150
while True:
95-
record = logging_queue.get()
96-
if record is None:
97-
break
98-
logger = logging.getLogger(record.name)
99-
100-
logger.handle(record)
151+
try:
152+
record = logging_queue.get(timeout=0.5)
153+
logger = logging.getLogger(record.name)
154+
logger.handle(record)
155+
except Empty:
156+
if logging_thread_stop.is_set():
157+
break
158+
else:
159+
continue
101160

102161

103162
class FractalSnowflake:
@@ -128,7 +187,10 @@ def __init__(
128187
# See https://docs.python.org/3/howto/logging-cookbook.html#logging-to-a-single-file-from-multiple-processes
129188

130189
self._logging_queue = self._mp_context.Queue()
131-
self._logging_thread = threading.Thread(target=_logging_thread, args=(self._logging_queue,), daemon=True)
190+
self._logging_thread_stop = threading.Event()
191+
self._logging_thread = threading.Thread(
192+
target=_logging_thread, args=(self._logging_queue, self._logging_thread_stop), daemon=True
193+
)
132194
self._logging_thread.start()
133195

134196
# Create a temporary directory for everything
@@ -235,8 +297,9 @@ def _update_finalizer(self):
235297
self._compute_proc,
236298
self._api_proc,
237299
self._job_runner_proc,
238-
self._logging_thread,
239300
self._logging_queue,
301+
self._logging_thread,
302+
self._logging_thread_stop,
240303
)
241304

242305
def _start_api(self):
@@ -292,7 +355,7 @@ def _stop_job_runner(self):
292355
self._update_finalizer()
293356

294357
@classmethod
295-
def _stop(cls, compute_proc, api_proc, job_runner_proc, logging_thread, logging_queue):
358+
def _stop(cls, compute_proc, api_proc, job_runner_proc, logging_queue, logging_thread, logging_thread_stop):
296359
####################################################################################
297360
# This is written as a class method so that it can be called by a weakref finalizer
298361
####################################################################################
@@ -301,7 +364,6 @@ def _stop(cls, compute_proc, api_proc, job_runner_proc, logging_thread, logging_
301364
# First the compute, since it will communicate its demise to the api server
302365
# Flask must be last. It was started first and owns the db
303366

304-
# First, stop all, then join all for better performance
305367
if compute_proc is not None:
306368
compute_proc.terminate()
307369
compute_proc.join()
@@ -314,8 +376,10 @@ def _stop(cls, compute_proc, api_proc, job_runner_proc, logging_thread, logging_
314376
api_proc.terminate()
315377
api_proc.join()
316378

317-
logging_queue.put(None)
379+
logging_thread_stop.set()
318380
logging_thread.join()
381+
logging_queue.close()
382+
logging_queue.join_thread()
319383

320384
def wait_for_api(self):
321385
"""
@@ -363,12 +427,22 @@ def stop(self):
363427
Stops all components of the snowflake
364428
"""
365429

366-
if self._finalizer is not None:
367-
self._finalizer()
430+
if self._compute_proc is not None:
431+
self._compute_proc.terminate()
432+
self._compute_proc.join()
433+
self._compute_proc = None
368434

369-
self._api_proc = None
370-
self._compute_proc = None
371-
self._job_runner_proc = None
435+
if self._job_runner_proc is not None:
436+
self._job_runner_proc.terminate()
437+
self._job_runner_proc.join()
438+
self._job_runner_proc = None
439+
440+
if self._api_proc is not None:
441+
self._api_proc.terminate()
442+
self._api_proc.join()
443+
self._api_proc = None
444+
445+
self._update_finalizer()
372446

373447
def get_uri(self) -> str:
374448
"""

0 commit comments

Comments
 (0)