diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 0970e2b6ab..90e54ad3b6 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -150,8 +150,9 @@ def _dispatch_execute( # Go over each output and create a separate offloaded in case its size is too large for k, v in outputs.literals.items(): + literal_map_copy[k] = v + if not offloading_enabled: - literal_map_copy[k] = v continue lit = v.to_flyte_idl() diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 7e58799caf..b62bfd8e34 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -501,15 +501,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: xs: typing.List[int] = [1, 2, 3] input_literal_map = _literal_models.LiteralMap( { - "a": _literal_models.Literal( - collection=_literal_models.LiteralCollection( - literals=[ - _literal_models.Literal( - scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)), - ) for x in xs - ] - ) - ) + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), } ) @@ -527,7 +519,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: lit.ParseFromString(f.read()) assert len(lit.literals) == 1 assert "o0" in lit.literals - assert lit.literals["o0"].offloaded_metadata is not None + assert lit.literals["o0"].HasField("offloaded_metadata") == True assert lit.literals["o0"].offloaded_metadata.size_bytes == 62 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() @@ -592,12 +584,12 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: lit.ParseFromString(f.read()) assert len(lit.literals) == 2 assert "o0" in lit.literals - assert lit.literals["o0"].offloaded_metadata is not None + assert lit.literals["o0"].HasField("offloaded_metadata") == True assert lit.literals["o0"].offloaded_metadata.size_bytes == 6 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() assert "o1" in lit.literals - assert lit.literals["o1"].offloaded_metadata is not None + assert lit.literals["o1"].HasField("offloaded_metadata") == True assert lit.literals["o1"].offloaded_metadata.size_bytes == 82 assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") assert lit.literals["o1"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() @@ -749,3 +741,86 @@ def t1(n: int) -> typing.Annotated[A, HashMethod(lambda x: str(x.a))]: assert a.a == 1234 else: assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.size_bytes == 195 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(collection_type=LiteralType(simple=SimpleType.STRING))).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + expected_output = [[f"string is: {x}" for x in xs] for _ in range(len(xs))] + assert lit == TypeEngine.to_literal(ctx, expected_output, typing.List[typing.List[str]], TypeEngine.to_literal_type(typing.List[typing.List[str]])).to_flyte_idl() + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Ensure that this is not set by an external source + assert os.environ.get("_F_L_MIN_SIZE_MB") is None + + # Notice how we're setting the env var to None, which disables offloading completely + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == False + else: + assert False, f"Unexpected file {ff}"