diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index eb2838ca41..1be433cf94 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -16,7 +16,6 @@ GetTaskResponse, ListAgentsRequest, ListAgentsResponse, - Resource, ) from flyteidl.service.agent_pb2_grpc import ( AgentMetadataServiceServicer, @@ -25,8 +24,7 @@ ) from prometheus_client import Counter, Summary -from flytekit import FlyteContext, logger -from flytekit.core.type_engine import TypeEngine +from flytekit import logger from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap @@ -136,16 +134,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) logger.info(f"{agent.name} start checking the status of the job") res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) - if res.outputs is None: - outputs = None - elif isinstance(res.outputs, LiteralMap): - outputs = res.outputs.to_flyte_idl() - else: - ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) - return GetTaskResponse( - resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) - ) + return GetTaskResponse(resource=res.to_flyte_idl()) @record_agent_metrics async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: @@ -175,17 +164,7 @@ async def ExecuteTaskSync( literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) - if res.outputs is None: - outputs = None - elif isinstance(res.outputs, LiteralMap): - outputs = res.outputs.to_flyte_idl() - else: - ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) - - header = ExecuteTaskSyncResponseHeader( - resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) - ) + header = ExecuteTaskSyncResponseHeader(resource=res.to_flyte_idl()) yield ExecuteTaskSyncResponse(header=header) request_success_count.labels(task_type=task_type, operation=do_operation).inc() except Exception as e: diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 4d1d8956da..a84ea487b2 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -12,9 +12,12 @@ from typing import Any, Dict, List, Optional, Union from flyteidl.admin.agent_pb2 import Agent +from flyteidl.admin.agent_pb2 import Resource as _Resource from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct from rich.logging import RichHandler from rich.progress import Progress @@ -90,6 +93,38 @@ class Resource: message: Optional[str] = None log_links: Optional[List[TaskLog]] = None outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None + custom_info: Optional[typing.Dict[str, Any]] = None + + def to_flyte_idl(self) -> _Resource: + if self.outputs is None: + outputs = None + elif isinstance(self.outputs, LiteralMap): + outputs = self.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, self.outputs) + + return Agent.Resource( + phase=self.phase, + message=self.message, + log_links=self.log_links, + outputs=outputs, + custom_info=( + json_format.Parse(json.dumps(self.custom_info), Struct.Struct()) if self.custom_info else None + ), + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _Resource): + return cls( + phase=pb2_object.phase, + message=pb2_object.message if pb2_object.HasField("message") else None, + log_links=(pb2_object.log_links if pb2_object.HasField("log_links") else None), + outputs=(LiteralMap.from_flyte_idl(pb2_object.outputs) if pb2_object.outputs else None), + custom_info=( + json_format.MessageToDict(pb2_object.custom_info) if pb2_object.HasField("custom_info") else None + ), + ) class AgentBase(ABC): diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 17db5c2788..4697278499 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -70,7 +70,11 @@ def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap return DummyMetadata(job_id=dummy_id) def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: - return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + return Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + custom_info={"custom": "info"}, + ) def delete(self, resource_meta: DummyMetadata, **kwargs): ... @@ -95,7 +99,11 @@ async def create( return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name) async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: - return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + return Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + custom_info={"custom": "info"}, + ) async def delete(self, resource_meta: DummyMetadata, **kwargs): ... @@ -172,6 +180,7 @@ def test_dummy_agent(): assert resource.phase == TaskExecution.SUCCEEDED assert resource.log_links[0].name == "console" assert resource.log_links[0].uri == "localhost:3000" + assert resource.custom_info["custom"] == "info" assert agent.delete(metadata) is None class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask):