From 824353d784f16ce6c4330e91fb9fada15191c754 Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:17:37 -0700 Subject: [PATCH] Feature/array node workflow parallelism (#2268) * default array node concurrency to -1 Signed-off-by: Paul Dittamo * typo Signed-off-by: Paul Dittamo * set default concurrency to None for backwards compatibility Signed-off-by: Paul Dittamo * update unit test - hashed name Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo --- flytekit/core/array_node_map_task.py | 5 +++-- tests/flytekit/unit/core/test_array_node_map_task.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index fc35dfa62f..94cba1426c 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -314,7 +314,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( task_function: PythonFunctionTask, - concurrency: int = 0, + concurrency: Optional[int] = None, # TODO why no min_successes? min_success_ratio: float = 1.0, **kwargs, @@ -328,7 +328,8 @@ def map_task( :param task_function: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until - all inputs are processed. If left unspecified, this means unbounded concurrency. + all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the + array node will inherit parallelism from the workflow :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete successfully before terminating this task and marking it successful. """ diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 8b078cdf2a..5c84c60984 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -184,7 +184,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_bf51001578d0ae197a52c0af0a99dd89-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_6b3bd0353da5de6e84d7982921ead2b3-arraynode" ) r_m = ArrayNodeMapTask(many_inputs) assert str(r_m.python_interface) == str(m.python_interface) @@ -194,7 +194,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_cb470e880fabd6265ec80e29fe60250d-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_7df6892fe8ce5343c76197a0b6127e80-arraynode" ) r_m = ArrayNodeMapTask(many_inputs, bound_inputs=set("c")) assert str(r_m.python_interface) == str(m.python_interface) @@ -204,7 +204,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_316e10eb97f5d2abd585951048b807b9-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_80fd21f14571026755b99d6b1c045089-arraynode" ) r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"c", "b"}) assert str(r_m.python_interface) == str(m.python_interface) @@ -214,7 +214,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": int, "b": str, "c": float} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_758022acd59ad1c8b81670378d4de4f6-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_5d2500dc176052a030efda3b8c283f96-arraynode" ) r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"a", "c", "b"}) assert str(r_m.python_interface) == str(m.python_interface)