Skip to content

Commit

Permalink
automatically identify communication protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
moarshy committed Dec 10, 2024
1 parent 6433e24 commit 4d62aaa
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
33 changes: 31 additions & 2 deletions naptha_sdk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,35 @@ def __init__(self,
):
self.orchestrator_run = orchestrator_run
self.agent_index = agent_index
self.worker_node = None

async def initialize(self):
"""Initialize the agent by setting up the worker node connection"""
worker_node_url = self.orchestrator_run.agent_deployments[self.agent_index].worker_node_url
worker_node_url = await self.identify_communication_type(worker_node_url)
self.worker_node = Node(worker_node_url)

async def identify_communication_type(self, agent_node_url):
if "ws://" in agent_node_url:
return agent_node_url
elif "grpc://" in agent_node_url:
return agent_node_url
else:
# Create temporary node for health check
temp_node = Node(f"http://{agent_node_url}")
ws_health_url = f"http://{agent_node_url}/health"

try:
ws_health = await temp_node.check_health_ws(ws_health_url)
return f"ws://{agent_node_url}" if ws_health else f"grpc://{agent_node_url}"
except Exception as e:
logger.error(f"Error checking node health: {e}")
raise

async def call_agent_func(self, *args, **kwargs):
if self.worker_node is None:
await self.initialize()

logger.info(f"Running agent on worker node {self.worker_node.node_url}")

agent_run_input = AgentRunInput(
Expand All @@ -25,5 +50,9 @@ async def call_agent_func(self, *args, **kwargs):
agent_deployment=self.orchestrator_run.agent_deployments[self.agent_index].model_dump(),
)

agent_run = await self.worker_node.run_agent_in_node(agent_run_input)
return agent_run
try:
agent_run = await self.worker_node.run_agent_in_node(agent_run_input)
return agent_run
except Exception as e:
logger.error(f"Error running agent: {e}")
raise
11 changes: 11 additions & 0 deletions naptha_sdk/client/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ async def run_environment_and_poll(self, environment_input: EnvironmentRunInput)
"""Run an environment and poll for results until completion."""
return await self._run_and_poll(environment_input, 'environment')

async def check_health_ws(self, health_url: str):
try:
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
response = await client.get(health_url)
response.raise_for_status()
return True
except Exception as e:
logger.info(f"An unexpected error occurred: {e}")
logger.info(f"Full traceback: {traceback.format_exc()}")
return False

async def check_user_ws(self, user_input: Dict[str, str]):
response = await self.send_receive_ws(user_input, "check_user")
logger.info(f"Check user response: {response}")
Expand Down

0 comments on commit 4d62aaa

Please sign in to comment.