diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 2fcf8bbd94..725e3e14fc 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -943,3 +943,126 @@ def wf_with_input() -> typing.Optional[typing.List[int]]: ) assert wf_with_input() == input_val + +def test_positional_args_task(): + arg1 = 5 + arg2 = 6 + ret = 17 + + @task + def t1(x: int, y: int) -> int: + return x + y * 2 + + @workflow + def wf_pure_positional_args() -> int: + return t1(arg1, arg2) + + @workflow + def wf_mixed_positional_and_keyword_args() -> int: + return t1(arg1, y=arg2) + + wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args) + wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args) + + arg1_binding = Scalar(primitive=Primitive(integer=arg1)) + arg2_binding = Scalar(primitive=Primitive(integer=arg2)) + output_type = LiteralType(simple=SimpleType.INTEGER) + + assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type + + + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_pure_positional_args() == ret + assert wf_mixed_positional_and_keyword_args() == ret + +def test_positional_args_workflow(): + arg1 = 5 + arg2 = 6 + ret = 17 + + @task + def t1(x: int, y: int) -> int: + return x + y * 2 + + @workflow + def sub_wf(x: int, y: int) -> int: + return t1(x=x, y=y) + + @workflow + def wf_pure_positional_args() -> int: + return sub_wf(arg1, arg2) + + @workflow + def wf_mixed_positional_and_keyword_args() -> int: + return sub_wf(arg1, y=arg2) + + wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args) + wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args) + + arg1_binding = Scalar(primitive=Primitive(integer=arg1)) + arg2_binding = Scalar(primitive=Primitive(integer=arg2)) + output_type = LiteralType(simple=SimpleType.INTEGER) + + assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_pure_positional_args() == ret + assert wf_mixed_positional_and_keyword_args() == ret + +def test_positional_args_chained_tasks(): + @task + def t1(x: int, y: int) -> int: + return x + y * 2 + + @workflow + def wf() -> int: + x = t1(2, y = 3) + y = t1(3, 4) + return t1(x, y = y) + + assert wf() == 30 + +def test_positional_args_task_inputs_from_workflow_args(): + @task + def t1(x: int, y: int, z: int) -> int: + return x + y * 2 + z * 3 + + @workflow + def wf(x: int, y: int) -> int: + return t1(x, y=y, z=3) + + assert wf(1, 2) == 14 + +def test_unexpected_kwargs_task_raises_error(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Received unexpected keyword argument"): + t1(b=6) + +def test_too_many_positional_args_task_raises_error(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Received more arguments than expected"): + t1(1, 2) + +def test_both_positional_and_keyword_args_task_raises_error(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Got multiple values for argument"): + t1(1, a=2)