diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index e4fef9ed10..5d10cadd45 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -77,8 +77,9 @@ def __init__( f = actual_task.lhs else: _, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function) + sorted_bounded_inputs = ",".join(sorted(self._bound_inputs)) h = hashlib.md5( - f"{collection_interface.__str__()}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8") + f"{sorted_bounded_inputs}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8") ).hexdigest() self._name = f"{mod}.map_{f}_{h}-arraynode" @@ -387,7 +388,7 @@ def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> ArrayNo def loader_args(self, settings: SerializationSettings, t: ArrayNodeMapTask) -> List[str]: # type:ignore return [ "vars", - f'{",".join(t.bound_inputs)}', + f'{",".join(sorted(t.bound_inputs))}', "resolver", t.python_function_task.task_resolver.location, *t.python_function_task.task_resolver.loader_args(settings, t.python_function_task), diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 1201a3ede0..aac31a1ee9 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -92,7 +92,8 @@ def __init__( f = actual_task.lhs else: _, mod, f, _ = tracker.extract_task_module(typing.cast(PythonFunctionTask, actual_task).task_function) - h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() + sorted_bounded_inputs = ",".join(sorted(self._bound_inputs)) + h = hashlib.md5(sorted_bounded_inputs.encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" self._cmd_prefix: typing.Optional[typing.List[str]] = None @@ -404,7 +405,7 @@ def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPyth def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore return [ "vars", - f'{",".join(t.bound_inputs)}', + f'{",".join(sorted(t.bound_inputs))}', "resolver", t.run_task.task_resolver.location, *t.run_task.task_resolver.loader_args(settings, t.run_task), diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index aeb727f5f8..40bb864c4f 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -7,7 +7,7 @@ from flytekit import task, workflow from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings -from flytekit.core.array_node_map_task import ArrayNodeMapTask +from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver from flytekit.core.task import TaskMetadata from flytekit.experimental import map_task as array_node_map_task from flytekit.tools.translator import get_serializable @@ -187,7 +187,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_4ee240ef5cf979dbc133fb30035cb874-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_bf51001578d0ae197a52c0af0a99dd89-arraynode" ) r_m = ArrayNodeMapTask(many_inputs) assert str(r_m.python_interface) == str(m.python_interface) @@ -197,7 +197,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_352fcdea8523a83134b51bbf5793f14e-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_cb470e880fabd6265ec80e29fe60250d-arraynode" ) r_m = ArrayNodeMapTask(many_inputs, bound_inputs=set("c")) assert str(r_m.python_interface) == str(m.python_interface) @@ -207,7 +207,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_e224ba3a5b00e08083d541a6ca99b179-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_316e10eb97f5d2abd585951048b807b9-arraynode" ) r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"c", "b"}) assert str(r_m.python_interface) == str(m.python_interface) @@ -217,7 +217,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert m.python_interface.inputs == {"a": int, "b": str, "c": float} assert ( m.name - == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_f080e60be9d6faedeef0c74834d6812a-arraynode" + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_758022acd59ad1c8b81670378d4de4f6-arraynode" ) r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"a", "c", "b"}) assert str(r_m.python_interface) == str(m.python_interface) @@ -257,6 +257,18 @@ def task3(c: str, a: int, b: float) -> str: assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"] +def test_bounded_inputs_vars_order(serialization_settings): + @task() + def task1(a: int, b: float, c: str) -> str: + return f"{a} - {b} - {c}" + + mt = array_node_map_task(functools.partial(task1, c=1.0, b="hello", a=1)) + mtr = ArrayNodeMapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + + assert args[1] == "a,b,c" + + @pytest.mark.parametrize( "min_success_ratio, should_raise_error", [ diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 26d1a71c3c..c87d4c6b1f 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -192,31 +192,36 @@ def many_inputs(a: int, b: str, c: float) -> str: m = map_task(many_inputs) assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} - assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_24c08b3a2f9c2e389ad9fc6a03482cf9" + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_d41d8cd98f00b204e9800998ecf8427e" r_m = MapPythonTask(many_inputs) assert str(r_m.python_interface) == str(m.python_interface) p1 = functools.partial(many_inputs, c=1.0) m = map_task(p1) assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": float} - assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_697aa7389996041183cf6cfd102be4f7" + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_4a8a08f09d37b73795649038408b5f33" r_m = MapPythonTask(many_inputs, bound_inputs=set("c")) assert str(r_m.python_interface) == str(m.python_interface) p2 = functools.partial(p1, b="hello") m = map_task(p2) assert m.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} - assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_cc18607da7494024a402a5fa4b3ea5c6" + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_74aefa13d6ab8e4bfbd241583749dfe8" r_m = MapPythonTask(many_inputs, bound_inputs={"c", "b"}) assert str(r_m.python_interface) == str(m.python_interface) p3 = functools.partial(p2, a=1) m = map_task(p3) assert m.python_interface.inputs == {"a": int, "b": str, "c": float} - assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_52fe80b04781ea77ef6f025f4b49abef" + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_a44c56c8177e32d3613988f4dba7962e" r_m = MapPythonTask(many_inputs, bound_inputs={"a", "c", "b"}) assert str(r_m.python_interface) == str(m.python_interface) + p3_1 = functools.partial(p2, a=1) + m_1 = map_task(p3_1) + assert m_1.python_interface.inputs == {"a": int, "b": str, "c": float} + assert m_1.name == m.name + with pytest.raises(TypeError): m(a=[1, 2, 3]) @@ -348,3 +353,11 @@ def wf(x: typing.List[int]): map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") assert wf.nodes[0].flyte_entity.run_task.container_image == "random:image" + + +def test_bounded_inputs_vars_order(serialization_settings): + mt = map_task(functools.partial(t3, c=1.0, b="hello", a=1)) + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + + assert args[1] == "a,b,c"