Skip to content

Commit

Permalink
Fix all agent tests and add create_backend_model function
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Sep 18, 2024
1 parent 138c1d8 commit aff284e
Show file tree
Hide file tree
Showing 15 changed files with 218 additions and 126 deletions.
40 changes: 40 additions & 0 deletions crab/agents/backend_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,47 @@
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# ruff: noqa: F401
from typing import Any, Literal

from pydantic import BaseModel

from crab.core.backend_model import BackendModel

from .camel_model import CamelModel
from .claude_model import ClaudeModel
from .gemini_model import GeminiModel
from .openai_model import OpenAIModel


class BackendModelConfig(BaseModel):
model_class: Literal["openai", "claude", "gemini", "camel"]
model_name: str
history_messages_len: int = 0
parameters: dict[str, Any] = {}
tool_call_required: bool = False


def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
match model_config.model_class:
case "claude":
return ClaudeModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
)
case "gemini":
return GeminiModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
)
case "openai":
return OpenAIModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
)
case "camel":
raise NotImplementedError("Cannot support camel model currently.")
case _:
raise ValueError(f"Unsupported model name: {model_config.model_name}")
9 changes: 7 additions & 2 deletions crab/agents/backend_models/claude_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
model: str,
parameters: dict[str, Any] = dict(),
history_messages_len: int = 0,
tool_call_required: bool = False,
) -> None:
if anthropic_model_enable is False:
raise ImportError("Please install anthropic to use ClaudeModel")
Expand All @@ -41,6 +42,7 @@ def __init__(
history_messages_len,
)
self.client = anthropic.Anthropic()
self.tool_call_required = tool_call_required

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
Expand Down Expand Up @@ -93,6 +95,7 @@ def record_message(self, new_message: dict, response_message: dict) -> None:
"content": "success",
}
for call in tool_calls
if call is ToolUseBlock
],
}
)
Expand All @@ -101,12 +104,14 @@ def call_api(self, request_messages: list):
while True:
try:
if self.action_schema is not None:
response = self.client.beta.tools.messages.create(
response = self.client.messages.create(
system=self.system_message, # <-- system prompt
messages=request_messages, # type: ignore
model=self.model,
tools=self.action_schema,
tool_choice={"type": "any"},
tool_choice={
"type": "any" if self.tool_call_required else "auto"
},
**self.parameters,
)
else:
Expand Down
18 changes: 11 additions & 7 deletions crab/agents/backend_models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
model: str,
parameters: dict[str, Any] = dict(),
history_messages_len: int = 0,
tool_call_required: bool = False,
) -> None:
if gemini_model_enable is False:
raise ImportError("Please install google.generativeai to use GeminiModel")
Expand All @@ -45,6 +46,7 @@ def __init__(
)
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
self.client = genai
self.tool_call_required = tool_call_required

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
Expand Down Expand Up @@ -98,7 +100,11 @@ def call_api(self, request_messages: list):
try:
if self.action_schema is not None:
tool_config = content_types.to_tool_config(
{"function_calling_config": {"mode": "ANY"}}
{
"function_calling_config": {
"mode": "ANY" if self.tool_call_required else "AUTO"
}
}
)
response = self.client.GenerativeModel(
self.model, system_instruction=self.system_message
Expand Down Expand Up @@ -141,9 +147,7 @@ def _convert_action_to_schema(cls, action_space):
return None
actions = []
for action in action_space:
actions.append(
Tool(function_declarations=[cls._action_to_funcdec_policy(action)])
)
actions.append(Tool(function_declarations=[cls._action_to_funcdec(action)]))
return actions

@staticmethod
Expand Down Expand Up @@ -171,14 +175,14 @@ def _clear_schema(cls, schema_dict: dict):
cls._clear_schema(schema_dict["items"])

@classmethod
def _action_to_funcdec(cls, action: Action, env: str):
def _action_to_funcdec(cls, action: Action) -> FunctionDeclaration:
"Converts crab Action to google FunctionDeclaration"
p_schema = action.parameters.model_json_schema()
if "$defs" in p_schema:
p_schema = json_expand_refs(p_schema)
cls._clear_schema(p_schema)
return FunctionDeclaration(
name=action.name + "__in__" + env,
description="In {} environment, {}".format(env, action.description),
name=action.name,
description=action.description,
parameters=p_schema,
)
15 changes: 13 additions & 2 deletions crab/agents/backend_models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
model: str,
parameters: dict[str, Any] = dict(),
history_messages_len: int = 0,
tool_call_required: bool = False,
) -> None:
if not openai_model_enable:
raise ImportError("Please install openai to use OpenAIModel")
Expand All @@ -40,6 +41,16 @@ def __init__(
history_messages_len,
)
self.client = openai.OpenAI()
self.tool_call_required = tool_call_required
self.system_message = "You are a helpful assistant."
self.openai_system_message = {
"role": "system",
"content": self.system_message,
}
self.action_space = None
self.action_schema = None
self.token_usage = 0
self.chat_history: list[list[ChatCompletionMessage | dict]] = []

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
Expand Down Expand Up @@ -88,12 +99,12 @@ def call_api(self, request_messages: list) -> ChatCompletionMessage:
messages=request_messages,
model=self.model,
tools=self.action_schema,
tool_choice="required",
tool_choice="required" if self.tool_call_required else "auto",
**self.parameters,
)
else:
response = self.client.chat.completions.create(
messages=request_messages,
messages=request_messages, # type: ignore
model=self.model,
**self.parameters,
)
Expand Down
23 changes: 13 additions & 10 deletions crab/agents/policies/multi_agent_by_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from copy import copy

