Skip to content

Commit

Permalink
Add async Dask distributed executor
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jul 28, 2023
1 parent 4ca3bce commit 0f8a0b0
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ omit =
cubed/extensions/*
cubed/runtime/executors/beam.py
cubed/runtime/executors/coiled.py
cubed/runtime/executors/dask.py
cubed/runtime/executors/dask*.py
cubed/runtime/executors/lithops.py
cubed/runtime/executors/modal*.py
cubed/vendor/*
154 changes: 154 additions & 0 deletions cubed/runtime/executors/dask_distributed_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import asyncio
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)

from aiostream import stream
from aiostream.core import Stream
from dask.distributed import Client
from networkx import MultiDiGraph

from cubed.core.array import Callback, Spec
from cubed.core.plan import visit_node_generations, visit_nodes
from cubed.primitive.types import CubedPipeline
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.types import DagExecutor
from cubed.runtime.utils import execution_stats, gensym, handle_callbacks


# note we can't call `pipeline_func` just `func` here as it clashes with `dask.distributed.Client.map``
@execution_stats
def run_func(input, pipeline_func=None, config=None, name=None):
result = pipeline_func(input, config=config)
return result


async def map_unordered(
client: Client,
map_function: Callable[..., Any],
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
retries: int = 2,
use_backups: bool = False,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
) -> AsyncIterator[Any]:
def create_futures_func(input, **kwargs):
input = list(input) # dask expects a sequence (it calls `len` on it)
key = name or gensym("map")
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
return [
(i, asyncio.ensure_future(f))
for i, f in zip(
input,
client.map(map_function, input, key=key, retries=retries, **kwargs),
)
]

def create_backup_futures_func(input, **kwargs):
input = list(input) # dask expects a sequence (it calls `len` on it)
key = name or gensym("backup")
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
return [
(i, asyncio.ensure_future(f))
for i, f in zip(input, client.map(map_function, input, key=key, **kwargs))
]

async for result in async_map_unordered(
create_futures_func,
map_iterdata,
use_backups=use_backups,
create_backup_futures_func=create_backup_futures_func,
return_stats=return_stats,
name=name,
**kwargs,
):
yield result


def pipeline_to_stream(
client: Client, name: str, pipeline: CubedPipeline, **kwargs
) -> Stream:
return stream.iterate(
map_unordered(
client,
run_func,
pipeline.mappable,
return_stats=True,
name=name,
pipeline_func=pipeline.function,
config=pipeline.config,
**kwargs,
)
)


async def async_execute_dag(
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
spec: Optional[Spec] = None,
compute_arrays_in_parallel: Optional[bool] = None,
compute_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
compute_kwargs = compute_kwargs or {}
async with Client(asynchronous=True, **compute_kwargs) as client:
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
streams = [
pipeline_to_stream(client, name, node["pipeline"], **kwargs)
for name, node in gen
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)


class AsyncDaskDistributedExecutor(DagExecutor):
"""An execution engine that uses Dask Distributed's async API."""

def __init__(self, **kwargs):
self.kwargs = kwargs

def execute_dag(
self,
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
spec: Optional[Spec] = None,
compute_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
async_execute_dag(
dag,
callbacks=callbacks,
array_names=array_names,
resume=resume,
spec=spec,
compute_kwargs=compute_kwargs,
**merged_kwargs,
)
)
100 changes: 100 additions & 0 deletions cubed/tests/runtime/test_dask_distributed_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import asyncio
from functools import partial

import pytest

from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure

pytest.importorskip("dask.distributed")

from dask.distributed import Client

from cubed.runtime.executors.dask_distributed_async import map_unordered


async def run_test(function, input, retries, use_backups=False):
outputs = set()
async with Client(asynchronous=True) as client:
async for output in map_unordered(
client,
function,
input,
retries=retries,
use_backups=use_backups,
):
outputs.add(output)
return outputs


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
# no failures
({}, 3, 2),
# first invocation fails
({0: [-1], 1: [-1], 2: [-1]}, 3, 2),
# first two invocations fail
({0: [-1, -1], 1: [-1, -1], 2: [-1, -1]}, 3, 2),
# first input sleeps once (not tested since timeout is not supported)
# ({0: [20]}, 3, 2),
],
)
# fmt: on
def test_success(tmp_path, timing_map, n_tasks, retries):
outputs = asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, timing_map),
input=range(n_tasks),
retries=retries,
)
)

assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
# too many failures
({0: [-1], 1: [-1], 2: [-1, -1, -1]}, 3, 2),
],
)
# fmt: on
def test_failure(tmp_path, timing_map, n_tasks, retries):
with pytest.raises(RuntimeError):
asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, timing_map),
input=range(n_tasks),
retries=retries,
)
)

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
({0: [60]}, 10, 2),
],
)
# fmt: on
def test_stragglers(tmp_path, timing_map, n_tasks, retries):
outputs = asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, timing_map),
input=range(n_tasks),
retries=retries,
use_backups=True,
)
)

assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)
11 changes: 11 additions & 0 deletions cubed/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
except ImportError:
pass

try:
from cubed.runtime.executors.dask_distributed_async import (
AsyncDaskDistributedExecutor,
)

ALL_EXECUTORS.append(AsyncDaskDistributedExecutor())

MAIN_EXECUTORS.append(AsyncDaskDistributedExecutor())
except ImportError:
pass

try:
from cubed.runtime.executors.lithops import LithopsDagExecutor

Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ diagnostics = [
]
beam = ["apache-beam", "gcsfs"]
dask = ["dask"]
dask-distributed = ["distributed"]
lithops = ["lithops[aws] >= 2.7.0"]
modal = [
"cubed[diagnostics]",
Expand Down Expand Up @@ -70,6 +71,13 @@ test-dask = [
"pytest-cov",
"pytest-mock",
]
test-dask-distributed = [
"cubed[dask-distributed,diagnostics]",
"dill",
"pytest",
"pytest-cov",
"pytest-mock",
]
test-modal = [
"cubed[modal]",
"dill",
Expand Down

0 comments on commit 0f8a0b0

Please sign in to comment.