From 5aa854dc5c16fc0cfc572f142b1b9de88f59470b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 7 Feb 2024 23:06:20 +0100 Subject: [PATCH] Handle overriding of container image in backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- flytekit/core/node.py | 3 ++- flytekit/models/core/workflow.py | 16 +++++++++++++--- flytekit/tools/translator.py | 18 +++++++++++++++--- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index f5a3db4afa..e8a37bf3f0 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -65,6 +65,7 @@ def __init__( self._outputs = None self._resources: typing.Optional[_resources_model] = None self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None + self._container_image: typing.Optional[str] = None def runs_before(self, other: Node): """ @@ -193,7 +194,7 @@ def with_overrides(self, *args, **kwargs): if "container_image" in kwargs: v = kwargs["container_image"] assert_not_promise(v, "container_image") - self.run_entity._container_image = v + self._container_image = v if "accelerator" in kwargs: v = kwargs["accelerator"] diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 62636d1420..aef5d3c46f 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -595,10 +595,14 @@ def from_flyte_idl(cls, pb2_object): class TaskNodeOverrides(_common.FlyteIdlEntity): def __init__( - self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources] + self, + resources: typing.Optional[Resources], + extended_resources: typing.Optional[tasks_pb2.ExtendedResources], + container_image: typing.Optional[str], ): self._resources = resources self._extended_resources = extended_resources + self._container_image = container_image @property def resources(self) -> Resources: @@ -608,19 +612,25 @@ def resources(self) -> Resources: def extended_resources(self) -> tasks_pb2.ExtendedResources: return self._extended_resources + @property + def container_image(self) -> str: + return self._container_image + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, + container_image=self.container_image, ) @classmethod def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None + container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources, extended_resources=extended_resources) - return cls(resources=None, extended_resources=extended_resources) + return cls(resources=resources, extended_resources=extended_resources, container_image=container_image) + return cls(resources=None, extended_resources=extended_resources, container_image=container_image) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 7bc719cef8..2847ff1b3d 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -477,7 +477,11 @@ def get_serializable_node( output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), + overrides=TaskNodeOverrides( + resources=entity._resources, + extended_resources=entity._extended_resources, + container_image=entity._container_image, + ), ), ) if entity._aliases: @@ -554,7 +558,11 @@ def get_serializable_node( output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, - overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), + overrides=TaskNodeOverrides( + resources=entity._resources, + extended_resources=entity._extended_resources, + container_image=entity._container_image, + ), ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): @@ -603,7 +611,11 @@ def get_serializable_array_node( task_spec = get_serializable(entity_mapping, settings, entity, options) task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources), + overrides=TaskNodeOverrides( + resources=node._resources, + extended_resources=node._extended_resources, + container_image=node._container_image, + ), ) node = workflow_model.Node( id=entity.name,