Skip to content

Commit

Permalink
test: Add tests for positional arguments
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5320
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jun 21, 2024
1 parent 7fe42ac commit 3fd412a
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,126 @@ def wf_with_input() -> typing.Optional[typing.List[int]]:
)

assert wf_with_input() == input_val

def test_positional_args_task():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def wf_pure_positional_args() -> int:
return t1(arg1, arg2)

@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return t1(arg1, y=arg2)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type


assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_workflow():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int, y: int) -> int:
return t1(x=x, y=y)

@workflow
def wf_pure_positional_args() -> int:
return sub_wf(arg1, arg2)

@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return sub_wf(arg1, y=arg2)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_chained_tasks():
@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def wf() -> int:
x = t1(2, y = 3)
y = t1(3, 4)
return t1(x, y = y)

assert wf() == 30

def test_positional_args_task_inputs_from_workflow_args():
@task
def t1(x: int, y: int, z: int) -> int:
return x + y * 2 + z * 3

@workflow
def wf(x: int, y: int) -> int:
return t1(x, y=y, z=3)

assert wf(1, 2) == 14

def test_unexpected_kwargs_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Received unexpected keyword argument"):
t1(b=6)

def test_too_many_positional_args_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Received more arguments than expected"):
t1(1, 2)

def test_both_positional_and_keyword_args_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Got multiple values for argument"):
t1(1, a=2)

0 comments on commit 3fd412a

Please sign in to comment.