diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index 0cb2c8d25c..466058a791 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -19,6 +19,7 @@ flyte_entity_call_handler, translate_inputs_to_literals, ) +from flytekit.core.task import ReferenceTask from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -34,8 +35,7 @@ class ArrayNode: def __init__( self, - target: Union[LaunchPlan, "FlyteLaunchPlan"], - execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE, + target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"], bindings: Optional[List[_literal_models.Binding]] = None, concurrency: Optional[int] = None, min_successes: Optional[int] = None, @@ -51,17 +51,17 @@ def __init__( :param min_successes: The minimum number of successful executions. If set, this takes precedence over min_success_ratio :param min_success_ratio: The minimum ratio of successful executions. - :param execution_mode: The execution mode for propeller to use when handling ArrayNode :param metadata: The metadata for the underlying node """ from flytekit.remote import FlyteLaunchPlan self.target = target self._concurrency = concurrency - self._execution_mode = execution_mode self.id = target.name self._bindings = bindings or [] self.metadata = metadata + self._data_mode = None + self._execution_mode = None if min_successes is not None: self._min_successes = min_successes @@ -92,9 +92,12 @@ def __init__( else: raise ValueError("No interface found for the target entity.") - if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): - if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE: - raise ValueError("Only execution version 1 is supported for LaunchPlans.") + if isinstance(target, (LaunchPlan, FlyteLaunchPlan)): + self._data_mode = _core_workflow.ArrayNode.SINGLE_INPUT_FILE + self._execution_mode = _core_workflow.ArrayNode.FULL_STATE + elif isinstance(target, ReferenceTask): + self._data_mode = _core_workflow.ArrayNode.INDIVIDUAL_INPUT_FILES + self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE else: raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}") @@ -133,6 +136,10 @@ def upstream_nodes(self) -> List[Node]: def flyte_entity(self) -> Any: return self.target + @property + def data_mode(self) -> _core_workflow.ArrayNode.DataMode: + return self._data_mode + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: if self._remote_interface: raise ValueError("Mapping over remote entities is not supported in local execution.") @@ -254,7 +261,7 @@ def __call__(self, *args, **kwargs): def array_node( - target: Union[LaunchPlan, "FlyteLaunchPlan"], + target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, min_successes: Optional[int] = None, @@ -275,8 +282,8 @@ def array_node( """ from flytekit.remote import FlyteLaunchPlan - if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan): - raise ValueError("Only LaunchPlans are supported for now.") + if not isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)): + raise ValueError("Only LaunchPlans and ReferenceTasks are supported for now.") node = ArrayNode( target=target, diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 44458a53d2..78b9611651 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -18,6 +18,7 @@ from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.launch_plan import LaunchPlan from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask +from flytekit.core.task import ReferenceTask from flytekit.core.type_engine import TypeEngine from flytekit.core.utils import timeit from flytekit.loggers import logger @@ -390,7 +391,7 @@ def map_task( """ from flytekit.remote import FlyteLaunchPlan - if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): + if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)): return array_node( target=target, concurrency=concurrency, diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 8d8bf9c9ef..f3fed3d4f3 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -390,6 +390,7 @@ def __init__( min_success_ratio=None, execution_mode=None, is_original_sub_node_interface=False, + data_mode=None, ) -> None: """ TODO: docstring @@ -401,6 +402,7 @@ def __init__( self._min_success_ratio = min_success_ratio self._execution_mode = execution_mode self._is_original_sub_node_interface = is_original_sub_node_interface + self._data_mode = data_mode @property def node(self) -> "Node": @@ -414,6 +416,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode: min_success_ratio=self._min_success_ratio, execution_mode=self._execution_mode, is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface), + data_mode=self._data_mode, ) @classmethod diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5c7a6d5eb4..ee905a4218 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -601,6 +601,7 @@ def get_serializable_array_node( min_success_ratio=array_node.min_success_ratio, execution_mode=array_node.execution_mode, is_original_sub_node_interface=array_node.is_original_sub_node_interface, + data_mode=array_node.data_mode, ) diff --git a/pyproject.toml b/pyproject.toml index 58c107cdc3..3dc782c507 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.9", + "flyteidl>=1.14.1", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57",