Skip to content

Commit

Permalink
Change map_unordered to run groups (of functions/mappables) in para…
Browse files Browse the repository at this point in the history
…llel.
  • Loading branch information
tomwhite committed Jul 13, 2023
1 parent 4e4cd90 commit c376a18
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 46 deletions.
106 changes: 62 additions & 44 deletions cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import copy
import logging
import time
from functools import partial
from typing import (
Any,
Callable,
Expand All @@ -19,7 +21,7 @@
from networkx import MultiDiGraph

from cubed.core.array import Callback
from cubed.core.plan import visit_nodes
from cubed.core.plan import visit_node_generations
from cubed.runtime.backup import should_launch_backup
from cubed.runtime.executors.lithops_retries import (
RetryingFuture,
Expand All @@ -39,8 +41,11 @@ def run_func(input, func=None, config=None, name=None):

def map_unordered(
lithops_function_executor: FunctionExecutor,
map_function: Callable[..., Any],
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
group_map_functions: Sequence[Callable[..., Any]],
group_map_iterdata: Sequence[
Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]]
],
group_names: Sequence[str],
include_modules: List[str] = [],
timeout: Optional[int] = None,
retries: int = 2,
Expand All @@ -66,27 +71,32 @@ def map_unordered(
"""
return_when = ALWAYS if use_backups else ANY_COMPLETED

group_name = kwargs.get("name", None)
start_times = {}
end_times = {}
group_name_to_function: Dict[str, Callable[..., Any]] = {}
# backups are launched based on task start and end times for the group
start_times: Dict[str, Dict[RetryingFuture, float]] = {}
end_times: Dict[str, Dict[RetryingFuture, float]] = collections.defaultdict(dict)
backups: Dict[RetryingFuture, RetryingFuture] = {}
pending = []

# can't use functools.partial here as we get an error in lithops
# also, lithops extra_args doesn't work for this case
partial_map_function = lambda x: map_function(x, **kwargs)

futures = map_with_retries(
lithops_function_executor,
partial_map_function,
map_iterdata,
timeout=timeout,
include_modules=include_modules,
retries=retries,
group_name=group_name,
)
start_times.update({k: time.monotonic() for k in futures})
pending.extend(futures)
pending: List[RetryingFuture] = []

for map_function, map_iterdata, group_name in zip(
group_map_functions, group_map_iterdata, group_names
):
# can't use functools.partial here as we get an error in lithops
# also, lithops extra_args doesn't work for this case
partial_map_function = lambda x: map_function(x, **kwargs)
group_name_to_function[group_name] = partial_map_function

futures = map_with_retries(
lithops_function_executor,
partial_map_function,
map_iterdata,
timeout=timeout,
include_modules=include_modules,
retries=retries,
group_name=group_name,
)
start_times[group_name] = {k: time.monotonic() for k in futures}
pending.extend(futures)

while pending:
finished, pending = wait_with_retries(
Expand All @@ -104,7 +114,8 @@ def map_unordered(
if not backup.done or not backup.error:
continue
future.status(throw_except=True)
end_times[future] = time.monotonic()
group_name = future.group_name
end_times[group_name][future] = time.monotonic()
if return_stats:
yield future.result(), standardise_lithops_stats(future)
else:
Expand All @@ -123,20 +134,24 @@ def map_unordered(
if use_backups:
now = time.monotonic()
for future in copy.copy(pending):
group_name = future.group_name
if future not in backups and should_launch_backup(
future, now, start_times, end_times
future, now, start_times[group_name], end_times[group_name]
):
input = future.input
logger.info("Running backup task for %s", input)
futures = map_with_retries(
lithops_function_executor,
partial_map_function,
group_name_to_function[group_name],
[input],
timeout=timeout,
include_modules=include_modules,
retries=0, # don't retry backup tasks
group_name=group_name,
)
start_times[group_name].update(
{k: time.monotonic() for k in futures}
)
start_times.update({k: time.monotonic() for k in futures})
pending.extend(futures)
backup = futures[0]
backups[future] = backup
Expand All @@ -153,23 +168,26 @@ def execute_dag(
) -> None:
use_backups = kwargs.pop("use_backups", False)
with FunctionExecutor(**kwargs) as executor:
for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
for _, stats in map_unordered(
executor,
run_func,
stage.mappable,
func=stage.function,
config=pipeline.config,
name=name,
use_backups=use_backups,
return_stats=True,
):
handle_callbacks(callbacks, stats)
else:
raise NotImplementedError()
for gen in visit_node_generations(dag, resume=resume):
group_map_functions = []
group_map_iterdata = []
group_names = []
for name, node in gen:
pipeline = node["pipeline"]
stage = pipeline.stages[0] # assume one
f = partial(run_func, func=stage.function, config=pipeline.config)
group_map_functions.append(f)
group_map_iterdata.append(stage.mappable)
group_names.append(name)
for _, stats in map_unordered(
executor,
group_map_functions,
group_map_iterdata,
group_names,
use_backups=use_backups,
return_stats=True,
):
handle_callbacks(callbacks, stats)


def standardise_lithops_stats(future: RetryingFuture) -> Dict[str, Any]:
Expand Down
5 changes: 3 additions & 2 deletions cubed/tests/runtime/test_lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def run_test(function, input, retries, timeout=10, use_backups=False):
with LocalhostExecutor() as executor:
for output in map_unordered(
executor,
function,
input,
[function],
[input],
["group0"],
timeout=timeout,
retries=retries,
use_backups=use_backups,
Expand Down

0 comments on commit c376a18

Please sign in to comment.