diff --git a/naptha_sdk/cli.py b/naptha_sdk/cli.py index 483f6e7..ad65b4b 100644 --- a/naptha_sdk/cli.py +++ b/naptha_sdk/cli.py @@ -1,6 +1,11 @@ import argparse import asyncio -import json +from dotenv import load_dotenv +from naptha_sdk.client.naptha import Naptha +from naptha_sdk.toolset import Toolset +from naptha_sdk.client.hub import user_setup_flow +from naptha_sdk.user import get_public_key +from naptha_sdk.schemas import AgentConfig, AgentDeployment, EnvironmentDeployment, OrchestratorDeployment, OrchestratorRunInput, EnvironmentRunInput import os import shlex from textwrap import wrap @@ -387,7 +392,6 @@ async def read_storage(naptha, hash_or_name, output_dir='./files', ipfs=False): except Exception as err: print(f"Error: {err}") - async def write_storage(naptha, storage_input, ipfs=False, publish_to_ipns=False, update_ipns_name=None): """Write to storage, optionally to IPFS and/or IPNS.""" try: @@ -486,6 +490,31 @@ async def main(): write_storage_parser.add_argument("--publish_to_ipns", help="Publish to IPNS", action="store_true") write_storage_parser.add_argument("--update_ipns_name", help="Update IPNS name") + # Toolset commands + toolset_parser = subparsers.add_parser("toolset", help="List available tools.") + toolset_parser.add_argument('-n', '--node_url', help='Node URL to connect to', type=str, default=None) + + #i'm here and trying to figure out where the url comes from elsewhere in this file, because i don't think it should be a required argument + toolset_parser.add_argument( + '-lr', '--load_repo', + help='Load a github repository into the given toolset. will create a new toolset if not present.', + type=str, + nargs=2, + metavar=('toolset_name', 'repo_url') + ) + # set toolset + toolset_parser.add_argument('-s', '--set_toolset', help='Set a toolset by name', type=str) + # get current toolset + toolset_parser.add_argument('-g', '--current_toolset', help='Get the current toolset', action='store_true') + # run tool + toolset_parser.add_argument('-r', '--run_tool', + help='Run a tool. Provide the toolset name, tool name, and parameters in "key=value" format.', + type=str, + nargs=3, + metavar=('toolset_name', 'tool_name', 'params') + ) + + # Signup command signup_parser = subparsers.add_parser("signup", help="Sign up a new user.") @@ -497,7 +526,7 @@ async def main(): args = _parse_str_args(args) if args.command == "signup": _, user_id = await user_setup_flow(hub_url, public_key) - elif args.command in ["nodes", "agents", "orchestrators", "environments", "personas", "run", "inference", "read_storage", "write_storage", "publish", "create"]: + elif args.command in ["nodes", "agents", "orchestrators", "environments", "personas", "run", "inference", "read_storage", "write_storage", "publish", "create", "toolset"]: if not naptha.hub.is_authenticated: if not hub_username or not hub_password: print( @@ -656,6 +685,37 @@ async def main(): await read_storage(naptha, args.agent_run_id, args.output_dir, args.ipfs) elif args.command == "write_storage": await write_storage(naptha, args.storage_input, args.ipfs, args.publish_to_ipns, args.update_ipns_name) + elif args.command == "toolset": + + toolset_node_url = naptha.node.node_url + agent_id = "1" # TODO: get agent id from user + if hasattr(args, 'node_url') and args.node_url is not None: + toolset_node_url = args.node_url + + toolset = Toolset(toolset_node_url, agent_id) + + if hasattr(args, 'load_repo') and args.load_repo is not None: + toolset_name, repo_url = args.load_repo + await toolset.load_or_add_tool_repo_to_toolset(toolset_name, repo_url) + elif hasattr(args, 'set_toolset') and args.set_toolset is not None: + toolset_name = args.set_toolset + await toolset.set_toolset(toolset_name) + elif hasattr(args, 'current_toolset') and args.current_toolset: + await toolset.get_current_toolset() + elif hasattr(args, 'run_tool') and args.run_tool is not None: + toolset_name, tool_name, params = args.run_tool + # parse params p1=v1,p2=v2 to dict + params = dict(param.split('=') for param in params.split(',')) + + await toolset.run_tool(toolset_name, tool_name, params) + else: + # Get toolset list + result = await toolset.get_toolset_list() + print("~"*50) + print(result) + print("~"*50) + + elif args.command == "publish": await naptha.publish_agents() else: diff --git a/naptha_sdk/toolset.py b/naptha_sdk/toolset.py index 0431eaa..68c002f 100644 --- a/naptha_sdk/toolset.py +++ b/naptha_sdk/toolset.py @@ -26,6 +26,7 @@ async def load_or_add_tool_repo_to_toolset(self, toolset_name, repo_url): repo_url=repo_url, toolset_name=toolset_name) + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: load_repo_response = await client.post( f"{self.worker_node_url}/tool/add_tool_repo_to_toolset", @@ -54,6 +55,7 @@ async def get_toolset_list(self): toolset_list_response.raise_for_status() result = ToolsetList(**json.loads(toolset_list_response.text)) logger.info(result) + return result except (HTTPStatusError, RemoteProtocolError) as e: print(f"Failed to get toolset list: {e}") raise @@ -73,8 +75,12 @@ async def set_toolset(self, toolset_name): ) toolset_response.raise_for_status() result = ToolsetDetails(**json.loads(toolset_response.text)) - logger.info("~"*50) - logger.info(result) + + logger.info(f'Toolset {result.name} loaded') + # print description with newlines + for line in result.description.split("\n"): + logger.info(f" {line}") + except (HTTPStatusError, RemoteProtocolError) as e: print(f"Failed to set toolset: {e}") raise @@ -94,7 +100,10 @@ async def get_current_toolset(self): ) toolset_list_response.raise_for_status() result = ToolsetDetails(**json.loads(toolset_list_response.text)) - logger.info(result) + logger.info(f'Toolset {result.name} loaded') + # print description with newlines + for line in result.description.split("\n"): + logger.info(f" {line}") except (HTTPStatusError, RemoteProtocolError) as e: print(f"Failed to get toolset: {e}") raise @@ -103,15 +112,8 @@ async def get_current_toolset(self): raise async def run_tool(self, toolset_name, tool_name, params): - logger.info(f"Running tool {tool_name}") + logger.info(f"Running Tool: {toolset_name}.{tool_name}({params})") try: - print("~"*50) - print(f"Running tool {tool_name}") - print(f"Params: {params}") - print(f"Toolset: {toolset_name}") - print(f"Agent ID: {self.agent_id}") - - print("~"*50) request = ToolRunRequest( tool_run_id="1", agent_id=self.agent_id, @@ -126,9 +128,9 @@ async def run_tool(self, toolset_name, tool_name, params): json=request.model_dump() ) tool_run_response.raise_for_status() - logger.info(f"Ran tool {tool_name}") + logger.info(f"{toolset_name}.{tool_name}({params}):") result = ToolRunResult(**json.loads(tool_run_response.text)) - logger.info(result) + logger.info(result.result) except (HTTPStatusError, RemoteProtocolError) as e: print(f"Failed to run tool: {e}") raise