-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6433e24
commit 9cf7425
Showing
3 changed files
with
236 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from naptha_sdk.client.node import Node, HTTP_TIMEOUT | ||
from naptha_sdk.schemas import ToolsetLoadRepoRequest, ToolsetListRequest, ToolsetList, SetToolsetRequest, ToolsetDetails, ToolsetRequest, ToolRunRequest, ToolRunResult | ||
from naptha_sdk.utils import get_logger | ||
from httpx import HTTPStatusError, RemoteProtocolError | ||
import httpx | ||
import json | ||
|
||
logger = get_logger(__name__) | ||
|
||
class Toolset: | ||
def __init__(self, | ||
worker_node_url, | ||
agent_id, | ||
*args, | ||
**kwargs | ||
): | ||
|
||
self.agent_id = agent_id | ||
self.worker_node_url = worker_node_url | ||
|
||
async def load_or_add_tool_repo_to_toolset(self, toolset_name, repo_url): | ||
logger.info(f"Loading tool repo to toolset on worker node {self.worker_node_url}") | ||
try: | ||
request = ToolsetLoadRepoRequest( | ||
agent_id=self.agent_id, | ||
repo_url=repo_url, | ||
toolset_name=toolset_name) | ||
|
||
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: | ||
load_repo_response = await client.post( | ||
f"{self.worker_node_url}/tool/add_tool_repo_to_toolset", | ||
json=request.model_dump() | ||
) | ||
load_repo_response.raise_for_status() | ||
logger.info(f"Loaded repo {repo_url} into toolset {toolset_name}") | ||
except (HTTPStatusError, RemoteProtocolError) as e: | ||
print(f"Failed to load repo: {e}") | ||
raise | ||
except Exception as e: | ||
print(f"Error loading repo: {e}") | ||
raise | ||
|
||
async def get_toolset_list(self): | ||
logger.info(f"Getting toolset list from worker node") | ||
try: | ||
request = ToolsetListRequest(agent_id=self.agent_id) | ||
|
||
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: | ||
# Send agent_id as a query parameter | ||
toolset_list_response = await client.post( | ||
f"{self.worker_node_url}/tool/get_toolset_list", | ||
json=request.model_dump() | ||
) | ||
toolset_list_response.raise_for_status() | ||
result = ToolsetList(**json.loads(toolset_list_response.text)) | ||
logger.info(result) | ||
except (HTTPStatusError, RemoteProtocolError) as e: | ||
print(f"Failed to get toolset list: {e}") | ||
raise | ||
except Exception as e: | ||
print(f"Error getting toolset list: {e}") | ||
raise | ||
|
||
async def set_toolset(self, toolset_name): | ||
logger.info(f"Setting toolset") | ||
try: | ||
request = SetToolsetRequest(agent_id=self.agent_id, toolset_name=toolset_name) | ||
|
||
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: | ||
toolset_response = await client.post( | ||
f"{self.worker_node_url}/tool/set_toolset", | ||
json=request.model_dump() | ||
) | ||
toolset_response.raise_for_status() | ||
result = ToolsetDetails(**json.loads(toolset_response.text)) | ||
logger.info("~"*50) | ||
logger.info(result) | ||
except (HTTPStatusError, RemoteProtocolError) as e: | ||
print(f"Failed to set toolset: {e}") | ||
raise | ||
except Exception as e: | ||
print(f"Error setting toolset: {e}") | ||
raise | ||
|
||
async def get_current_toolset(self): | ||
logger.info(f"Getting toolset") | ||
try: | ||
request = ToolsetRequest(agent_id=self.agent_id) | ||
|
||
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: | ||
toolset_list_response = await client.post( | ||
f"{self.worker_node_url}/tool/get_current_toolset", | ||
json=request.model_dump() | ||
) | ||
toolset_list_response.raise_for_status() | ||
result = ToolsetDetails(**json.loads(toolset_list_response.text)) | ||
logger.info(result) | ||
except (HTTPStatusError, RemoteProtocolError) as e: | ||
print(f"Failed to get toolset: {e}") | ||
raise | ||
except Exception as e: | ||
print(f"Error getting toolset: {e}") | ||
raise | ||
|
||
async def run_tool(self, toolset_name, tool_name, params): | ||
logger.info(f"Running tool {tool_name}") | ||
try: | ||
print("~"*50) | ||
print(f"Running tool {tool_name}") | ||
print(f"Params: {params}") | ||
print(f"Toolset: {toolset_name}") | ||
print(f"Agent ID: {self.agent_id}") | ||
|
||
print("~"*50) | ||
request = ToolRunRequest( | ||
tool_run_id="1", | ||
agent_id=self.agent_id, | ||
toolset_id=toolset_name, | ||
tool_id=tool_name, | ||
params=params | ||
) | ||
|
||
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: | ||
tool_run_response = await client.post( | ||
f"{self.worker_node_url}/tool/run_tool", | ||
json=request.model_dump() | ||
) | ||
tool_run_response.raise_for_status() | ||
logger.info(f"Ran tool {tool_name}") | ||
result = ToolRunResult(**json.loads(tool_run_response.text)) | ||
logger.info(result) | ||
except (HTTPStatusError, RemoteProtocolError) as e: | ||
print(f"Failed to run tool: {e}") | ||
raise | ||
except Exception as e: | ||
print(f"Error running tool: {e}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from naptha_sdk.toolset import Toolset | ||
import asyncio | ||
|
||
async def main(): | ||
# Initialize Toolset | ||
worker_node_url = "http://0.0.0.0:7001" # Replace with your worker node URL | ||
agent_id = "test-agent-id" # Replace with your agent ID | ||
|
||
toolset = Toolset(worker_node_url=worker_node_url, agent_id=agent_id) | ||
|
||
try: | ||
# Test loading a tool repository | ||
print("\n1. Testing load_or_add_tool_repo_to_toolset:") | ||
await toolset.load_or_add_tool_repo_to_toolset( | ||
toolset_name="test-toolset", | ||
repo_url="https://github.com/C0deMunk33/test_toolset" | ||
) | ||
|
||
# Test getting toolset list | ||
print("\n2. Testing get_toolset_list:") | ||
await toolset.get_toolset_list() | ||
|
||
# Test setting a toolset | ||
print("\n3. Testing set_toolset:") | ||
await toolset.set_toolset(toolset_name="test-toolset") | ||
|
||
# Test getting current toolset | ||
print("\n4. Testing get_current_toolset:") | ||
await toolset.get_current_toolset() | ||
|
||
# Test running a tool | ||
print("\n5. Testing run_tool:") | ||
test_params = { | ||
"a": "1", | ||
"b": "1" | ||
} | ||
result = await toolset.run_tool( | ||
toolset_name="test-toolset", | ||
tool_name="add", | ||
params=test_params | ||
) | ||
|
||
print(f"\nTool run result: {result}") | ||
|
||
except Exception as e: | ||
print(f"Test failed with error: {e}") | ||
raise | ||
else: | ||
print("\nAll tests completed successfully!") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |