Skip to content

Commit

Permalink
Setting interruptible on ArrayNode sub node metadata (flyteorg#2288)
Browse files Browse the repository at this point in the history
* setting interruptible on ArrayNode sub node metadata

Signed-off-by: Daniel Rammer <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* stuff (flyteorg#2291)

Signed-off-by: Yee Hing Tong <[email protected]>

---------

Signed-off-by: Daniel Rammer <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent 3642ec6 commit 600a1a9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def python_interface(self):

def construct_node_metadata(self) -> NodeMetadata:
# TODO: add support for other Flyte entities
return NodeMetadata(
name=self.name,
)
nm = super().construct_node_metadata()
nm._name = self.name
return nm

@property
def min_success_ratio(self) -> Optional[float]:
Expand Down
42 changes: 42 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,45 @@ def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")

assert wf.nodes[0]._container_image == "random:image"


def test_serialization_metadata(serialization_settings):
@task(interruptible=True)
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2))
# since we manually override task metadata, the underlying task metadata will not be copied.
assert not arraynode_maptask.metadata.interruptible

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

assert not arraynode_maptask.construct_node_metadata().interruptible
assert not wf_spec.template.nodes[0].metadata.interruptible


def test_serialization_metadata2(serialization_settings):
@task
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2, interruptible=True))
assert arraynode_maptask.metadata.interruptible

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

assert arraynode_maptask.construct_node_metadata().interruptible
assert wf_spec.template.nodes[0].metadata.interruptible
task_spec = od[arraynode_maptask]
assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.metadata.interruptible

0 comments on commit 600a1a9

Please sign in to comment.