from crab import Action, ActionOutput
from crab.agents.backend_models import BackendModelConfig, create_backend_model
from crab.agents.utils import generate_action_prompt
from crab.core.agent_policy import AgentPolicy
from crab.core.backend_model import (
BackendModel,
Expand Down Expand Up @@ -57,12 +57,12 @@ class MultiAgentByEnvPolicy(AgentPolicy):

def __init__(
self,
main_agent_model_backend: BackendModel,
env_agent_model_backend: BackendModel,
main_agent_model_backend: BackendModelConfig,
env_agent_model_backend: BackendModelConfig,
):
self.main_agent_model_backend = copy(main_agent_model_backend)
self.env_agent_model_backend = env_agent_model_backend
self.reset(task_description="", action_spaces=None, env_descriptions={})
self.main_agent_model_backend = create_backend_model(main_agent_model_backend)
self.env_agent_model_backend_config = env_agent_model_backend
self.reset(task_description="", action_spaces={}, env_descriptions={})

def reset(
self,
Expand All @@ -82,15 +82,16 @@ def reset(
)
self.env_agent_model_backends: dict[str, BackendModel] = {}
for env in action_spaces:
backend = copy(self.env_agent_model_backend)
backend = create_backend_model(self.env_agent_model_backend_config)
if env == "root":
backend.reset(root_agent_system_message, action_spaces[env])
else:
backend.require_tool = True
env_agent_system_message = self._env_agent_prompt.format(
task_description=task_description,
environment=env,
env_description=env_descriptions[env],
action_descriptions=self.generate_action_prompt(action_spaces[env]),
action_descriptions=generate_action_prompt(action_spaces[env]),
)
backend.reset(env_agent_system_message, action_spaces[env])
self.env_agent_model_backends[env] = backend
Expand Down Expand Up @@ -140,5 +141,7 @@ def chat(
)
else:
output = backend.chat((main_agent_message, MessageType.TEXT))
for action in output.action_list:
action.env = env
tool_calls.extend(output.action_list)
return self.decode_combined_action(tool_calls)
return tool_calls
29 changes: 15 additions & 14 deletions crab/agents/policies/multi_agent_by_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from copy import copy

from crab import Action, ActionOutput
from crab.core.agent_policy import AgentPolicy
from crab.core.backend_model import (
BackendModel,
MessageType,
from crab.agents.backend_models import BackendModelConfig, create_backend_model
from crab.agents.utils import (
combine_multi_env_action_space,
decode_combined_action,
generate_action_prompt,
)
from crab.core import Action, ActionOutput
from crab.core.agent_policy import AgentPolicy
from crab.core.backend_model import MessageType


class MultiAgentByFuncPolicy(AgentPolicy):
Expand All @@ -40,11 +41,11 @@ class MultiAgentByFuncPolicy(AgentPolicy):

def __init__(
self,
main_agent_model_backend: BackendModel,
tool_agent_model_backend: BackendModel,
main_agent_model_backend: BackendModelConfig,
tool_agent_model_backend: BackendModelConfig,
):
self.main_agent_model_backend = copy(main_agent_model_backend)
self.tool_agent_model_backend = copy(tool_agent_model_backend)
self.main_agent_model_backend = create_backend_model(main_agent_model_backend)
self.tool_agent_model_backend = create_backend_model(tool_agent_model_backend)
self.reset(task_description="", action_spaces=None, env_descriptions={})

def reset(
Expand All @@ -54,11 +55,11 @@ def reset(
env_descriptions: dict[str, str],
) -> list[ActionOutput]:
self.task_description = task_description
self.action_space = self.combine_multi_env_action_space(action_spaces)
self.action_space = combine_multi_env_action_space(action_spaces)

main_agent_system_message = self._system_prompt.format(
task_description=task_description,
action_descriptions=self.generate_action_prompt(self.action_space),
action_descriptions=generate_action_prompt(self.action_space),
env_description=str(env_descriptions),
)
self.main_agent_model_backend.reset(main_agent_system_message, None)
Expand Down Expand Up @@ -95,4 +96,4 @@ def chat(
tool_output = self.tool_agent_model_backend.chat(
(output.message, MessageType.TEXT)
)
return self.decode_combined_action(tool_output.action_list)
return decode_combined_action(tool_output.action_list)
19 changes: 11 additions & 8 deletions crab/agents/policies/single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from copy import copy

from crab import Action, ActionOutput
from crab.agents.backend_models import BackendModelConfig, create_backend_model
from crab.agents.utils import (
combine_multi_env_action_space,
decode_combined_action,
generate_action_prompt,
)
from crab.core.agent_policy import AgentPolicy
from crab.core.backend_model import (
BackendModel,
MessageType,
)
from crab.utils.measure import timed
Expand Down Expand Up @@ -46,9 +49,9 @@ class SingleAgentPolicy(AgentPolicy):

def __init__(
self,
model_backend: BackendModel,
model_backend: BackendModelConfig,
):
self.model_backend = copy(model_backend)
self.model_backend = create_backend_model(model_backend)
self.reset(task_description="", action_spaces=None, env_descriptions={})

def reset(
Expand All @@ -58,10 +61,10 @@ def reset(
env_descriptions: dict[str, str],
) -> list:
self.task_description = task_description
self.action_space = self.combine_multi_env_action_space(action_spaces)
self.action_space = combine_multi_env_action_space(action_spaces)
system_message = self._system_prompt.format(
task_description=task_description,
action_descriptions=self.generate_action_prompt(self.action_space),
action_descriptions=generate_action_prompt(self.action_space),
env_description=str(env_descriptions),
)
self.model_backend.reset(system_message, self.action_space)
Expand All @@ -87,4 +90,4 @@ def chat(
)
)
output = self.model_backend.chat(prompt)
return self.decode_combined_action(output.action_list)
return decode_combined_action(output.action_list)
Loading

0 comments on commit aff284e

Please sign in to comment.