Skip to content

Commit

Permalink
Lithops retries and backups (#217)
Browse files Browse the repository at this point in the history
* Lithops map_with_retries and wait_with_retries

* Use map_with_retries

* Handle backups properly
  • Loading branch information
tomwhite authored Jul 3, 2023
1 parent 738c6ee commit 25f5909
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 50 deletions.
79 changes: 29 additions & 50 deletions cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from lithops.executors import FunctionExecutor
from lithops.wait import ALWAYS, ANY_COMPLETED
from six import reraise

from cubed.core.plan import visit_nodes
from cubed.runtime.backup import should_launch_backup
from cubed.runtime.executors.lithops_retries import map_with_retries, wait_with_retries
from cubed.runtime.types import DagExecutor
from cubed.runtime.utils import handle_callbacks
from cubed.vendor.rechunker.types import ParallelPipelines, PipelineExecutor
Expand Down Expand Up @@ -95,39 +95,39 @@ def map_unordered(
# also, lithops extra_args doesn't work for this case
partial_map_function = lambda x: map_function(x, **kwargs)

futures = lithops_function_executor.map(
partial_map_function, inputs, timeout=timeout, include_modules=include_modules
futures = map_with_retries(
lithops_function_executor,
partial_map_function,
inputs,
timeout=timeout,
include_modules=include_modules,
retries=retries,
)
tasks.update({k: v for (k, v) in zip(futures, inputs)})
start_times.update({k: time.monotonic() for k in futures})
pending.extend(futures)

future_to_index = {f: i for i, f in enumerate(futures)}
failure_counts = [0] * len(inputs)

while pending:
finished, pending = lithops_function_executor.wait(
pending, throw_except=False, return_when=return_when, show_progressbar=False
finished, pending = wait_with_retries(
lithops_function_executor,
pending,
throw_except=False,
return_when=return_when,
show_progressbar=False,
)
failed = []
for future in finished:
if future.error:
index = future_to_index[future]
failure_counts[index] = failure_counts[index] + 1
failure_count = failure_counts[index]

if failure_count > retries:
# re-raise exception
# TODO: why does calling status not raise the exception?
future.status(throw_except=True)
reraise(*future._exception)
failed.append(future)
# if the task has a backup that is not done, or is done with no exception, then don't raise this exception
backup = backups.get(future, None)
if backup:
if not backup.done or not backup.error:
continue
future.status(throw_except=True)
end_times[future] = time.monotonic()
if return_stats:
yield future.result(), standardise_lithops_stats(future.stats)
else:
end_times[future] = time.monotonic()
if return_stats:
yield future.result(), standardise_lithops_stats(future.stats)
else:
yield future.result()
yield future.result()

# remove any backup task
if use_backups:
Expand All @@ -137,30 +137,7 @@ def map_unordered(
pending.remove(backup)
del backups[future]
del backups[backup]

if failed:
# rerun and add to pending
inputs_to_rerun = []
input_indexes_to_rerun = []
for fut in failed:
index = future_to_index[fut]
inputs_to_rerun.append(inputs[index])
input_indexes_to_rerun.append(index)
del future_to_index[fut]
# TODO: de-duplicate code from above
futures = lithops_function_executor.map(
partial_map_function,
inputs_to_rerun,
timeout=timeout,
include_modules=include_modules,
)
tasks.update({k: v for (k, v) in zip(futures, inputs_to_rerun)})
start_times.update({k: time.monotonic() for k in futures})
pending.extend(futures)

future_to_index.update(
{f: i for i, f in zip(input_indexes_to_rerun, futures)}
)
backup.cancel()

if use_backups:
now = time.monotonic()
Expand All @@ -170,16 +147,18 @@ def map_unordered(
):
input = tasks[future]
logger.info("Running backup task for %s", input)
futures = lithops_function_executor.map(
futures = map_with_retries(
lithops_function_executor,
partial_map_function,
[input],
timeout=timeout,
include_modules=include_modules,
retries=0, # don't retry backup tasks
)
tasks.update({k: v for (k, v) in zip(futures, [input])})
start_times.update({k: time.monotonic() for k in futures})
pending.extend(futures)
backup = futures[0] # TODO: launch multiple backups at once
backup = futures[0]
backups[future] = backup
backups[backup] = future
time.sleep(1)
Expand Down
161 changes: 161 additions & 0 deletions cubed/runtime/executors/lithops_retries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from lithops import FunctionExecutor
from lithops.future import ResponseFuture
from lithops.wait import ALL_COMPLETED, ALWAYS, ANY_COMPLETED
from six import reraise


class RetryingFuture:
"""
A wrapper around Lithops `ResponseFuture` that takes care of retries.
"""

def __init__(
self,
response_future: ResponseFuture,
map_function: Callable,
input: Any,
map_kwargs: Any = None,
retries: Optional[int] = None,
):
self.response_future = response_future
self.map_function = map_function
self.input = input
self.map_kwargs = map_kwargs or {}
self.retries = retries or 0
self.failure_count = 0
self.cancelled = False

def _inc_failure_count(self):
self.failure_count += 1

def _should_retry(self):
return not self.cancelled and self.failure_count <= self.retries

def _retry(self, function_executor: FunctionExecutor):
inputs = [self.input]
futures_list = function_executor.map(
self.map_function, inputs, **self.map_kwargs
)
self.response_future = futures_list[0]

def cancel(self):
# cancelling will prevent any further retries, but won't affect any running tasks
self.cancelled = True

@property
def done(self):
return self.response_future.done

@property
def error(self):
return self.response_future.error

@property
def _exception(self):
return self.response_future._exception

@property
def stats(self):
return self.response_future.stats

def status(
self,
throw_except: bool = True,
internal_storage: Any = None,
check_only: bool = False,
):
stat = self.response_future.status(
throw_except=throw_except,
internal_storage=internal_storage,
check_only=check_only,
)
if self.response_future.error:
reraise(*self.response_future._exception)
return stat

def result(self, throw_except: bool = True, internal_storage: Any = None):
res = self.response_future.result(
throw_except=throw_except, internal_storage=internal_storage
)
if self.response_future.error:
reraise(*self.response_future._exception)
return res


def map_with_retries(
function_executor: FunctionExecutor,
map_function: Callable,
map_iterdata: List[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
timeout: Optional[int] = None,
include_modules: Optional[List[str]] = [],
retries: Optional[int] = None,
) -> List[RetryingFuture]:
"""
A generalisation of Lithops `map`, with retries.
"""

inputs = list(map_iterdata)
futures_list = function_executor.map(
map_function, inputs, timeout=timeout, include_modules=include_modules
)
return [
RetryingFuture(
f,
map_function=map_function,
input=i,
map_kwargs=dict(timeout=timeout, include_modules=include_modules),
retries=retries,
)
for i, f in zip(inputs, futures_list)
]


def wait_with_retries(
function_executor: FunctionExecutor,
fs: List[RetryingFuture],
throw_except: Optional[bool] = True,
return_when: Optional[Any] = ALL_COMPLETED,
show_progressbar: Optional[bool] = True,
) -> Tuple[List[RetryingFuture], List[RetryingFuture]]:
"""
A generalisation of Lithops `wait`, with retries.
"""

lookup = {f.response_future: f for f in fs}

while True:
response_futures = [f.response_future for f in fs]

done, pending = function_executor.wait(
response_futures,
throw_except=throw_except,
return_when=return_when,
show_progressbar=show_progressbar,
)

retrying_done = []
retrying_pending = [lookup[response_future] for response_future in pending]
for response_future in done:
retrying_future = lookup[response_future]
if response_future.error:
retrying_future._inc_failure_count()
if retrying_future._should_retry():
retrying_future._retry(function_executor)
# put back into pending since we are retrying this input
retrying_pending.append(retrying_future)
lookup[retrying_future.response_future] = retrying_future
else:
retrying_done.append(retrying_future)
else:
retrying_done.append(retrying_future)

if return_when == ALWAYS:
break
elif return_when == ANY_COMPLETED and len(retrying_done) > 0:
break
elif return_when == ALL_COMPLETED and len(retrying_pending) == 0:
break

return retrying_done, retrying_pending
69 changes: 69 additions & 0 deletions cubed/tests/runtime/test_lithops_retries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from lithops.executors import LocalhostExecutor

from cubed.runtime.executors.lithops_retries import map_with_retries, wait_with_retries
from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure


def run_test(function, input, retries, timeout=10):
with LocalhostExecutor() as executor:
futures = map_with_retries(
executor,
function,
input,
timeout=timeout,
retries=retries,
)
done, pending = wait_with_retries(executor, futures, throw_except=False)
assert len(pending) == 0
outputs = set(f.result() for f in done)
return outputs


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
# no failures
({}, 3, 2),
# first invocation fails
({0: [-1], 1: [-1], 2: [-1]}, 3, 2),
# first two invocations fail
({0: [-1, -1], 1: [-1, -1], 2: [-1, -1]}, 3, 2),
# first input sleeps once
({0: [20]}, 3, 2),
],
)
# fmt: on
def test_success(tmp_path, timing_map, n_tasks, retries):
partial_map_function = lambda x: deterministic_failure(tmp_path, timing_map, x)
outputs = run_test(
function=partial_map_function,
input=range(n_tasks),
retries=retries,
)

assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
# too many failures
({0: [-1], 1: [-1], 2: [-1, -1, -1]}, 3, 2),
],
)
# fmt: on
def test_failure(tmp_path, timing_map, n_tasks, retries):
partial_map_function = lambda x: deterministic_failure(tmp_path, timing_map, x)
with pytest.raises(RuntimeError):
run_test(
function=partial_map_function,
input=range(n_tasks),
retries=retries,
)

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)

0 comments on commit 25f5909

Please sign in to comment.