-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
274 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import asyncio | ||
from typing import ( | ||
Any, | ||
AsyncIterator, | ||
Callable, | ||
Dict, | ||
Iterable, | ||
List, | ||
Optional, | ||
Sequence, | ||
Tuple, | ||
Union, | ||
) | ||
|
||
from aiostream import stream | ||
from aiostream.core import Stream | ||
from dask.distributed import Client | ||
from networkx import MultiDiGraph | ||
|
||
from cubed.core.array import Callback, Spec | ||
from cubed.core.plan import visit_node_generations, visit_nodes | ||
from cubed.primitive.types import CubedPipeline | ||
from cubed.runtime.executors.asyncio import async_map_unordered | ||
from cubed.runtime.types import DagExecutor | ||
from cubed.runtime.utils import execution_stats, gensym, handle_callbacks | ||
|
||
|
||
# note we can't call `pipeline_func` just `func` here as it clashes with `dask.distributed.Client.map`` | ||
@execution_stats | ||
def run_func(input, pipeline_func=None, config=None, name=None): | ||
result = pipeline_func(input, config=config) | ||
return result | ||
|
||
|
||
async def map_unordered( | ||
client: Client, | ||
map_function: Callable[..., Any], | ||
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]], | ||
retries: int = 2, | ||
use_backups: bool = False, | ||
return_stats: bool = False, | ||
name: Optional[str] = None, | ||
**kwargs, | ||
) -> AsyncIterator[Any]: | ||
def create_futures_func(input, **kwargs): | ||
input = list(input) # dask expects a sequence (it calls `len` on it) | ||
key = name or gensym("map") | ||
key = key.replace("-", "_") # otherwise array number is not shown on dashboard | ||
return [ | ||
(i, asyncio.ensure_future(f)) | ||
for i, f in zip( | ||
input, | ||
client.map(map_function, input, key=key, retries=retries, **kwargs), | ||
) | ||
] | ||
|
||
def create_backup_futures_func(input, **kwargs): | ||
input = list(input) # dask expects a sequence (it calls `len` on it) | ||
key = name or gensym("backup") | ||
key = key.replace("-", "_") # otherwise array number is not shown on dashboard | ||
return [ | ||
(i, asyncio.ensure_future(f)) | ||
for i, f in zip(input, client.map(map_function, input, key=key, **kwargs)) | ||
] | ||
|
||
async for result in async_map_unordered( | ||
create_futures_func, | ||
map_iterdata, | ||
use_backups=use_backups, | ||
create_backup_futures_func=create_backup_futures_func, | ||
return_stats=return_stats, | ||
name=name, | ||
**kwargs, | ||
): | ||
yield result | ||
|
||
|
||
def pipeline_to_stream( | ||
client: Client, name: str, pipeline: CubedPipeline, **kwargs | ||
) -> Stream: | ||
return stream.iterate( | ||
map_unordered( | ||
client, | ||
run_func, | ||
pipeline.mappable, | ||
return_stats=True, | ||
name=name, | ||
pipeline_func=pipeline.function, | ||
config=pipeline.config, | ||
**kwargs, | ||
) | ||
) | ||
|
||
|
||
async def async_execute_dag( | ||
dag: MultiDiGraph, | ||
callbacks: Optional[Sequence[Callback]] = None, | ||
array_names: Optional[Sequence[str]] = None, | ||
resume: Optional[bool] = None, | ||
spec: Optional[Spec] = None, | ||
compute_arrays_in_parallel: Optional[bool] = None, | ||
compute_kwargs: Optional[Dict[str, Any]] = None, | ||
**kwargs, | ||
) -> None: | ||
compute_kwargs = compute_kwargs or {} | ||
async with Client(asynchronous=True, **compute_kwargs) as client: | ||
if not compute_arrays_in_parallel: | ||
# run one pipeline at a time | ||
for name, node in visit_nodes(dag, resume=resume): | ||
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs) | ||
async with st.stream() as streamer: | ||
async for _, stats in streamer: | ||
handle_callbacks(callbacks, stats) | ||
else: | ||
for gen in visit_node_generations(dag, resume=resume): | ||
# run pipelines in the same topological generation in parallel by merging their streams | ||
streams = [ | ||
pipeline_to_stream(client, name, node["pipeline"], **kwargs) | ||
for name, node in gen | ||
] | ||
merged_stream = stream.merge(*streams) | ||
async with merged_stream.stream() as streamer: | ||
async for _, stats in streamer: | ||
handle_callbacks(callbacks, stats) | ||
|
||
|
||
class AsyncDaskDistributedExecutor(DagExecutor): | ||
"""An execution engine that uses Dask Distributed's async API.""" | ||
|
||
def __init__(self, **kwargs): | ||
self.kwargs = kwargs | ||
|
||
def execute_dag( | ||
self, | ||
dag: MultiDiGraph, | ||
callbacks: Optional[Sequence[Callback]] = None, | ||
array_names: Optional[Sequence[str]] = None, | ||
resume: Optional[bool] = None, | ||
spec: Optional[Spec] = None, | ||
compute_kwargs: Optional[Dict[str, Any]] = None, | ||
**kwargs, | ||
) -> None: | ||
merged_kwargs = {**self.kwargs, **kwargs} | ||
asyncio.run( | ||
async_execute_dag( | ||
dag, | ||
callbacks=callbacks, | ||
array_names=array_names, | ||
resume=resume, | ||
spec=spec, | ||
compute_kwargs=compute_kwargs, | ||
**merged_kwargs, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import asyncio | ||
from functools import partial | ||
|
||
import pytest | ||
|
||
from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure | ||
|
||
pytest.importorskip("dask.distributed") | ||
|
||
from dask.distributed import Client | ||
|
||
from cubed.runtime.executors.dask_distributed_async import map_unordered | ||
|
||
|
||
async def run_test(function, input, retries, use_backups=False): | ||
outputs = set() | ||
async with Client(asynchronous=True) as client: | ||
async for output in map_unordered( | ||
client, | ||
function, | ||
input, | ||
retries=retries, | ||
use_backups=use_backups, | ||
): | ||
outputs.add(output) | ||
return outputs | ||
|
||
|
||
# fmt: off | ||
@pytest.mark.parametrize( | ||
"timing_map, n_tasks, retries", | ||
[ | ||
# no failures | ||
({}, 3, 2), | ||
# first invocation fails | ||
({0: [-1], 1: [-1], 2: [-1]}, 3, 2), | ||
# first two invocations fail | ||
({0: [-1, -1], 1: [-1, -1], 2: [-1, -1]}, 3, 2), | ||
# first input sleeps once (not tested since timeout is not supported) | ||
# ({0: [20]}, 3, 2), | ||
], | ||
) | ||
# fmt: on | ||
def test_success(tmp_path, timing_map, n_tasks, retries): | ||
outputs = asyncio.run( | ||
run_test( | ||
function=partial(deterministic_failure, tmp_path, timing_map), | ||
input=range(n_tasks), | ||
retries=retries, | ||
) | ||
) | ||
|
||
assert outputs == set(range(n_tasks)) | ||
|
||
check_invocation_counts(tmp_path, timing_map, n_tasks, retries) | ||
|
||
|
||
# fmt: off | ||
@pytest.mark.parametrize( | ||
"timing_map, n_tasks, retries", | ||
[ | ||
# too many failures | ||
({0: [-1], 1: [-1], 2: [-1, -1, -1]}, 3, 2), | ||
], | ||
) | ||
# fmt: on | ||
def test_failure(tmp_path, timing_map, n_tasks, retries): | ||
with pytest.raises(RuntimeError): | ||
asyncio.run( | ||
run_test( | ||
function=partial(deterministic_failure, tmp_path, timing_map), | ||
input=range(n_tasks), | ||
retries=retries, | ||
) | ||
) | ||
|
||
check_invocation_counts(tmp_path, timing_map, n_tasks, retries) | ||
|
||
|
||
# fmt: off | ||
@pytest.mark.parametrize( | ||
"timing_map, n_tasks, retries", | ||
[ | ||
({0: [60]}, 10, 2), | ||
], | ||
) | ||
# fmt: on | ||
def test_stragglers(tmp_path, timing_map, n_tasks, retries): | ||
outputs = asyncio.run( | ||
run_test( | ||
function=partial(deterministic_failure, tmp_path, timing_map), | ||
input=range(n_tasks), | ||
retries=retries, | ||
use_backups=True, | ||
) | ||
) | ||
|
||
assert outputs == set(range(n_tasks)) | ||
|
||
check_invocation_counts(tmp_path, timing_map, n_tasks, retries) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters