Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for toggling data mode for array node #2940

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
self.id = target.name
self._bindings = bindings or []
self.metadata = metadata
self._data_mode = None

if min_successes is not None:
self._min_successes = min_successes
Expand Down Expand Up @@ -93,10 +94,12 @@ def __init__(
raise ValueError("No interface found for the target entity.")

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
self._data_mode = _core_workflow.ArrayNode.SINGLE_INPUT_FILE
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
else:
raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")
self._data_mode = _core_workflow.ArrayNode.INDIVIDUAL_INPUT_FILES
# raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
# Part of SupportsNodeCreation interface
Expand Down Expand Up @@ -133,6 +136,10 @@ def upstream_nodes(self) -> List[Node]:
def flyte_entity(self) -> Any:
return self.target

@property
def data_mode(self) -> _core_workflow.ArrayNode.DataMode:
return self._data_mode

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self._remote_interface:
raise ValueError("Mapping over remote entities is not supported in local execution.")
Expand Down Expand Up @@ -269,10 +276,9 @@ def array_node(
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")
# if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
# raise ValueError("Only LaunchPlans are supported for now.")

node = ArrayNode(
target=target,
Expand Down
21 changes: 10 additions & 11 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,22 +367,21 @@ def map_task(
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
"""
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
)
return array_node_map_task(
task_function=target,
# if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
**kwargs,
)
# return array_node_map_task(
# task_function=target,
# concurrency=concurrency,
# min_successes=min_successes,
# min_success_ratio=min_success_ratio,
# **kwargs,
# )


def array_node_map_task(
Expand Down
10 changes: 9 additions & 1 deletion flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,13 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode":

class ArrayNode(_common.FlyteIdlEntity):
def __init__(
self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None, execution_mode=None
self,
node: "Node",
parallelism=None,
min_successes=None,
min_success_ratio=None,
execution_mode=None,
data_mode=None,
) -> None:
"""
TODO: docstring
Expand All @@ -393,6 +399,7 @@ def __init__(
self._min_successes = min_successes
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode
self._data_mode = data_mode

@property
def node(self) -> "Node":
Expand All @@ -405,6 +412,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
min_successes=self._min_successes,
min_success_ratio=self._min_success_ratio,
execution_mode=self._execution_mode,
data_mode=self._data_mode,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def get_serializable_array_node(
min_successes=array_node.min_successes,
min_success_ratio=array_node.min_success_ratio,
execution_mode=array_node.execution_mode,
data_mode=array_node.data_mode,
)


Expand Down
Loading