Skip to content

Commit

Permalink
Log Agents In The Agent Server (#2309)
Browse files Browse the repository at this point in the history
* Log Agents In The Agent Server

Signed-off-by: Future-Outlier <[email protected]>

* add tests

Signed-off-by: Future-Outlier <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
Future-Outlier and pingsutw authored Apr 7, 2024
1 parent bf38b8e commit 4c6e704
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
12 changes: 12 additions & 0 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):

_start_http_server()
click.secho("Starting the agent service...", fg="blue")
print_agents_metadata()

server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)
Expand Down Expand Up @@ -96,3 +98,13 @@ def _start_health_check_server(server: grpc.Server, worker: int):

except ImportError as e:
click.secho(f"Failed to start the health check servicer with error {e}", fg="red")


def print_agents_metadata():
from flytekit.extend.backend.base_agent import AgentRegistry

agents = AgentRegistry.list_agents()
for agent in agents:
name = agent.name
metadata = [category.name for category in agent.supported_task_categories]
click.secho(f"Starting {name} that supports task categories {metadata}", fg="blue")
2 changes: 2 additions & 0 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector:


class SnowflakeAgent(AsyncAgentBase):
name = "Snowflake Agent"

def __init__(self):
super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata)

Expand Down
26 changes: 26 additions & 0 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import grpc
import pytest
from flyteidl.admin.agent_pb2 import (
Agent,
CreateRequestHeader,
CreateTaskRequest,
DeleteTaskRequest,
Expand All @@ -20,6 +21,7 @@
from flyteidl.core.identifier_pb2 import ResourceType

from flytekit import PythonFunctionTask, task
from flytekit.clis.sdk_in_container.serve import print_agents_metadata
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, kwtypes
from flytekit.core.interface import Interface
Expand Down Expand Up @@ -384,3 +386,27 @@ def test_render_task_template():
"task-name",
"simple_task",
]


@pytest.fixture
def sample_agents():
async_agent = Agent(
name="Sensor", is_sync=False, supported_task_categories=[TaskCategory(name="sensor", version=0)]
)
sync_agent = Agent(
name="ChatGPT Agent", is_sync=True, supported_task_categories=[TaskCategory(name="chatgpt", version=0)]
)
return [async_agent, sync_agent]


@patch("flytekit.clis.sdk_in_container.serve.click.secho")
@patch("flytekit.extend.backend.base_agent.AgentRegistry.list_agents")
def test_print_agents_metadata_output(list_agents_mock, mock_secho, sample_agents):
list_agents_mock.return_value = sample_agents
print_agents_metadata()
expected_calls = [
(("Starting Sensor that supports task categories ['sensor']",), {"fg": "blue"}),
(("Starting ChatGPT Agent that supports task categories ['chatgpt']",), {"fg": "blue"}),
]
mock_secho.assert_has_calls(expected_calls, any_order=True)
assert mock_secho.call_count == len(expected_calls)

0 comments on commit 4c6e704

Please sign in to comment.