diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 5dc0facf36..6eeec468ad 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -268,13 +268,15 @@ def _raw_execute(self, **kwargs) -> Any: outputs_expected = False outputs = [] - any_input_key = ( - list(self.python_function_task.interface.inputs.keys())[0] - if self.python_function_task.interface.inputs.items() is not None - else None - ) + mapped_input_value_len = 0 + if self._run_task.interface.inputs.items(): + for k in self._run_task.interface.inputs.keys(): + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + mapped_input_value_len = len(v) + break - for i in range(len(kwargs[any_input_key])): + for i in range(mapped_input_value_len): single_instance_inputs = {} for k in self.interface.inputs.keys(): v = kwargs[k] diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index e47b731ac6..6462488639 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -263,13 +263,15 @@ def _raw_execute(self, **kwargs) -> Any: outputs_expected = False outputs = [] - any_input_key = ( - list(self._run_task.interface.inputs.keys())[0] - if self._run_task.interface.inputs.items() is not None - else None - ) + mapped_input_value_len = 0 + if self._run_task.interface.inputs.items(): + for k in self._run_task.interface.inputs.keys(): + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + mapped_input_value_len = len(v) + break - for i in range(len(kwargs[any_input_key])): + for i in range(mapped_input_value_len): single_instance_inputs = {} for k in self.interface.inputs.keys(): v = kwargs[k] diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 2de15667d5..38f02c0c11 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -230,3 +230,27 @@ def many_outputs(a: int) -> (int, str): with pytest.raises(ValueError): _ = array_node_map_task(many_outputs) + + +def test_parameter_order(): + @task() + def task1(a: int, b: float, c: str) -> str: + return f"{a} - {b} - {c}" + + @task() + def task2(b: float, c: str, a: int) -> str: + return f"{a} - {b} - {c}" + + @task() + def task3(c: str, a: int, b: float) -> str: + return f"{a} - {b} - {c}" + + param_a = [1, 2, 3] + param_b = [0.1, 0.2, 0.3] + param_c = "c" + + m1 = array_node_map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b) + m2 = array_node_map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b) + m3 = array_node_map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b) + + assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"] diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index f66ab7bd49..82b530443f 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -282,3 +282,27 @@ def my_wf1() -> typing.List[type_t]: return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4]) my_wf1() + + +def test_map_task_parameter_order(): + @task() + def task1(a: int, b: float, c: str) -> str: + return f"{a} - {b} - {c}" + + @task() + def task2(b: float, c: str, a: int) -> str: + return f"{a} - {b} - {c}" + + @task() + def task3(c: str, a: int, b: float) -> str: + return f"{a} - {b} - {c}" + + param_a = [1, 2, 3] + param_b = [0.1, 0.2, 0.3] + param_c = "c" + + m1 = map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b) + m2 = map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b) + m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b) + + assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]