Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ArrayNode mapping over Launch Plans #2480

Merged
merged 31 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
62031c3
set up array node
pvditt Jun 12, 2024
2293b6a
wip array node task wrapper
pvditt Jun 12, 2024
9579da6
support function like callability
pvditt Jun 13, 2024
683b50a
Merge branch 'master' into array-node
pvditt Jun 13, 2024
11d0dca
temp check in some progress on python func wrapper
pvditt Jun 18, 2024
4155c68
only support launch plans in new array node class for now
pvditt Jun 18, 2024
99322e9
add map task array node implementation wrapper
pvditt Jun 18, 2024
2b051c3
ArrayNode only supports LPs for now
pvditt Jun 18, 2024
2c6991e
support local execute for new array node implementation
pvditt Jun 19, 2024
80f7140
add local execute unit tests for array node
pvditt Jun 21, 2024
8faf457
set exeucution version in array node spec
pvditt Jun 27, 2024
52bae8f
Merge branch 'master' into array-node
pvditt Jul 22, 2024
f0bc6dc
check input types for local execute
pvditt Jul 22, 2024
5e8cd9d
remove code that is un-needed for now
pvditt Jul 23, 2024
7ff17b0
clean up array node class
pvditt Jul 24, 2024
91a6438
improve naming
pvditt Jul 24, 2024
d5b32d2
clean up
pvditt Jul 25, 2024
ced84ba
utilize enum execution mode to set array node execution path
pvditt Jul 25, 2024
45cc9ac
default execution mode to FULL_STATE for new array node class
pvditt Jul 25, 2024
d14f97e
support min_successes for new array node
pvditt Jul 26, 2024
7315ab5
add map task wrapper unit test
pvditt Jul 26, 2024
e56e57c
set min successes for array node map task wrapper
pvditt Jul 26, 2024
f12c7cc
Merge branch 'master' into array-node
pvditt Jul 26, 2024
c3f6fe7
update docstrings
pvditt Jul 26, 2024
8ae1ae2
Install flyteidl from master in plugins tests
eapolinario Jul 29, 2024
7fef6d0
Merge remote-tracking branch 'origin/install-latest-flyteidl-in-flyte…
pvditt Jul 30, 2024
34210cf
Merge branch 'master' into array-node
pvditt Jul 30, 2024
21d286a
lint
pvditt Jul 30, 2024
5b7fbcc
clean up min success/ratio setting
pvditt Jul 30, 2024
c59c427
lint
pvditt Jul 30, 2024
8340c08
make array node class callable
pvditt Jul 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ jobs:
uv pip install --system .
if [ -f dev-requirements.in ]; then uv pip install --system -r dev-requirements.in; fi
# TODO: move to protobuf>=5. Github issue: https://github.com/flyteorg/flyte/issues/5448
uv pip install --system -U $GITHUB_WORKSPACE "protobuf<5"
uv pip install --system -U $GITHUB_WORKSPACE "protobuf<5" "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl"
# TODO: remove this when numpy v2 in onnx has been resolved
if [[ ${{ matrix.plugin-names }} == *"onnx"* || ${{ matrix.plugin-names }} == "flytekit-sqlalchemy" || ${{ matrix.plugin-names }} == "flytekit-pandera" ]]; then
uv pip install --system "numpy<2.0.0"
Expand Down
222 changes: 222 additions & 0 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import math
from typing import Any, List, Optional, Set, Tuple, Union

from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import ExecutionState, FlyteContext
from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import (
Promise,
VoidPromise,
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import TaskMetadata
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Literal, LiteralCollection, Scalar


class ArrayNode:
def __init__(
self,
target: LaunchPlan,
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
bound_inputs: Optional[Set[str]] = None,
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
metadata: Optional[Union[_workflow_model.NodeMetadata, TaskMetadata]] = None,
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
):
"""
:param target: The target Flyte entity to map over
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param bound_inputs: The set of inputs that should be bound to the map task
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
"""
self.target = target
self._concurrency = concurrency
self._min_successes = min_successes
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode
self.id = target.name

n_outputs = len(self.target.python_interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")

self._bound_inputs: Set[str] = bound_inputs or set(bound_inputs) if bound_inputs else set()

output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
)
self._collection_interface = collection_interface

self.metadata = None
if isinstance(target, LaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if metadata:
if isinstance(metadata, _workflow_model.NodeMetadata):
self.metadata = metadata
else:
raise Exception("Invalid metadata for LaunchPlan. Should be NodeMetadata.")
else:
raise Exception("Only LaunchPlans are supported for now.")

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
# Part of SupportsNodeCreation interface
# TODO - include passed in metadata
return _workflow_model.NodeMetadata(name=self.target.name)

@property
def name(self) -> str:
# Part of SupportsNodeCreation interface
return self.target.name

@property
def python_interface(self) -> flyte_interface.Interface:
# Part of SupportsNodeCreation interface
return self._collection_interface

@property
def bindings(self) -> List[_literal_models.Binding]:
# Required in get_serializable_node
return []

@property
def upstream_nodes(self) -> List[Node]:
# Required in get_serializable_node
return []

@property
def flyte_entity(self) -> Any:
return self.target

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
outputs_expected = True
if not self.python_interface.outputs:
outputs_expected = False

mapped_entity_count = 0
for k in self.python_interface.inputs.keys():
if k not in self._bound_inputs:
v = kwargs[k]
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], self.target.python_interface.inputs[k]):
mapped_entity_count = len(v)
break
else:
raise ValueError(
f"Expected a list of {self.target.python_interface.inputs[k]} but got {type(v)} instead."
)

