Skip to content

Commit

Permalink
Add type hints for runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jul 13, 2023
1 parent be21230 commit 4e6fa7a
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 62 deletions.
19 changes: 11 additions & 8 deletions cubed/runtime/backup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import math
from typing import Dict, TypeVar

T = TypeVar("T")


def should_launch_backup(
task,
now,
start_times,
end_times,
min_tasks=10,
min_completed_fraction=0.5,
slow_factor=3.0,
):
task: T,
now: float,
start_times: Dict[T, float],
end_times: Dict[T, float],
min_tasks: int = 10,
min_completed_fraction: float = 0.5,
slow_factor: float = 3.0,
) -> bool:
"""
Determine whether to launch a backup task.
Expand Down
65 changes: 51 additions & 14 deletions cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
import copy
import logging
import time
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)

from lithops.executors import FunctionExecutor
from lithops.wait import ALWAYS, ANY_COMPLETED
from networkx import MultiDiGraph

from cubed.core.array import Callback
from cubed.core.plan import visit_nodes
from cubed.runtime.backup import should_launch_backup
from cubed.runtime.executors.lithops_retries import map_with_retries, wait_with_retries
from cubed.runtime.executors.lithops_retries import (
RetryingFuture,
map_with_retries,
wait_with_retries,
)
from cubed.runtime.types import DagExecutor
from cubed.runtime.utils import handle_callbacks

Expand All @@ -20,16 +38,16 @@ def run_func(input, func=None, config=None, name=None):


def map_unordered(
lithops_function_executor,
map_function,
map_iterdata,
include_modules=[],
timeout=None,
retries=2,
use_backups=False,
return_stats=False,
lithops_function_executor: FunctionExecutor,
map_function: Callable[..., Any],
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
include_modules: List[str] = [],
timeout: Optional[int] = None,
retries: int = 2,
use_backups: bool = False,
return_stats: bool = False,
**kwargs,
):
) -> Iterator[Any]:
"""
Apply a function to items of an input list, yielding results as they are completed
(which may be different to the input order).
Expand All @@ -52,7 +70,7 @@ def map_unordered(
tasks = {}
start_times = {}
end_times = {}
backups = {}
backups: Dict[RetryingFuture, RetryingFuture] = {}
pending = []

# can't use functools.partial here as we get an error in lithops
Expand Down Expand Up @@ -128,7 +146,13 @@ def map_unordered(
time.sleep(1)


def execute_dag(dag, callbacks=None, array_names=None, resume=None, **kwargs):
def execute_dag(
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
**kwargs,
) -> None:
use_backups = kwargs.pop("use_backups", False)
with FunctionExecutor(**kwargs) as executor:
for name, node in visit_nodes(dag, resume=resume):
Expand Down Expand Up @@ -168,6 +192,19 @@ class LithopsDagExecutor(DagExecutor):
def __init__(self, **kwargs):
self.kwargs = kwargs

def execute_dag(self, dag, callbacks=None, array_names=None, **kwargs):
def execute_dag(
self,
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
execute_dag(dag, callbacks=callbacks, array_names=array_names, **merged_kwargs)
execute_dag(
dag,
callbacks=callbacks,
array_names=array_names,
resume=resume,
**merged_kwargs,
)
4 changes: 2 additions & 2 deletions cubed/runtime/executors/lithops_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RetryingFuture:
def __init__(
self,
response_future: ResponseFuture,
map_function: Callable,
map_function: Callable[..., Any],
input: Any,
map_kwargs: Any = None,
retries: Optional[int] = None,
Expand Down Expand Up @@ -86,7 +86,7 @@ def result(self, throw_except: bool = True, internal_storage: Any = None):

def map_with_retries(
function_executor: FunctionExecutor,
map_function: Callable,
map_function: Callable[..., Any],
map_iterdata: List[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
timeout: Optional[int] = None,
include_modules: Optional[List[str]] = [],
Expand Down
29 changes: 25 additions & 4 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import time
from asyncio.exceptions import TimeoutError
from typing import Optional, Sequence

import modal
from modal.exception import ConnectionError
from networkx import MultiDiGraph
from tenacity import retry, retry_if_exception_type, stop_after_attempt

from cubed.core.array import Callback
from cubed.core.plan import visit_nodes
from cubed.runtime.types import DagExecutor
from cubed.runtime.utils import execute_with_stats, handle_callbacks
Expand Down Expand Up @@ -90,8 +93,13 @@ def run_remotely(self, input, func=None, config=None):
stop=stop_after_attempt(3),
)
def execute_dag(
dag, callbacks=None, array_names=None, resume=None, cloud=None, **kwargs
):
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
cloud: Optional[str] = None,
**kwargs,
) -> None:
with stub.run():
cloud = cloud or "aws"
if cloud == "aws":
Expand Down Expand Up @@ -124,6 +132,19 @@ class ModalDagExecutor(DagExecutor):
def __init__(self, **kwargs):
self.kwargs = kwargs

def execute_dag(self, dag, callbacks=None, array_names=None, **kwargs):
def execute_dag(
self,
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
execute_dag(dag, callbacks=callbacks, array_names=array_names, **merged_kwargs)
execute_dag(
dag,
callbacks=callbacks,
array_names=array_names,
resume=resume,
**merged_kwargs,
)
38 changes: 27 additions & 11 deletions cubed/runtime/executors/modal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import copy
import time
from asyncio.exceptions import TimeoutError
from typing import Any, AsyncIterator, Dict, Iterable, Optional, Sequence

from modal.exception import ConnectionError
from modal.function import Function
from networkx import MultiDiGraph
from tenacity import retry, retry_if_exception_type, stop_after_attempt

from cubed.core.array import Callback
from cubed.core.plan import visit_nodes
from cubed.runtime.backup import should_launch_backup
from cubed.runtime.executors.modal import Container, run_remotely, stub
Expand All @@ -15,14 +19,14 @@

# We need map_unordered for the use_backups implementation
async def map_unordered(
app_function,
input,
use_backups=False,
backup_function=None,
return_stats=False,
name=None,
app_function: Function,
input: Iterable[Any],
use_backups: bool = False,
backup_function: Optional[Function] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
):
) -> AsyncIterator[Any]:
"""
Apply a function to items of an input list, yielding results as they are completed
(which may be different to the input order).
Expand Down Expand Up @@ -57,7 +61,7 @@ async def map_unordered(
t = time.monotonic()
start_times = {f: t for f in pending}
end_times = {}
backups = {}
backups: Dict[asyncio.Future, asyncio.Future] = {}

while pending:
finished, pending = await asyncio.wait(
Expand Down Expand Up @@ -118,8 +122,13 @@ async def map_unordered(
stop=stop_after_attempt(3),
)
async def async_execute_dag(
dag, callbacks=None, array_names=None, resume=None, cloud=None, **kwargs
):
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
cloud: Optional[str] = None,
**kwargs,
) -> None:
async with stub.run():
cloud = cloud or "aws"
if cloud == "aws":
Expand Down Expand Up @@ -153,7 +162,14 @@ class AsyncModalDagExecutor(DagExecutor):
def __init__(self, **kwargs):
self.kwargs = kwargs

def execute_dag(self, dag, callbacks=None, array_names=None, resume=None, **kwargs):
def execute_dag(
self,
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
async_execute_dag(
Expand Down
19 changes: 15 additions & 4 deletions cubed/runtime/executors/python.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
from typing import Any, Callable, Optional, Sequence

from networkx import MultiDiGraph
from tenacity import retry, stop_after_attempt

from cubed.core.array import TaskEndEvent
from cubed.core.array import Callback, TaskEndEvent
from cubed.core.plan import visit_nodes
from cubed.primitive.types import CubedPipeline
from cubed.runtime.types import DagExecutor


@retry(reraise=True, stop=stop_after_attempt(3))
def exec_stage_func(func, *args, **kwargs):
def exec_stage_func(func: Callable[..., Any], *args, **kwargs):
return func(*args, **kwargs)


class PythonDagExecutor(DagExecutor):
"""The default execution engine that runs tasks sequentially uses Python loops."""

def execute_dag(self, dag, callbacks=None, array_names=None, resume=None, **kwargs):
def execute_dag(
self,
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
**kwargs,
) -> None:
for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
pipeline: CubedPipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
for m in stage.mappable:
Expand Down
Loading

0 comments on commit 4e6fa7a

Please sign in to comment.