From aff284e9f85334f2e5b0815f4e0f8f3ba63b79cc Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Wed, 18 Sep 2024 19:08:41 +0300 Subject: [PATCH] Fix all agent tests and add create_backend_model function --- crab/agents/backend_models/__init__.py | 40 +++++++++++++ crab/agents/backend_models/claude_model.py | 9 ++- crab/agents/backend_models/gemini_model.py | 18 +++--- crab/agents/backend_models/openai_model.py | 15 ++++- crab/agents/policies/multi_agent_by_env.py | 23 ++++---- crab/agents/policies/multi_agent_by_func.py | 29 +++++----- crab/agents/policies/single_agent.py | 19 ++++--- crab/agents/utils.py | 56 +++++++++++++++++++ crab/core/agent_policy.py | 53 +----------------- .../backend_models/test_claude_model.py | 21 ++++--- .../backend_models/test_gemini_model.py | 14 +++-- .../backend_models/test_openai_model.py | 18 +++--- .../policies/test_multi_agent_by_func.py | 11 ++-- .../policies/test_mutli_agent_by_env.py | 11 ++-- test/agents/policies/test_single_agent.py | 7 ++- 15 files changed, 218 insertions(+), 126 deletions(-) create mode 100644 crab/agents/utils.py diff --git a/crab/agents/backend_models/__init__.py b/crab/agents/backend_models/__init__.py index 5f36882..c087ca0 100644 --- a/crab/agents/backend_models/__init__.py +++ b/crab/agents/backend_models/__init__.py @@ -12,7 +12,47 @@ # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== # ruff: noqa: F401 +from typing import Any, Literal + +from pydantic import BaseModel + +from crab.core.backend_model import BackendModel + from .camel_model import CamelModel from .claude_model import ClaudeModel from .gemini_model import GeminiModel from .openai_model import OpenAIModel + + +class BackendModelConfig(BaseModel): + model_class: Literal["openai", "claude", "gemini", "camel"] + model_name: str + history_messages_len: int = 0 + parameters: dict[str, Any] = {} + tool_call_required: bool = False + + +def create_backend_model(model_config: BackendModelConfig) -> BackendModel: + match model_config.model_class: + case "claude": + return ClaudeModel( + model=model_config.model_name, + parameters=model_config.parameters, + history_messages_len=model_config.history_messages_len, + ) + case "gemini": + return GeminiModel( + model=model_config.model_name, + parameters=model_config.parameters, + history_messages_len=model_config.history_messages_len, + ) + case "openai": + return OpenAIModel( + model=model_config.model_name, + parameters=model_config.parameters, + history_messages_len=model_config.history_messages_len, + ) + case "camel": + raise NotImplementedError("Cannot support camel model currently.") + case _: + raise ValueError(f"Unsupported model name: {model_config.model_name}") diff --git a/crab/agents/backend_models/claude_model.py b/crab/agents/backend_models/claude_model.py index 7ffc4c2..cf03e55 100644 --- a/crab/agents/backend_models/claude_model.py +++ b/crab/agents/backend_models/claude_model.py @@ -32,6 +32,7 @@ def __init__( model: str, parameters: dict[str, Any] = dict(), history_messages_len: int = 0, + tool_call_required: bool = False, ) -> None: if anthropic_model_enable is False: raise ImportError("Please install anthropic to use ClaudeModel") @@ -41,6 +42,7 @@ def __init__( history_messages_len, ) self.client = anthropic.Anthropic() + self.tool_call_required = tool_call_required def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message @@ -93,6 +95,7 @@ def record_message(self, new_message: dict, response_message: dict) -> None: "content": "success", } for call in tool_calls + if call is ToolUseBlock ], } ) @@ -101,12 +104,14 @@ def call_api(self, request_messages: list): while True: try: if self.action_schema is not None: - response = self.client.beta.tools.messages.create( + response = self.client.messages.create( system=self.system_message, # <-- system prompt messages=request_messages, # type: ignore model=self.model, tools=self.action_schema, - tool_choice={"type": "any"}, + tool_choice={ + "type": "any" if self.tool_call_required else "auto" + }, **self.parameters, ) else: diff --git a/crab/agents/backend_models/gemini_model.py b/crab/agents/backend_models/gemini_model.py index 663aba2..26123b6 100644 --- a/crab/agents/backend_models/gemini_model.py +++ b/crab/agents/backend_models/gemini_model.py @@ -35,6 +35,7 @@ def __init__( model: str, parameters: dict[str, Any] = dict(), history_messages_len: int = 0, + tool_call_required: bool = False, ) -> None: if gemini_model_enable is False: raise ImportError("Please install google.generativeai to use GeminiModel") @@ -45,6 +46,7 @@ def __init__( ) genai.configure(api_key=os.environ["GEMINI_API_KEY"]) self.client = genai + self.tool_call_required = tool_call_required def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message @@ -98,7 +100,11 @@ def call_api(self, request_messages: list): try: if self.action_schema is not None: tool_config = content_types.to_tool_config( - {"function_calling_config": {"mode": "ANY"}} + { + "function_calling_config": { + "mode": "ANY" if self.tool_call_required else "AUTO" + } + } ) response = self.client.GenerativeModel( self.model, system_instruction=self.system_message @@ -141,9 +147,7 @@ def _convert_action_to_schema(cls, action_space): return None actions = [] for action in action_space: - actions.append( - Tool(function_declarations=[cls._action_to_funcdec_policy(action)]) - ) + actions.append(Tool(function_declarations=[cls._action_to_funcdec(action)])) return actions @staticmethod @@ -171,14 +175,14 @@ def _clear_schema(cls, schema_dict: dict): cls._clear_schema(schema_dict["items"]) @classmethod - def _action_to_funcdec(cls, action: Action, env: str): + def _action_to_funcdec(cls, action: Action) -> FunctionDeclaration: "Converts crab Action to google FunctionDeclaration" p_schema = action.parameters.model_json_schema() if "$defs" in p_schema: p_schema = json_expand_refs(p_schema) cls._clear_schema(p_schema) return FunctionDeclaration( - name=action.name + "__in__" + env, - description="In {} environment, {}".format(env, action.description), + name=action.name, + description=action.description, parameters=p_schema, ) diff --git a/crab/agents/backend_models/openai_model.py b/crab/agents/backend_models/openai_model.py index 5e95535..c7ba157 100644 --- a/crab/agents/backend_models/openai_model.py +++ b/crab/agents/backend_models/openai_model.py @@ -31,6 +31,7 @@ def __init__( model: str, parameters: dict[str, Any] = dict(), history_messages_len: int = 0, + tool_call_required: bool = False, ) -> None: if not openai_model_enable: raise ImportError("Please install openai to use OpenAIModel") @@ -40,6 +41,16 @@ def __init__( history_messages_len, ) self.client = openai.OpenAI() + self.tool_call_required = tool_call_required + self.system_message = "You are a helpful assistant." + self.openai_system_message = { + "role": "system", + "content": self.system_message, + } + self.action_space = None + self.action_schema = None + self.token_usage = 0 + self.chat_history: list[list[ChatCompletionMessage | dict]] = [] def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message @@ -88,12 +99,12 @@ def call_api(self, request_messages: list) -> ChatCompletionMessage: messages=request_messages, model=self.model, tools=self.action_schema, - tool_choice="required", + tool_choice="required" if self.tool_call_required else "auto", **self.parameters, ) else: response = self.client.chat.completions.create( - messages=request_messages, + messages=request_messages, # type: ignore model=self.model, **self.parameters, ) diff --git a/crab/agents/policies/multi_agent_by_env.py b/crab/agents/policies/multi_agent_by_env.py index d2cfc2c..b72a535 100644 --- a/crab/agents/policies/multi_agent_by_env.py +++ b/crab/agents/policies/multi_agent_by_env.py @@ -11,9 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== -from copy import copy - from crab import Action, ActionOutput +from crab.agents.backend_models import BackendModelConfig, create_backend_model +from crab.agents.utils import generate_action_prompt from crab.core.agent_policy import AgentPolicy from crab.core.backend_model import ( BackendModel, @@ -57,12 +57,12 @@ class MultiAgentByEnvPolicy(AgentPolicy): def __init__( self, - main_agent_model_backend: BackendModel, - env_agent_model_backend: BackendModel, + main_agent_model_backend: BackendModelConfig, + env_agent_model_backend: BackendModelConfig, ): - self.main_agent_model_backend = copy(main_agent_model_backend) - self.env_agent_model_backend = env_agent_model_backend - self.reset(task_description="", action_spaces=None, env_descriptions={}) + self.main_agent_model_backend = create_backend_model(main_agent_model_backend) + self.env_agent_model_backend_config = env_agent_model_backend + self.reset(task_description="", action_spaces={}, env_descriptions={}) def reset( self, @@ -82,15 +82,16 @@ def reset( ) self.env_agent_model_backends: dict[str, BackendModel] = {} for env in action_spaces: - backend = copy(self.env_agent_model_backend) + backend = create_backend_model(self.env_agent_model_backend_config) if env == "root": backend.reset(root_agent_system_message, action_spaces[env]) else: + backend.require_tool = True env_agent_system_message = self._env_agent_prompt.format( task_description=task_description, environment=env, env_description=env_descriptions[env], - action_descriptions=self.generate_action_prompt(action_spaces[env]), + action_descriptions=generate_action_prompt(action_spaces[env]), ) backend.reset(env_agent_system_message, action_spaces[env]) self.env_agent_model_backends[env] = backend @@ -140,5 +141,7 @@ def chat( ) else: output = backend.chat((main_agent_message, MessageType.TEXT)) + for action in output.action_list: + action.env = env tool_calls.extend(output.action_list) - return self.decode_combined_action(tool_calls) + return tool_calls diff --git a/crab/agents/policies/multi_agent_by_func.py b/crab/agents/policies/multi_agent_by_func.py index 8f95b72..eec0159 100644 --- a/crab/agents/policies/multi_agent_by_func.py +++ b/crab/agents/policies/multi_agent_by_func.py @@ -11,14 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== -from copy import copy - -from crab import Action, ActionOutput -from crab.core.agent_policy import AgentPolicy -from crab.core.backend_model import ( - BackendModel, - MessageType, +from crab.agents.backend_models import BackendModelConfig, create_backend_model +from crab.agents.utils import ( + combine_multi_env_action_space, + decode_combined_action, + generate_action_prompt, ) +from crab.core import Action, ActionOutput +from crab.core.agent_policy import AgentPolicy +from crab.core.backend_model import MessageType class MultiAgentByFuncPolicy(AgentPolicy): @@ -40,11 +41,11 @@ class MultiAgentByFuncPolicy(AgentPolicy): def __init__( self, - main_agent_model_backend: BackendModel, - tool_agent_model_backend: BackendModel, + main_agent_model_backend: BackendModelConfig, + tool_agent_model_backend: BackendModelConfig, ): - self.main_agent_model_backend = copy(main_agent_model_backend) - self.tool_agent_model_backend = copy(tool_agent_model_backend) + self.main_agent_model_backend = create_backend_model(main_agent_model_backend) + self.tool_agent_model_backend = create_backend_model(tool_agent_model_backend) self.reset(task_description="", action_spaces=None, env_descriptions={}) def reset( @@ -54,11 +55,11 @@ def reset( env_descriptions: dict[str, str], ) -> list[ActionOutput]: self.task_description = task_description - self.action_space = self.combine_multi_env_action_space(action_spaces) + self.action_space = combine_multi_env_action_space(action_spaces) main_agent_system_message = self._system_prompt.format( task_description=task_description, - action_descriptions=self.generate_action_prompt(self.action_space), + action_descriptions=generate_action_prompt(self.action_space), env_description=str(env_descriptions), ) self.main_agent_model_backend.reset(main_agent_system_message, None) @@ -95,4 +96,4 @@ def chat( tool_output = self.tool_agent_model_backend.chat( (output.message, MessageType.TEXT) ) - return self.decode_combined_action(tool_output.action_list) + return decode_combined_action(tool_output.action_list) diff --git a/crab/agents/policies/single_agent.py b/crab/agents/policies/single_agent.py index 7003c11..7746c53 100644 --- a/crab/agents/policies/single_agent.py +++ b/crab/agents/policies/single_agent.py @@ -11,12 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== -from copy import copy - from crab import Action, ActionOutput +from crab.agents.backend_models import BackendModelConfig, create_backend_model +from crab.agents.utils import ( + combine_multi_env_action_space, + decode_combined_action, + generate_action_prompt, +) from crab.core.agent_policy import AgentPolicy from crab.core.backend_model import ( - BackendModel, MessageType, ) from crab.utils.measure import timed @@ -46,9 +49,9 @@ class SingleAgentPolicy(AgentPolicy): def __init__( self, - model_backend: BackendModel, + model_backend: BackendModelConfig, ): - self.model_backend = copy(model_backend) + self.model_backend = create_backend_model(model_backend) self.reset(task_description="", action_spaces=None, env_descriptions={}) def reset( @@ -58,10 +61,10 @@ def reset( env_descriptions: dict[str, str], ) -> list: self.task_description = task_description - self.action_space = self.combine_multi_env_action_space(action_spaces) + self.action_space = combine_multi_env_action_space(action_spaces) system_message = self._system_prompt.format( task_description=task_description, - action_descriptions=self.generate_action_prompt(self.action_space), + action_descriptions=generate_action_prompt(self.action_space), env_description=str(env_descriptions), ) self.model_backend.reset(system_message, self.action_space) @@ -87,4 +90,4 @@ def chat( ) ) output = self.model_backend.chat(prompt) - return self.decode_combined_action(output.action_list) + return decode_combined_action(output.action_list) diff --git a/crab/agents/utils.py b/crab/agents/utils.py new file mode 100644 index 0000000..e3a18c7 --- /dev/null +++ b/crab/agents/utils.py @@ -0,0 +1,56 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from crab.core import Action, ActionOutput + + +def combine_multi_env_action_space( + action_space: dict[str, list[Action]] | None, +) -> list[Action]: + """Combine multi-env action space together to fit in a single agent.""" + result = [] + if action_space is None: + return result + for env in action_space: + for action in action_space[env]: + new_action = action.model_copy() + new_action.name = new_action.name + "__in__" + env + new_action.description = f"In {env} environment, " + new_action.description + result.append(new_action) + return result + + +def decode_combined_action( + output_actions: list[ActionOutput], +) -> list[ActionOutput]: + """Decode combined action output to action output with the corresponding + environment. + """ + result = [] + for output in output_actions: + name_env = output.name.split("__in__") + if len(name_env) != 2: + raise RuntimeError( + 'The decoded action name should contain the splitter "__in__".' + ) + new_output = output.model_copy() + new_output.name = name_env[0] + new_output.env = name_env[1] + result.append(new_output) + return result + + +def generate_action_prompt(action_space: list[Action]) -> str: + return "".join( + [f"[{action.name}: {action.description}]\n" for action in action_space] + ) diff --git a/crab/core/agent_policy.py b/crab/core/agent_policy.py index 7f460ff..baea1ad 100644 --- a/crab/core/agent_policy.py +++ b/crab/core/agent_policy.py @@ -13,14 +13,14 @@ # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== from abc import ABC, abstractmethod -from .models import Action, ActionOutput, MessageType +from .models import Action, ActionOutput, Message class AgentPolicy(ABC): @abstractmethod def chat( self, - observation: dict[str, list[tuple[str, MessageType]]], + observation: dict[str, list[Message]], ) -> list[ActionOutput]: ... @abstractmethod @@ -32,54 +32,7 @@ def reset( ) -> None: ... @abstractmethod - def get_token_usage(self): ... + def get_token_usage(self) -> int: ... @abstractmethod def get_backend_model_name(self) -> str: ... - - @staticmethod - def combine_multi_env_action_space( - action_space: dict[str, list[Action]] | None, - ) -> list[Action]: - """Combine multi-env action space together to fit in a single agent.""" - result = [] - if action_space is None: - return result - for env in action_space: - for action in action_space[env]: - new_action = action.model_copy() - new_action.name = new_action.name + "__in__" + env - new_action.description = ( - f"In {env} environment, " + new_action.description - ) - result.append(new_action) - return result - - @staticmethod - def decode_combined_action( - output_actions: list[ActionOutput], - ) -> list[ActionOutput]: - """Decode combined action output to action output with the corresponding - environment. - """ - result = [] - for output in output_actions: - name_env = output.name.split("__in__") - if len(name_env) != 2: - raise RuntimeError( - 'The decoded action name should contain the splitter "__in__".' - ) - new_output = output.model_copy() - new_output.name = name_env[0] - new_output.env = name_env[1] - result.append(new_output) - return result - - @staticmethod - def generate_action_prompt(actions: list[Action] | None): - if actions is None: - return None - result = "" - for action in actions: - result += f"[{action.name}: {action.description}]\n" - return result diff --git a/test/agents/backend_models/test_claude_model.py b/test/agents/backend_models/test_claude_model.py index ace602a..be3ddb8 100644 --- a/test/agents/backend_models/test_claude_model.py +++ b/test/agents/backend_models/test_claude_model.py @@ -14,17 +14,20 @@ import pytest from crab import MessageType, action -from crab.agents.backend_models.claude_model import ClaudeModel +from crab.agents.backend_models import BackendModelConfig, create_backend_model # TODO: Add mock data @pytest.fixture def claude_model_text(): - return ClaudeModel( - model="claude-3-opus-20240229", - parameters={"max_tokens": 3000}, - history_messages_len=1, + return create_backend_model( + BackendModelConfig( + model_class="claude", + model_name="claude-3-opus-20240229", + parameters={"max_tokens": 3000}, + history_messages_len=1, + ) ) @@ -39,7 +42,7 @@ def add(a: int, b: int): return a + b -@pytest.mark.skip(reason="Mock data to be added") +# @pytest.mark.skip(reason="Mock data to be added") def test_text_chat(claude_model_text): message = ("Hello!", MessageType.TEXT) output = claude_model_text.chat(message) @@ -60,7 +63,7 @@ def test_text_chat(claude_model_text): assert len(claude_model_text.chat_history) == 3 -@pytest.mark.skip(reason="Mock data to be added") +# @pytest.mark.skip(reason="Mock data to be added") def test_action_chat(claude_model_text): claude_model_text.reset("You are a helpful assistant.", [add]) message = ( @@ -71,8 +74,8 @@ def test_action_chat(claude_model_text): 0, ) output = claude_model_text.chat(message) - assert output.message is None assert len(output.action_list) == 1 - assert output.action_list[0].arguments == {"a": 10, "b": 15} + args = output.action_list[0].arguments + assert args["a"] + args["b"] == 25 assert output.action_list[0].name == "add" assert claude_model_text.token_usage > 0 diff --git a/test/agents/backend_models/test_gemini_model.py b/test/agents/backend_models/test_gemini_model.py index 86ece01..1ab7877 100644 --- a/test/agents/backend_models/test_gemini_model.py +++ b/test/agents/backend_models/test_gemini_model.py @@ -14,17 +14,21 @@ import pytest from crab import MessageType, action -from crab.agents.backend_models.gemini_model import GeminiModel +from crab.agents.backend_models import BackendModelConfig, create_backend_model # TODO: Add mock data @pytest.fixture def gemini_model_text(): - return GeminiModel( - model="gemini-1.5-pro-latest", - parameters={"max_tokens": 3000}, - history_messages_len=1, + return create_backend_model( + BackendModelConfig( + model_class="gemini", + model_name="gemini-1.5-pro-latest", + parameters={"max_tokens": 3000}, + history_messages_len=1, + tool_call_required=False, + ) ) diff --git a/test/agents/backend_models/test_openai_model.py b/test/agents/backend_models/test_openai_model.py index 51e56ab..57c9b72 100644 --- a/test/agents/backend_models/test_openai_model.py +++ b/test/agents/backend_models/test_openai_model.py @@ -18,10 +18,8 @@ from openai.types.chat.chat_completion_message_tool_call import Function from crab import action -from crab.agents.backend_models.openai_model import ( - MessageType, - OpenAIModel, -) +from crab.agents.backend_models import BackendModelConfig, create_backend_model +from crab.agents.backend_models.openai_model import MessageType # Mock data for the OpenAI API response openai_mock_response = MagicMock( @@ -91,10 +89,14 @@ @pytest.fixture def openai_model_text(): os.environ["OPENAI_API_KEY"] = "MOCK" - return OpenAIModel( - model="gpt-4o", - parameters={"max_tokens": 3000}, - history_messages_len=1, + return create_backend_model( + BackendModelConfig( + model_class="openai", + model_name="gpt-4o", + parameters={"max_tokens": 3000}, + history_messages_len=1, + tool_call_required=False, + ) ) diff --git a/test/agents/policies/test_multi_agent_by_func.py b/test/agents/policies/test_multi_agent_by_func.py index d319488..b7e31af 100644 --- a/test/agents/policies/test_multi_agent_by_func.py +++ b/test/agents/policies/test_multi_agent_by_func.py @@ -14,15 +14,16 @@ import pytest from crab import create_benchmark -from crab.agents.backend_models.openai_model import OpenAIModel +from crab.agents.backend_models import BackendModelConfig from crab.agents.policies.multi_agent_by_func import MultiAgentByFuncPolicy from crab.benchmarks.template import multienv_template_benchmark_config @pytest.fixture def policy_fixture(): - model = OpenAIModel( - model="gpt-4o", + model = BackendModelConfig( + model_class="openai", + model_name="gpt-4o", parameters={"max_tokens": 3000}, history_messages_len=1, ) @@ -30,9 +31,11 @@ def policy_fixture(): benchmark = create_benchmark(benchmark_config) task, action_spaces = benchmark.start_task("0") policy = MultiAgentByFuncPolicy( - task_description=task.description, main_agent_model_backend=model, tool_agent_model_backend=model, + ) + policy.reset( + task_description=task.description, action_spaces=action_spaces, env_descriptions=benchmark.get_env_descriptions(), ) diff --git a/test/agents/policies/test_mutli_agent_by_env.py b/test/agents/policies/test_mutli_agent_by_env.py index 1f1e791..318e677 100644 --- a/test/agents/policies/test_mutli_agent_by_env.py +++ b/test/agents/policies/test_mutli_agent_by_env.py @@ -14,15 +14,16 @@ import pytest from crab import create_benchmark -from crab.agents.backend_models.openai_model import OpenAIModel +from crab.agents.backend_models import BackendModelConfig from crab.agents.policies.multi_agent_by_env import MultiAgentByEnvPolicy from crab.benchmarks.template import multienv_template_benchmark_config @pytest.fixture def policy_fixture(): - model = OpenAIModel( - model="gpt-4o", + model = BackendModelConfig( + model_class="openai", + model_name="gpt-4o", parameters={"max_tokens": 3000}, history_messages_len=1, ) @@ -30,9 +31,11 @@ def policy_fixture(): benchmark = create_benchmark(benchmark_config) task, action_spaces = benchmark.start_task("0") policy = MultiAgentByEnvPolicy( - task_description=task.description, main_agent_model_backend=model, env_agent_model_backend=model, + ) + policy.reset( + task_description=task.description, action_spaces=action_spaces, env_descriptions=benchmark.get_env_descriptions(), ) diff --git a/test/agents/policies/test_single_agent.py b/test/agents/policies/test_single_agent.py index 56f0bfa..440893c 100644 --- a/test/agents/policies/test_single_agent.py +++ b/test/agents/policies/test_single_agent.py @@ -26,7 +26,7 @@ ) from crab import create_benchmark -from crab.agents.backend_models.openai_model import OpenAIModel +from crab.agents.backend_models import BackendModelConfig from crab.agents.policies.single_agent import SingleAgentPolicy from crab.benchmarks.template import multienv_template_benchmark_config @@ -75,8 +75,9 @@ @pytest.fixture def policy_fixture(): os.environ["OPENAI_API_KEY"] = "MOCK" - model = OpenAIModel( - model="gpt-4o", + model = BackendModelConfig( + model_class="openai", + model_name="gpt-4o", parameters={"max_tokens": 3000}, history_messages_len=1, )