failed_count = 0
min_successes = mapped_entity_count
if self._min_successes:
min_successes = self._min_successes
elif self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)

literals = []
for i in range(mapped_entity_count):
single_instance_inputs = {}
for k in self.python_interface.inputs.keys():
if k not in self._bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]

# translate Python native inputs to Flyte literals
typed_interface = transform_interface_to_typed_interface(self.target.python_interface)
literal_map = translate_inputs_to_literals(
ctx,
incoming_values=single_instance_inputs,
flyte_interface_types={} if typed_interface is None else typed_interface.inputs,
native_types=self.target.python_interface.inputs,
)
kwargs_literals = {k1: Promise(var=k1, val=v1) for k1, v1 in literal_map.items()}

try:
output = self.target.__call__(**kwargs_literals)
if outputs_expected:
literals.append(output.val)
except Exception as exc:
if outputs_expected:
literal_with_none = Literal(scalar=Scalar(none_type=_literal_models.Void()))
literals.append(literal_with_none)
failed_count += 1
if mapped_entity_count - failed_count < min_successes:
logger.error("The number of successful tasks is lower than the minimum")
raise exc

if outputs_expected:
return Promise(var="o0", val=Literal(collection=LiteralCollection(literals=literals)))
return VoidPromise(self.name)

def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

@property
def min_success_ratio(self) -> Optional[float]:
return self._min_success_ratio

@property
def min_successes(self) -> Optional[int]:
return self._min_successes

@property
def concurrency(self) -> Optional[int]:
return self._concurrency

@property
def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode:
return self._execution_mode


def array_node(
target: Union[LaunchPlan],
concurrency: Optional[int] = None,
min_success_ratio: float = 1.0,
min_successes: Optional[int] = None,
**kwargs,
):
"""
ArrayNode implementation that maps over tasks and other Flyte entities
:param target: The target Flyte entity to map over
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
if not isinstance(target, LaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")

node = ArrayNode(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
)

def callable_entity(**inner_kwargs):
combined_kwargs = {**kwargs, **inner_kwargs}
return flyte_entity_call_handler(node, **combined_kwargs)

return callable_entity
37 changes: 37 additions & 0 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.array_node import array_node
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine, is_annotated
from flytekit.core.utils import timeit
Expand Down Expand Up @@ -347,6 +349,41 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
target: Union[LaunchPlan, PythonFunctionTask],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
**kwargs,
):
"""
Wrapper that creates a map task utilizing either the existing ArrayNodeMapTask
or the drop in replacement ArrayNode implementation
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
:param target: The Flyte entity of which will be mapped over
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
"""
if isinstance(target, LaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
**kwargs,
)
return array_node_map_task(
task_function=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
**kwargs,
)


def array_node_map_task(
task_function: PythonFunctionTask,
concurrency: Optional[int] = None,
# TODO why no min_successes?
Expand Down
6 changes: 5 additions & 1 deletion flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode":


class ArrayNode(_common.FlyteIdlEntity):
def __init__(self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None) -> None:
def __init__(
self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None, execution_mode=None
) -> None:
"""
TODO: docstring
"""
Expand All @@ -390,6 +392,7 @@ def __init__(self, node: "Node", parallelism=None, min_successes=None, min_succe
# TODO either min_successes or min_success_ratio should be set
self._min_successes = min_successes
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode

@property
def node(self) -> "Node":
Expand All @@ -401,6 +404,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
parallelism=self._parallelism,
min_successes=self._min_successes,
min_success_ratio=self._min_success_ratio,
execution_mode=self._execution_mode,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def raw_register(
workflow_model.WorkflowNode,
workflow_model.BranchNode,
workflow_model.TaskNode,
workflow_model.ArrayNode,
),
):
return None
Expand Down
Loading
Loading