From 12e194a308f0eacc622d7358ea5698ac741aedaf Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 21 Nov 2024 17:29:42 -0500 Subject: [PATCH] Add `is_map_task` to `_dispatch_execute` Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 19 +++- .../unit/bin/test_python_entrypoint.py | 90 +++++++++++++++++++ 2 files changed, 107 insertions(+), 2 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ed04335b00..084e8f733b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -136,6 +136,7 @@ def _dispatch_execute( load_task: Callable[[], PythonTask], inputs_path: str, output_prefix: str, + is_map_task: bool = False, ): """ Dispatches execute to PythonTask @@ -145,6 +146,12 @@ def _dispatch_execute( a: [Optional] Record outputs to output_prefix b: OR if IgnoreOutputs is raised, then ignore uploading outputs c: OR if an unhandled exception is retrieved - record it as an errors.pb + + :param ctx: FlyteContext + :param load_task: Callable[[], PythonTask] + :param inputs: Where to read inputs + :param output_prefix: Where to write primitive outputs + :param is_map_task: Whether this task is executing as part of a map task """ error_file_name = _build_error_file_name() worker_name = _get_worker_name() @@ -206,6 +213,14 @@ def _dispatch_execute( if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") + inferred_type = task_def.interface.outputs[k].type + + # In the case of map tasks we need to use the type of the collection as inferred type as the task + # typed interface of the offloaded literal. This is done because the map task interface present in + # the task template contains the (correct) type for the entire map task, not the single node execution. + # For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal. + if is_map_task: + inferred_type = inferred_type.collection_type # This file will hold the offloaded literal and will be written to the output prefix # alongside the regular outputs.pb, deck.pb, etc. @@ -216,7 +231,7 @@ def _dispatch_execute( uri=f"{output_prefix}/{offloaded_filename}", size_bytes=lit.ByteSize(), # TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged - inferred_type=task_def.interface.outputs[k].type, + inferred_type=inferred_type, ), hash=v.hash if v.hash is not None else compute_hash_string(lit), ) @@ -633,7 +648,7 @@ def load_task(): ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + _dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True) def normalize_inputs( diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index d4dc88b9d1..3955019cd0 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -20,6 +20,7 @@ from flytekit.bin.entrypoint import _dispatch_execute, get_container_error_timestamp, normalize_inputs, setup_execution, get_traceback_str from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import mock_stats +from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.hash import HashMethod from flytekit.models.core import identifier as id_models from flytekit.core import context_manager @@ -893,3 +894,92 @@ def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: assert lit.literals["o0"].HasField("offloaded_metadata") == False else: assert False, f"Unexpected file {ff}" + + + +def test_dispatch_execute_offloaded_map_task(tmp_path_factory): + @task + def t1(n: int) -> int: + return n + 1 + + inputs: typing.List[int] = [1, 2, 3, 4] + for i, v in enumerate(inputs): + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": TypeEngine.to_literal(ctx, inputs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict( + os.environ, + { + "_F_L_MIN_SIZE_MB": "0", # Always offload + "BATCH_JOB_ARRAY_INDEX_OFFSET": str(i), + }): + _dispatch_execute(ctx, lambda: ArrayNodeMapTask(python_function_task=t1), str(inputs_path/"inputs.pb"), str(outputs_path.absolute()), is_map_task=True) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + expected_output = v + 1 + assert lit == TypeEngine.to_literal(ctx, expected_output, int, TypeEngine.to_literal_type(int)).to_flyte_idl() + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Ensure that this is not set by an external source + assert os.environ.get("_F_L_MIN_SIZE_MB") is None + + # Notice how we're setting the env var to None, which disables offloading completely + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == False + else: + assert False, f"Unexpected file {ff}"