From 4d62aaa120a946a11997e07c252d2c640b8d7370 Mon Sep 17 00:00:00 2001 From: Mohamed Arshath Date: Tue, 10 Dec 2024 02:09:34 +0000 Subject: [PATCH] automatically identify communication protocol --- naptha_sdk/agent.py | 33 +++++++++++++++++++++++++++++++-- naptha_sdk/client/node.py | 11 +++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/naptha_sdk/agent.py b/naptha_sdk/agent.py index 1b7c4bf..aab9eef 100644 --- a/naptha_sdk/agent.py +++ b/naptha_sdk/agent.py @@ -13,10 +13,35 @@ def __init__(self, ): self.orchestrator_run = orchestrator_run self.agent_index = agent_index + self.worker_node = None + + async def initialize(self): + """Initialize the agent by setting up the worker node connection""" worker_node_url = self.orchestrator_run.agent_deployments[self.agent_index].worker_node_url + worker_node_url = await self.identify_communication_type(worker_node_url) self.worker_node = Node(worker_node_url) + async def identify_communication_type(self, agent_node_url): + if "ws://" in agent_node_url: + return agent_node_url + elif "grpc://" in agent_node_url: + return agent_node_url + else: + # Create temporary node for health check + temp_node = Node(f"http://{agent_node_url}") + ws_health_url = f"http://{agent_node_url}/health" + + try: + ws_health = await temp_node.check_health_ws(ws_health_url) + return f"ws://{agent_node_url}" if ws_health else f"grpc://{agent_node_url}" + except Exception as e: + logger.error(f"Error checking node health: {e}") + raise + async def call_agent_func(self, *args, **kwargs): + if self.worker_node is None: + await self.initialize() + logger.info(f"Running agent on worker node {self.worker_node.node_url}") agent_run_input = AgentRunInput( @@ -25,5 +50,9 @@ async def call_agent_func(self, *args, **kwargs): agent_deployment=self.orchestrator_run.agent_deployments[self.agent_index].model_dump(), ) - agent_run = await self.worker_node.run_agent_in_node(agent_run_input) - return agent_run + try: + agent_run = await self.worker_node.run_agent_in_node(agent_run_input) + return agent_run + except Exception as e: + logger.error(f"Error running agent: {e}") + raise \ No newline at end of file diff --git a/naptha_sdk/client/node.py b/naptha_sdk/client/node.py index cf4ba8c..0f2710e 100644 --- a/naptha_sdk/client/node.py +++ b/naptha_sdk/client/node.py @@ -101,6 +101,17 @@ async def run_environment_and_poll(self, environment_input: EnvironmentRunInput) """Run an environment and poll for results until completion.""" return await self._run_and_poll(environment_input, 'environment') + async def check_health_ws(self, health_url: str): + try: + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: + response = await client.get(health_url) + response.raise_for_status() + return True + except Exception as e: + logger.info(f"An unexpected error occurred: {e}") + logger.info(f"Full traceback: {traceback.format_exc()}") + return False + async def check_user_ws(self, user_input: Dict[str, str]): response = await self.send_receive_ws(user_input, "check_user") logger.info(f"Check user response: {response}")