diff --git a/.coveragerc b/.coveragerc index 04163484..f7743db3 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,7 +5,7 @@ omit = cubed/extensions/* cubed/runtime/executors/beam.py cubed/runtime/executors/coiled.py - cubed/runtime/executors/dask.py + cubed/runtime/executors/dask*.py cubed/runtime/executors/lithops.py cubed/runtime/executors/modal*.py cubed/vendor/* diff --git a/cubed/runtime/executors/dask_distributed_async.py b/cubed/runtime/executors/dask_distributed_async.py new file mode 100644 index 00000000..7de39827 --- /dev/null +++ b/cubed/runtime/executors/dask_distributed_async.py @@ -0,0 +1,154 @@ +import asyncio +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) + +from aiostream import stream +from aiostream.core import Stream +from dask.distributed import Client +from networkx import MultiDiGraph + +from cubed.core.array import Callback, Spec +from cubed.core.plan import visit_node_generations, visit_nodes +from cubed.primitive.types import CubedPipeline +from cubed.runtime.executors.asyncio import async_map_unordered +from cubed.runtime.types import DagExecutor +from cubed.runtime.utils import execution_stats, gensym, handle_callbacks + + +# note we can't call `pipeline_func` just `func` here as it clashes with `dask.distributed.Client.map`` +@execution_stats +def run_func(input, pipeline_func=None, config=None, name=None): + result = pipeline_func(input, config=config) + return result + + +async def map_unordered( + client: Client, + map_function: Callable[..., Any], + map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]], + retries: int = 2, + use_backups: bool = False, + return_stats: bool = False, + name: Optional[str] = None, + **kwargs, +) -> AsyncIterator[Any]: + def create_futures_func(input, **kwargs): + input = list(input) # dask expects a sequence (it calls `len` on it) + key = name or gensym("map") + key = key.replace("-", "_") # otherwise array number is not shown on dashboard + return [ + (i, asyncio.ensure_future(f)) + for i, f in zip( + input, + client.map(map_function, input, key=key, retries=retries, **kwargs), + ) + ] + + def create_backup_futures_func(input, **kwargs): + input = list(input) # dask expects a sequence (it calls `len` on it) + key = name or gensym("backup") + key = key.replace("-", "_") # otherwise array number is not shown on dashboard + return [ + (i, asyncio.ensure_future(f)) + for i, f in zip(input, client.map(map_function, input, key=key, **kwargs)) + ] + + async for result in async_map_unordered( + create_futures_func, + map_iterdata, + use_backups=use_backups, + create_backup_futures_func=create_backup_futures_func, + return_stats=return_stats, + name=name, + **kwargs, + ): + yield result + + +def pipeline_to_stream( + client: Client, name: str, pipeline: CubedPipeline, **kwargs +) -> Stream: + return stream.iterate( + map_unordered( + client, + run_func, + pipeline.mappable, + return_stats=True, + name=name, + pipeline_func=pipeline.function, + config=pipeline.config, + **kwargs, + ) + ) + + +async def async_execute_dag( + dag: MultiDiGraph, + callbacks: Optional[Sequence[Callback]] = None, + array_names: Optional[Sequence[str]] = None, + resume: Optional[bool] = None, + spec: Optional[Spec] = None, + compute_arrays_in_parallel: Optional[bool] = None, + compute_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, +) -> None: + compute_kwargs = compute_kwargs or {} + async with Client(asynchronous=True, **compute_kwargs) as client: + if not compute_arrays_in_parallel: + # run one pipeline at a time + for name, node in visit_nodes(dag, resume=resume): + st = pipeline_to_stream(client, name, node["pipeline"], **kwargs) + async with st.stream() as streamer: + async for _, stats in streamer: + handle_callbacks(callbacks, stats) + else: + for gen in visit_node_generations(dag, resume=resume): + # run pipelines in the same topological generation in parallel by merging their streams + streams = [ + pipeline_to_stream(client, name, node["pipeline"], **kwargs) + for name, node in gen + ] + merged_stream = stream.merge(*streams) + async with merged_stream.stream() as streamer: + async for _, stats in streamer: + handle_callbacks(callbacks, stats) + + +class AsyncDaskDistributedExecutor(DagExecutor): + """An execution engine that uses Dask Distributed's async API.""" + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def execute_dag( + self, + dag: MultiDiGraph, + callbacks: Optional[Sequence[Callback]] = None, + array_names: Optional[Sequence[str]] = None, + resume: Optional[bool] = None, + spec: Optional[Spec] = None, + compute_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + merged_kwargs = {**self.kwargs, **kwargs} + asyncio.run( + async_execute_dag( + dag, + callbacks=callbacks, + array_names=array_names, + resume=resume, + spec=spec, + compute_kwargs=compute_kwargs, + **merged_kwargs, + ) + ) diff --git a/cubed/tests/runtime/test_dask_distributed_async.py b/cubed/tests/runtime/test_dask_distributed_async.py new file mode 100644 index 00000000..74baee84 --- /dev/null +++ b/cubed/tests/runtime/test_dask_distributed_async.py @@ -0,0 +1,100 @@ +import asyncio +from functools import partial + +import pytest + +from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure + +pytest.importorskip("dask.distributed") + +from dask.distributed import Client + +from cubed.runtime.executors.dask_distributed_async import map_unordered + + +async def run_test(function, input, retries, use_backups=False): + outputs = set() + async with Client(asynchronous=True) as client: + async for output in map_unordered( + client, + function, + input, + retries=retries, + use_backups=use_backups, + ): + outputs.add(output) + return outputs + + +# fmt: off +@pytest.mark.parametrize( + "timing_map, n_tasks, retries", + [ + # no failures + ({}, 3, 2), + # first invocation fails + ({0: [-1], 1: [-1], 2: [-1]}, 3, 2), + # first two invocations fail + ({0: [-1, -1], 1: [-1, -1], 2: [-1, -1]}, 3, 2), + # first input sleeps once (not tested since timeout is not supported) + # ({0: [20]}, 3, 2), + ], +) +# fmt: on +def test_success(tmp_path, timing_map, n_tasks, retries): + outputs = asyncio.run( + run_test( + function=partial(deterministic_failure, tmp_path, timing_map), + input=range(n_tasks), + retries=retries, + ) + ) + + assert outputs == set(range(n_tasks)) + + check_invocation_counts(tmp_path, timing_map, n_tasks, retries) + + +# fmt: off +@pytest.mark.parametrize( + "timing_map, n_tasks, retries", + [ + # too many failures + ({0: [-1], 1: [-1], 2: [-1, -1, -1]}, 3, 2), + ], +) +# fmt: on +def test_failure(tmp_path, timing_map, n_tasks, retries): + with pytest.raises(RuntimeError): + asyncio.run( + run_test( + function=partial(deterministic_failure, tmp_path, timing_map), + input=range(n_tasks), + retries=retries, + ) + ) + + check_invocation_counts(tmp_path, timing_map, n_tasks, retries) + + +# fmt: off +@pytest.mark.parametrize( + "timing_map, n_tasks, retries", + [ + ({0: [60]}, 10, 2), + ], +) +# fmt: on +def test_stragglers(tmp_path, timing_map, n_tasks, retries): + outputs = asyncio.run( + run_test( + function=partial(deterministic_failure, tmp_path, timing_map), + input=range(n_tasks), + retries=retries, + use_backups=True, + ) + ) + + assert outputs == set(range(n_tasks)) + + check_invocation_counts(tmp_path, timing_map, n_tasks, retries) diff --git a/cubed/tests/utils.py b/cubed/tests/utils.py index d5e53460..ff62cb6d 100644 --- a/cubed/tests/utils.py +++ b/cubed/tests/utils.py @@ -40,6 +40,17 @@ except ImportError: pass +try: + from cubed.runtime.executors.dask_distributed_async import ( + AsyncDaskDistributedExecutor, + ) + + ALL_EXECUTORS.append(AsyncDaskDistributedExecutor()) + + MAIN_EXECUTORS.append(AsyncDaskDistributedExecutor()) +except ImportError: + pass + try: from cubed.runtime.executors.lithops import LithopsDagExecutor diff --git a/pyproject.toml b/pyproject.toml index eba75555..962a08f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ diagnostics = [ ] beam = ["apache-beam", "gcsfs"] dask = ["dask"] +dask-distributed = ["distributed"] lithops = ["lithops[aws] >= 2.7.0"] modal = [ "cubed[diagnostics]", @@ -70,6 +71,13 @@ test-dask = [ "pytest-cov", "pytest-mock", ] +test-dask-distributed = [ + "cubed[dask-distributed,diagnostics]", + "dill", + "pytest", + "pytest-cov", + "pytest-mock", +] test-modal = [ "cubed[modal]", "dill",