diff --git a/agency_swarm/agency/genesis/AgentCreator/AgentCreator.py b/agency_swarm/agency/genesis/AgentCreator/AgentCreator.py index 35385d1f..d094f4cc 100644 --- a/agency_swarm/agency/genesis/AgentCreator/AgentCreator.py +++ b/agency_swarm/agency/genesis/AgentCreator/AgentCreator.py @@ -1,13 +1,15 @@ from agency_swarm import Agent +from agency_swarm.agents.agent import DEFAULT_MODEL from .tools.ImportAgent import ImportAgent from .tools.CreateAgentTemplate import CreateAgentTemplate from .tools.ReadManifesto import ReadManifesto class AgentCreator(Agent): - def __init__(self): + def __init__(self, model=DEFAULT_MODEL): super().__init__( description="This agent is responsible for creating new agents for the agency.", instructions="./instructions.md", tools=[ImportAgent, CreateAgentTemplate, ReadManifesto], temperature=0.3, + model=model, ) \ No newline at end of file diff --git a/agency_swarm/agency/genesis/GenesisAgency.py b/agency_swarm/agency/genesis/GenesisAgency.py index e3b0014f..7da29ba4 100644 --- a/agency_swarm/agency/genesis/GenesisAgency.py +++ b/agency_swarm/agency/genesis/GenesisAgency.py @@ -5,17 +5,17 @@ from .OpenAPICreator import OpenAPICreator from .ToolCreator import ToolCreator from agency_swarm.util.helpers import get_available_agent_descriptions - +from agency_swarm.agents.agent import DEFAULT_MODEL class GenesisAgency(Agency): - def __init__(self, with_browsing=True, **kwargs): + def __init__(self, with_browsing=True, model=DEFAULT_MODEL, **kwargs): if "max_prompt_tokens" not in kwargs: kwargs["max_prompt_tokens"] = 25000 if 'agency_chart' not in kwargs: - agent_creator = AgentCreator() - genesis_ceo = GenesisCEO() - tool_creator = ToolCreator() - openapi_creator = OpenAPICreator() + agent_creator = AgentCreator(model=model) + genesis_ceo = GenesisCEO(model=model) + tool_creator = ToolCreator(model=model) + openapi_creator = OpenAPICreator(model=model) kwargs['agency_chart'] = [ genesis_ceo, tool_creator, agent_creator, [genesis_ceo, agent_creator], diff --git a/agency_swarm/agency/genesis/GenesisCEO/GenesisCEO.py b/agency_swarm/agency/genesis/GenesisCEO/GenesisCEO.py index c01ae104..b358af6f 100644 --- a/agency_swarm/agency/genesis/GenesisCEO/GenesisCEO.py +++ b/agency_swarm/agency/genesis/GenesisCEO/GenesisCEO.py @@ -1,19 +1,21 @@ from pathlib import Path from agency_swarm import Agent +from agency_swarm.agents.agent import DEFAULT_MODEL from .tools.CreateAgencyFolder import CreateAgencyFolder from .tools.FinalizeAgency import FinalizeAgency from .tools.ReadRequirements import ReadRequirements class GenesisCEO(Agent): - def __init__(self): + def __init__(self, model=DEFAULT_MODEL): super().__init__( description="Acts as the overseer and communicator across the agency, ensuring alignment with the " "agency's goals.", instructions="./instructions.md", tools=[CreateAgencyFolder, FinalizeAgency, ReadRequirements], temperature=0.4, + model=model, ) diff --git a/agency_swarm/agency/genesis/OpenAPICreator/OpenAPICreator.py b/agency_swarm/agency/genesis/OpenAPICreator/OpenAPICreator.py index 96f106fd..3e16ff63 100644 --- a/agency_swarm/agency/genesis/OpenAPICreator/OpenAPICreator.py +++ b/agency_swarm/agency/genesis/OpenAPICreator/OpenAPICreator.py @@ -1,11 +1,13 @@ from agency_swarm import Agent +from agency_swarm.agents.agent import DEFAULT_MODEL from .tools.CreateToolsFromOpenAPISpec import CreateToolsFromOpenAPISpec class OpenAPICreator(Agent): - def __init__(self): + def __init__(self, model=DEFAULT_MODEL): super().__init__( description="This agent is responsible for creating new tools from an OpenAPI specifications.", instructions="./instructions.md", - tools=[CreateToolsFromOpenAPISpec] - ) \ No newline at end of file + tools=[CreateToolsFromOpenAPISpec], + model=model, + ) diff --git a/agency_swarm/agency/genesis/ToolCreator/ToolCreator.py b/agency_swarm/agency/genesis/ToolCreator/ToolCreator.py index f9f3dc7d..9f876768 100644 --- a/agency_swarm/agency/genesis/ToolCreator/ToolCreator.py +++ b/agency_swarm/agency/genesis/ToolCreator/ToolCreator.py @@ -1,15 +1,17 @@ from agency_swarm import Agent +from agency_swarm.agents.agent import DEFAULT_MODEL from .tools.CreateTool import CreateTool from .tools.TestTool import TestTool class ToolCreator(Agent): - def __init__(self): + def __init__(self, model=DEFAULT_MODEL): super().__init__( description="This agent is responsible for creating new tools for the agency using python code.", instructions="./instructions.md", tools=[CreateTool, TestTool], temperature=0, + model=model, ) diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index 05864172..fc09eca6 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -8,11 +8,13 @@ from deepdiff import DeepDiff from openai import NotFoundError from openai.types.beta.assistant import ToolResources +from astra_assistants import patch, OpenAI +from agency_swarm.agents.config import DEFAULT_MODEL from agency_swarm.tools import BaseTool, ToolFactory, Retrieval from agency_swarm.tools import FileSearch, CodeInterpreter from agency_swarm.tools.oai.FileSearch import FileSearchConfig -from agency_swarm.util.oai import get_openai_client +from agency_swarm.util.oai import get_openai_client, set_openai_client from agency_swarm.util.openapi import validate_openapi_spec from agency_swarm.util.shared_state import SharedState from pydantic import BaseModel @@ -84,7 +86,7 @@ def __init__( api_params: Dict[str, Dict[str, str]] = None, file_ids: List[str] = None, metadata: Dict[str, str] = None, - model: str = "gpt-4o-2024-08-06", + model: str = DEFAULT_MODEL, validation_attempts: int = 1, max_prompt_tokens: int = None, max_completion_tokens: int = None, @@ -159,7 +161,7 @@ def __init__( self._shared_instructions = None # init methods - self.client = get_openai_client() + self.client = get_openai_client(model=model) self._read_instructions() # upload files @@ -226,6 +228,7 @@ def init_oai(self): if assistant_settings['name'] == self.name: try: self.assistant = self.client.beta.assistants.retrieve(assistant_settings['id']) + self.id = assistant_settings['id'] # update assistant if parameters are different @@ -239,6 +242,7 @@ def init_oai(self): self._update_settings() return self except NotFoundError: + print(f"Assistant not found. {assistant_settings} consider deleting your settings.json file and starting over.") continue # create assistant if settings.json does not exist or assistant with the same name does not exist diff --git a/agency_swarm/agents/config.py b/agency_swarm/agents/config.py new file mode 100644 index 00000000..4d95d206 --- /dev/null +++ b/agency_swarm/agents/config.py @@ -0,0 +1 @@ +DEFAULT_MODEL = "gpt-4o-2024-08-06" diff --git a/agency_swarm/cli.py b/agency_swarm/cli.py index 6a746e76..668e39e4 100644 --- a/agency_swarm/cli.py +++ b/agency_swarm/cli.py @@ -2,7 +2,7 @@ import os from dotenv import load_dotenv from agency_swarm.util.helpers import list_available_agents - +from agents.agent import DEFAULT_MODEL def main(): parser = argparse.ArgumentParser(description='Agency Swarm CLI.') @@ -23,6 +23,7 @@ def main(): genesis_parser.add_argument('--openai_key', default=None, type=str, help='OpenAI API key.') genesis_parser.add_argument('--with_browsing', default=False, action='store_true', help='Enable browsing agent.') + genesis_parser.add_argument('--model', default=DEFAULT_MODEL, type=str, help='Model to use for the agency.') # import-agent import_parser = subparsers.add_parser('import-agent', help='Import pre-made agent by name to a local directory.') @@ -47,7 +48,7 @@ def main(): set_openai_key(args.openai_key) from agency_swarm.agency.genesis import GenesisAgency - agency = GenesisAgency(with_browsing=args.with_browsing) + agency = GenesisAgency(with_browsing=args.with_browsing, model=args.model) agency.run_demo() elif args.command == "import-agent": from agency_swarm.util import import_agent diff --git a/agency_swarm/util/oai.py b/agency_swarm/util/oai.py index 38e22d70..2bcdec56 100644 --- a/agency_swarm/util/oai.py +++ b/agency_swarm/util/oai.py @@ -3,6 +3,9 @@ import threading import os +from astra_assistants import patch +from openai.types.chat_model import ChatModel + from dotenv import load_dotenv load_dotenv() @@ -11,7 +14,7 @@ client = None -def get_openai_client(): +def get_openai_client(model:str=None): global client with client_lock: if client is None: @@ -23,6 +26,9 @@ def get_openai_client(): timeout=httpx.Timeout(60.0, read=40, connect=5.0), max_retries=10, default_headers={"OpenAI-Beta": "assistants=v2"}) + if model is not None and model not in ChatModel: + print(f"Using astra-assistants for non OpenAI model {model}, note thread URLs will not work.") + client = patch(client) return client diff --git a/pyproject.toml b/pyproject.toml index 9d2ea2f2..f5fd3bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,13 @@ dependencies = [ "termcolor==2.4.0", "python-dotenv==1.0.1", "rich==13.7.1", - "jsonref==1.1.0" + "jsonref==1.1.0", + "astra-assistants>=2.1.2", ] -requires-python = ">=3.7" +requires-python = ">=3.10" urls = { homepage = "https://github.com/VRSEN/agency-swarm" } [project.scripts] agency-swarm = "agency_swarm.cli:main" -[tool.setuptools_scm] \ No newline at end of file +[tool.setuptools_scm] diff --git a/requirements.txt b/requirements.txt index 89458b32..7719ab43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ deepdiff==6.7.1 termcolor==2.4.0 python-dotenv==1.0.1 rich==13.7.1 -jsonref==1.1.0 \ No newline at end of file +jsonref==1.1.0 +astra-assistants==2.1.2