diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 6a7e5c3c28..efe7086126 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -57,6 +57,8 @@ async def _start_grpc_server(port: int, worker: int, timeout: int): _start_http_server() click.secho("Starting the agent service...", fg="blue") + print_agents_metadata() + server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker)) add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) @@ -96,3 +98,13 @@ def _start_health_check_server(server: grpc.Server, worker: int): except ImportError as e: click.secho(f"Failed to start the health check servicer with error {e}", fg="red") + + +def print_agents_metadata(): + from flytekit.extend.backend.base_agent import AgentRegistry + + agents = AgentRegistry.list_agents() + for agent in agents: + name = agent.name + metadata = [category.name for category in agent.supported_task_categories] + click.secho(f"Starting {name} that supports task categories {metadata}", fg="blue") diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 8cb38662e3..71eba91186 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -59,6 +59,8 @@ def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector: class SnowflakeAgent(AsyncAgentBase): + name = "Snowflake Agent" + def __init__(self): super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 2bf23abb25..17db5c2788 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -6,6 +6,7 @@ import grpc import pytest from flyteidl.admin.agent_pb2 import ( + Agent, CreateRequestHeader, CreateTaskRequest, DeleteTaskRequest, @@ -20,6 +21,7 @@ from flyteidl.core.identifier_pb2 import ResourceType from flytekit import PythonFunctionTask, task +from flytekit.clis.sdk_in_container.serve import print_agents_metadata from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, kwtypes from flytekit.core.interface import Interface @@ -384,3 +386,27 @@ def test_render_task_template(): "task-name", "simple_task", ] + + +@pytest.fixture +def sample_agents(): + async_agent = Agent( + name="Sensor", is_sync=False, supported_task_categories=[TaskCategory(name="sensor", version=0)] + ) + sync_agent = Agent( + name="ChatGPT Agent", is_sync=True, supported_task_categories=[TaskCategory(name="chatgpt", version=0)] + ) + return [async_agent, sync_agent] + + +@patch("flytekit.clis.sdk_in_container.serve.click.secho") +@patch("flytekit.extend.backend.base_agent.AgentRegistry.list_agents") +def test_print_agents_metadata_output(list_agents_mock, mock_secho, sample_agents): + list_agents_mock.return_value = sample_agents + print_agents_metadata() + expected_calls = [ + (("Starting Sensor that supports task categories ['sensor']",), {"fg": "blue"}), + (("Starting ChatGPT Agent that supports task categories ['chatgpt']",), {"fg": "blue"}), + ] + mock_secho.assert_has_calls(expected_calls, any_order=True) + assert mock_secho.call_count == len(expected_calls)