Skip to content

Commit

Permalink
[Bug] Map task caching failures (#2113)
Browse files Browse the repository at this point in the history
* only utilize bounded inputs for map task names instead of entire interface

Signed-off-by: Paul Dittamo <[email protected]>

* add test

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* order container vars for map tasks

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt authored Jan 22, 2024
1 parent 5ba545c commit cba830e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
5 changes: 3 additions & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def __init__(
f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
h = hashlib.md5(
f"{collection_interface.__str__()}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8")
f"{sorted_bounded_inputs}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8")
).hexdigest()
self._name = f"{mod}.map_{f}_{h}-arraynode"

Expand Down Expand Up @@ -387,7 +388,7 @@ def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> ArrayNo
def loader_args(self, settings: SerializationSettings, t: ArrayNodeMapTask) -> List[str]: # type:ignore
return [
"vars",
f'{",".join(t.bound_inputs)}',
f'{",".join(sorted(t.bound_inputs))}',
"resolver",
t.python_function_task.task_resolver.location,
*t.python_function_task.task_resolver.loader_args(settings, t.python_function_task),
Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(
f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(typing.cast(PythonFunctionTask, actual_task).task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
h = hashlib.md5(sorted_bounded_inputs.encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

self._cmd_prefix: typing.Optional[typing.List[str]] = None
Expand Down Expand Up @@ -404,7 +405,7 @@ def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPyth
def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore
return [
"vars",
f'{",".join(t.bound_inputs)}',
f'{",".join(sorted(t.bound_inputs))}',
"resolver",
t.run_task.task_resolver.location,
*t.run_task.task_resolver.loader_args(settings, t.run_task),
Expand Down
22 changes: 17 additions & 5 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from flytekit import task, workflow
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.task import TaskMetadata
from flytekit.experimental import map_task as array_node_map_task
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -187,7 +187,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]}
assert (
m.name
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_4ee240ef5cf979dbc133fb30035cb874-arraynode"
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_bf51001578d0ae197a52c0af0a99dd89-arraynode"
)
r_m = ArrayNodeMapTask(many_inputs)
assert str(r_m.python_interface) == str(m.python_interface)
Expand All @@ -197,7 +197,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float}
assert (
m.name
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_352fcdea8523a83134b51bbf5793f14e-arraynode"
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_cb470e880fabd6265ec80e29fe60250d-arraynode"
)
r_m = ArrayNodeMapTask(many_inputs, bound_inputs=set("c"))
assert str(r_m.python_interface) == str(m.python_interface)
Expand All @@ -207,7 +207,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float}
assert (
m.name
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_e224ba3a5b00e08083d541a6ca99b179-arraynode"
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_316e10eb97f5d2abd585951048b807b9-arraynode"
)
r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"c", "b"})
assert str(r_m.python_interface) == str(m.python_interface)
Expand All @@ -217,7 +217,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert m.python_interface.inputs == {"a": int, "b": str, "c": float}
assert (
m.name
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_f080e60be9d6faedeef0c74834d6812a-arraynode"
== "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_758022acd59ad1c8b81670378d4de4f6-arraynode"
)
r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"a", "c", "b"})
assert str(r_m.python_interface) == str(m.python_interface)
Expand Down Expand Up @@ -257,6 +257,18 @@ def task3(c: str, a: int, b: float) -> str:
assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]


def test_bounded_inputs_vars_order(serialization_settings):
@task()
def task1(a: int, b: float, c: str) -> str:
return f"{a} - {b} - {c}"

mt = array_node_map_task(functools.partial(task1, c=1.0, b="hello", a=1))
mtr = ArrayNodeMapTaskResolver()
args = mtr.loader_args(serialization_settings, mt)

assert args[1] == "a,b,c"


@pytest.mark.parametrize(
"min_success_ratio, should_raise_error",
[
Expand Down
21 changes: 17 additions & 4 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,31 +192,36 @@ def many_inputs(a: int, b: str, c: float) -> str:

m = map_task(many_inputs)
assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]}
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_24c08b3a2f9c2e389ad9fc6a03482cf9"
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_d41d8cd98f00b204e9800998ecf8427e"
r_m = MapPythonTask(many_inputs)
assert str(r_m.python_interface) == str(m.python_interface)

p1 = functools.partial(many_inputs, c=1.0)
m = map_task(p1)
assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": float}
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_697aa7389996041183cf6cfd102be4f7"
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_4a8a08f09d37b73795649038408b5f33"
r_m = MapPythonTask(many_inputs, bound_inputs=set("c"))
assert str(r_m.python_interface) == str(m.python_interface)

p2 = functools.partial(p1, b="hello")
m = map_task(p2)
assert m.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float}
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_cc18607da7494024a402a5fa4b3ea5c6"
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_74aefa13d6ab8e4bfbd241583749dfe8"
r_m = MapPythonTask(many_inputs, bound_inputs={"c", "b"})
assert str(r_m.python_interface) == str(m.python_interface)

p3 = functools.partial(p2, a=1)
m = map_task(p3)
assert m.python_interface.inputs == {"a": int, "b": str, "c": float}
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_52fe80b04781ea77ef6f025f4b49abef"
assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_a44c56c8177e32d3613988f4dba7962e"
r_m = MapPythonTask(many_inputs, bound_inputs={"a", "c", "b"})
assert str(r_m.python_interface) == str(m.python_interface)

p3_1 = functools.partial(p2, a=1)
m_1 = map_task(p3_1)
assert m_1.python_interface.inputs == {"a": int, "b": str, "c": float}
assert m_1.name == m.name

with pytest.raises(TypeError):
m(a=[1, 2, 3])

Expand Down Expand Up @@ -348,3 +353,11 @@ def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")

assert wf.nodes[0].flyte_entity.run_task.container_image == "random:image"


def test_bounded_inputs_vars_order(serialization_settings):
mt = map_task(functools.partial(t3, c=1.0, b="hello", a=1))
mtr = MapTaskResolver()
args = mtr.loader_args(serialization_settings, mt)

assert args[1] == "a,b,c"

0 comments on commit cba830e

Please sign in to comment.