Skip to content

Commit

Permalink
worker: support prefetch multiplier and pool type config
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed May 10, 2022
1 parent b2698c5 commit d4c50e5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
12 changes: 12 additions & 0 deletions src/dvc_task/worker/temporary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def __init__( # pylint: disable=too-many-arguments
self,
app: Celery,
timeout: int = 60,
pool: Optional[str] = None,
concurrency: Optional[int] = None,
prefetch_multiplier: Optional[int] = None,
loglevel: Optional[str] = None,
task_events: bool = True,
):
Expand All @@ -28,13 +30,17 @@ def __init__( # pylint: disable=too-many-arguments
app: Celery application instance.
timeout: Queue timeout in seconds. Worker will be terminated if the
queue remains empty after timeout.
pool: Worker pool class.
concurrency: Worker concurrency.
prefetch_multiplier: Worker prefetch multiplier.
loglevel: Worker loglevel.
task_events: Enable worker task event monitoring.
"""
self.app = app
self.timeout = timeout
self.pool = pool
self.concurrency = concurrency
self.prefetch_multiplier = prefetch_multiplier
self.loglevel = loglevel or "info"
self.task_events = task_events

Expand All @@ -60,8 +66,14 @@ def start(self, name: str) -> None:
f"--loglevel={self.loglevel}",
f"--hostname={name}",
]
if self.pool:
argv.append(f"--pool={self.pool}")
if self.concurrency:
argv.append(f"--concurrency={self.concurrency}")
if self.prefetch_multiplier:
argv.append(
f"--prefetch-multiplier={self.prefetch_multiplier}"
)
if self.task_events:
argv.append("-E")
self.app.worker_main(argv=argv)
Expand Down
9 changes: 7 additions & 2 deletions tests/worker/test_temporary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from celery import Celery
from celery.concurrency.prefork import TaskPool
from celery.worker.worker import WorkController
from pytest_mock import MockerFixture

Expand All @@ -13,12 +14,16 @@ def test_start(celery_app: Celery, mocker: MockerFixture):
"""Should start underlying Celery worker."""
worker_cls = mocker.patch.object(celery_app, "Worker")
thread = mocker.patch("threading.Thread")
worker = TemporaryWorker(celery_app)
worker = TemporaryWorker(
celery_app, pool="prefork", concurrency=1, prefetch_multiplier=1
)
name = "worker1@localhost"
worker.start(name)
_args, kwargs = worker_cls.call_args
assert kwargs["hostname"] == name
assert kwargs["concurrency"] is None
assert kwargs["pool"] == TaskPool
assert kwargs["concurrency"] == 1
assert kwargs["prefetch_multiplier"] == 1
thread.assert_called_once_with(
target=worker.monitor, daemon=True, args=(name,)
)
Expand Down

0 comments on commit d4c50e5

Please sign in to comment.