Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Nanny and Worker plugins #115

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions cads_broker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import distributed
import sqlalchemy as sa
import structlog
from dask.typing import Key
from typing_extensions import Iterable

try:
from cads_worker import worker
import cads_worker.worker
except ModuleNotFoundError:
pass

from cads_broker import Environment, config, factory
from cads_broker import Environment, config, factory, utils
from cads_broker import database as db
from cads_broker.qos import QoS

Expand Down Expand Up @@ -197,6 +198,34 @@ def __init__(self, number_of_workers) -> None:
parser.parse_rules(self.rules, self.environment)


class TempDirNannyPlugin(distributed.NannyPlugin):
def setup(self, nanny: distributed.Nanny) -> None:
path = utils.rm_task_path(nanny, None)
path.mkdir()

def teardown(self, nanny: distributed.Nanny) -> None:
utils.rm_task_path(nanny, None)


class TempDirsWorkerPlugin(distributed.WorkerPlugin):
def setup(self, worker) -> None:
self.worker = worker

def teardown(self, worker: distributed.Worker) -> None:
for key in worker.state.tasks:
utils.rm_task_path(worker, key)

def transition(
self,
key: Key,
start: distributed.worker_state_machine.TaskStateState,
finish: distributed.worker_state_machine.TaskStateState,
**kwargs: Any,
) -> None:
if finish in ("memory", "error"):
utils.rm_task_path(self.worker, key)


@attrs.define
class Broker:
client: distributed.Client
Expand All @@ -218,6 +247,10 @@ class Broker:
internal_scheduler: Scheduler = Scheduler()
queue: Queue = Queue()

def __attrs_post_init__(self):
self.client.register_plugin(TempDirNannyPlugin())
self.client.register_plugin(TempDirsWorkerPlugin())

@classmethod
def from_address(
cls,
Expand Down Expand Up @@ -563,7 +596,7 @@ def submit_request(
)
self.queue.pop(request.request_uid)
future = self.client.submit(
worker.submit_workflow,
cads_worker.worker.submit_workflow,
key=request.request_uid,
setup_code=request.request_body.get("setup_code", ""),
entry_point=request.entry_point,
Expand Down
35 changes: 35 additions & 0 deletions cads_broker/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pathlib
import shutil
from typing import Any

import distributed
from dask.typing import Key


def get_task_path(
worker_or_nanny: distributed.Worker | distributed.Nanny, key: Key | None
) -> pathlib.Path:
if isinstance(worker_or_nanny, distributed.Worker):
root = worker_or_nanny.local_directory
elif isinstance(worker_or_nanny, distributed.Nanny):
root = worker_or_nanny.worker_dir
else:
raise TypeError(
f"`worker_or_nanny` is of the wrong type: {type(worker_or_nanny)}"
)
path = pathlib.Path(root) / "tasks_working_dir"
if key is not None:
path /= str(key)
return path


def rm_task_path(
worker_or_nanny: distributed.Worker | distributed.Nanny,
key: Key | None,
**kwargs: Any,
) -> pathlib.Path:
# This function is used by cads-worker as well.
path = get_task_path(worker_or_nanny, key)
if path.exists():
shutil.rmtree(path, **kwargs)
return path
33 changes: 33 additions & 0 deletions tests/test_20_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import pathlib
import uuid
from typing import Any

Expand Down Expand Up @@ -120,3 +121,35 @@ def mock_get_tasks() -> dict[str, str]:
# with pytest.raises(db.NoResultFound):
# with session_obj() as session:
# db.get_request(dismissed_request_uid, session=session)


def test_plugins(
mocker: pytest_mock.plugin.MockerFixture, session_obj: sa.orm.sessionmaker
) -> None:
environment = Environment.Environment()
qos = QoS.QoS(rules=Rule.RuleSet(), environment=environment, rules_hash="")
broker = dispatcher.Broker(
client=CLIENT,
environment=environment,
qos=qos,
address="scheduler-address",
session_maker_read=session_obj,
session_maker_write=session_obj,
)

def func() -> pathlib.Path:
worker = distributed.get_worker()
key = worker.get_current_task()
task_path = (
pathlib.Path(worker.local_directory) / "tasks_working_dir" / str(key)
)
task_path.mkdir()
return task_path

future = broker.client.submit(func)
task_path = future.result()
assert not task_path.exists()

assert task_path.parent.exists()
broker.client.shutdown()
assert not task_path.parent.exists()
Loading