Skip to content

Commit

Permalink
Add more unit tests including a negative test
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Nov 19, 2024
1 parent 9276e96 commit a8bdbca
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 13 deletions.
3 changes: 2 additions & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def _dispatch_execute(

# Go over each output and create a separate offloaded in case its size is too large
for k, v in outputs.literals.items():
literal_map_copy[k] = v

if not offloading_enabled:
literal_map_copy[k] = v
continue

lit = v.to_flyte_idl()
Expand Down
99 changes: 87 additions & 12 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,15 +501,7 @@ def t1(a: typing.List[int]) -> typing.List[str]:
xs: typing.List[int] = [1, 2, 3]
input_literal_map = _literal_models.LiteralMap(
{
"a": _literal_models.Literal(
collection=_literal_models.LiteralCollection(
literals=[
_literal_models.Literal(
scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)),
) for x in xs
]
)
)
"a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])),
}
)

Expand All @@ -527,7 +519,7 @@ def t1(a: typing.List[int]) -> typing.List[str]:
lit.ParseFromString(f.read())
assert len(lit.literals) == 1
assert "o0" in lit.literals
assert lit.literals["o0"].offloaded_metadata is not None
assert lit.literals["o0"].HasField("offloaded_metadata") == True
assert lit.literals["o0"].offloaded_metadata.size_bytes == 62
assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb")
assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl()
Expand Down Expand Up @@ -592,12 +584,12 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]:
lit.ParseFromString(f.read())
assert len(lit.literals) == 2
assert "o0" in lit.literals
assert lit.literals["o0"].offloaded_metadata is not None
assert lit.literals["o0"].HasField("offloaded_metadata") == True
assert lit.literals["o0"].offloaded_metadata.size_bytes == 6
assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb")
assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl()
assert "o1" in lit.literals
assert lit.literals["o1"].offloaded_metadata is not None
assert lit.literals["o1"].HasField("offloaded_metadata") == True
assert lit.literals["o1"].offloaded_metadata.size_bytes == 82
assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb")
assert lit.literals["o1"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl()
Expand Down Expand Up @@ -749,3 +741,86 @@ def t1(n: int) -> typing.Annotated[A, HashMethod(lambda x: str(x.a))]:
assert a.a == 1234
else:
assert False, f"Unexpected file {ff}"


def test_dispatch_execute_offloaded_nested_lists_of_literals(tmp_path_factory):
@task
def t1(a: typing.List[int]) -> typing.List[typing.List[str]]:
return [[f"string is: {x}" for x in a] for _ in range(len(a))]

inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with get_flyte_context(tmp_path_factory, outputs_path) as ctx:
xs: typing.List[int] = [1, 2, 3]
input_literal_map = _literal_models.LiteralMap(
{
"a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])),
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}):
_dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute()))

assert "error.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 1
assert "o0" in lit.literals
assert lit.literals["o0"].HasField("offloaded_metadata") == True
assert lit.literals["o0"].offloaded_metadata.size_bytes == 195
assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb")
assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(collection_type=LiteralType(simple=SimpleType.STRING))).to_flyte_idl()
elif ff == "o0_offloaded_metadata.pb":
lit = literals_pb2.Literal()
lit.ParseFromString(f.read())
expected_output = [[f"string is: {x}" for x in xs] for _ in range(len(xs))]
assert lit == TypeEngine.to_literal(ctx, expected_output, typing.List[typing.List[str]], TypeEngine.to_literal_type(typing.List[typing.List[str]])).to_flyte_idl()
else:
assert False, f"Unexpected file {ff}"


def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory):
@task
def t1(a: typing.List[int]) -> typing.List[typing.List[str]]:
return [[f"string is: {x}" for x in a] for _ in range(len(a))]

inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with get_flyte_context(tmp_path_factory, outputs_path) as ctx:
xs: typing.List[int] = [1, 2, 3]
input_literal_map = _literal_models.LiteralMap(
{
"a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])),
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

# Ensure that this is not set by an external source
assert os.environ.get("_F_L_MIN_SIZE_MB") is None

# Notice how we're setting the env var to None, which disables offloading completely
_dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute()))

assert "error.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 1
assert "o0" in lit.literals
assert lit.literals["o0"].HasField("offloaded_metadata") == False
else:
assert False, f"Unexpected file {ff}"

0 comments on commit a8bdbca

Please sign in to comment.