From f23893b2734f0c15140a35ace30a84fd72f47fa3 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 27 Apr 2023 17:36:47 +0900 Subject: [PATCH] worker: move clean to after worker_main --- src/dvc_task/worker/temporary.py | 39 +++++++++++++++++++------------- tests/worker/test_temporary.py | 4 +--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/dvc_task/worker/temporary.py b/src/dvc_task/worker/temporary.py index d70434f..28e4e77 100644 --- a/src/dvc_task/worker/temporary.py +++ b/src/dvc_task/worker/temporary.py @@ -3,7 +3,7 @@ import os import threading import time -from typing import Any, List, Mapping +from typing import Any, Dict, List, Mapping, Optional from celery import Celery from celery.utils.nodenames import default_nodename @@ -36,6 +36,15 @@ def __init__( # pylint: disable=too-many-arguments self.timeout = timeout self.config = kwargs + def ping(self, name: str, timeout: float = 1.0) -> Optional[List[Dict[str, Any]]]: + """Ping the specified worker.""" + return self._ping(destination=[default_nodename(name)], timeout=timeout) + + def _ping( + self, *, destination: Optional[List[str]] = None, timeout: float = 1.0 + ) -> Optional[List[Dict[str, Any]]]: + return self.app.control.ping(destination=destination, timeout=timeout) + def start(self, name: str, fsapp_clean: bool = False) -> None: """Start the worker if it does not already exist. @@ -50,12 +59,11 @@ def start(self, name: str, fsapp_clean: bool = False) -> None: # see https://github.com/celery/billiard/issues/247 os.environ["FORKED_BY_MULTIPROCESSING"] = "1" - if not self.app.control.ping(destination=[name]): + if not self.ping(name): monitor = threading.Thread( target=self.monitor, daemon=True, args=(name,), - kwargs={"fsapp_clean": fsapp_clean}, ) monitor.start() config = dict(self.config) @@ -63,6 +71,10 @@ def start(self, name: str, fsapp_clean: bool = False) -> None: argv = ["worker"] argv.extend(self._parse_config(config)) self.app.worker_main(argv=argv) + if fsapp_clean and isinstance(self.app, FSApp): # type: ignore[unreachable] + logger.info("cleaning up FSApp broker.") + self.app.clean() + logger.info("done") @staticmethod def _parse_config(config: Mapping[str, Any]) -> List[str]: @@ -85,13 +97,9 @@ def _parse_config(config: Mapping[str, Any]) -> List[str]: argv.append("-E") return argv - def monitor(self, name: str, fsapp_clean: bool = False) -> None: + def monitor(self, name: str) -> None: """Monitor the worker and stop it when the queue is empty.""" - logger.debug("monitor: waiting for worker to start") nodename = default_nodename(name) - while not self.app.control.ping(destination=[nodename]): - # wait for worker to start - time.sleep(1) def _tasksets(nodes): for taskset in ( @@ -105,17 +113,16 @@ def _tasksets(nodes): if isinstance(self.app, FSApp): yield from self.app.iter_queued() - logger.info("monitor: watching celery worker '%s'", nodename) - while self.app.control.ping(destination=[nodename]): + logger.debug("monitor: watching celery worker '%s'", nodename) + while True: time.sleep(self.timeout) nodes = self.app.control.inspect( # type: ignore[call-arg] - destination=[nodename] + destination=[nodename], + limit=1, ) if nodes is None or not any(tasks for tasks in _tasksets(nodes)): logger.info("monitor: shutting down due to empty queue.") - self.app.control.shutdown(destination=[nodename]) break - if fsapp_clean and isinstance(self.app, FSApp): - logger.info("monitor: cleanup FSApp broker.") - self.app.clean() - logger.info("monitor: done") + logger.debug("monitor: sending shutdown to '%s'.", nodename) + self.app.control.shutdown() + logger.debug("monitor: done") diff --git a/tests/worker/test_temporary.py b/tests/worker/test_temporary.py index 7391fa8..2e35e7b 100644 --- a/tests/worker/test_temporary.py +++ b/tests/worker/test_temporary.py @@ -24,9 +24,7 @@ def test_start(celery_app: Celery, mocker: MockerFixture): assert kwargs["pool"] == TaskPool assert kwargs["concurrency"] == 1 assert kwargs["prefetch_multiplier"] == 1 - thread.assert_called_once_with( - target=worker.monitor, daemon=True, args=(name,), kwargs={"fsapp_clean": False} - ) + thread.assert_called_once_with(target=worker.monitor, daemon=True, args=(name,)) @pytest.mark.flaky(