Skip to content

Commit

Permalink
basic dask executor
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Jun 28, 2023
1 parent 0110a40 commit 1cb0ba2
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import dask

from cubed.core.array import TaskEndEvent
from cubed.core.plan import visit_nodes
from cubed.runtime.types import DagExecutor


def exec_stage_func(func, *args, **kwargs):
# TODO would be good to give the dask tasks useful names
return dask.delayed(func(*args, **kwargs)) # should we add pure=True?


class DaskDelayedExecutor(DagExecutor):
"""Executes each stage using dask.Delayed functions."""

def execute_dag(self, dag, callbacks=None, array_names=None, resume=None, **compute_kwargs):
# Note this currently only builds the task graph for each stage once it gets to that stage in computation

for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
stage_delayed_funcs = []
for m in stage.mappable:
delayed_func = exec_stage_func(stage.function, m, config=pipeline.config)
stage_delayed_funcs.append(delayed_func)
if callbacks is not None:
event = TaskEndEvent(array_name=name)
[callback.on_task_end(event) for callback in callbacks]
else:
delayed_func = exec_stage_func(stage.function, config=pipeline.config)
stage_delayed_funcs = [delayed_func]
if callbacks is not None:
event = TaskEndEvent(array_name=name)
[callback.on_task_end(event) for callback in callbacks]

dask.persist(stage_delayed_funcs, **compute_kwargs)

0 comments on commit 1cb0ba2

Please sign in to comment.