Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Multitool envs #214

Draft
wants to merge 2 commits into
base: tau-bench
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ pytest==7.3.2
flaky
pytest-xdist
pytest-playwright
pydantic~=2.9
dask
distributed
browsergym>=0.7.1
joblib>=1.2.0
openai>=1.7,<2
langchain_community
tiktoken
tapeagents[converters]~=0.1.4
huggingface_hub
contexttimer
ipython
Expand Down
2 changes: 1 addition & 1 deletion src/agentlab/agents/generic_agent/reproducibility_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
This module contains the classes and functions to reproduce the results of a
study. It is used to create a new study that will run the same experiments as
the original study, but with a reproducibility agent that will mimic the same
answers as the original agent.
answers as the original agent.

Stats are collected to compare the original agent's answers with the new agent's
answers. Load the this reproducibility study in agent-xray to compare the results.
Expand Down
8 changes: 6 additions & 2 deletions src/agentlab/benchmarks/abstract_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
from abc import ABC, abstractmethod

import gym


class AbstractEnvArgs(ABC):
"""Easily serialiazable class to store the arguments of an environment"""
Expand All @@ -21,7 +22,6 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":


class AbstractEnv(gym.Env, ABC):

@abstractmethod
def reset(self, seed: int = None) -> tuple[dict[str, any], dict[str, any]]:
"""Reset the environment to the initial state, ready for an agent to start a new episode.
Expand Down Expand Up @@ -57,3 +57,7 @@ def step(self, action: str):
@abstractmethod
def close(self):
"""Close any resources used by the environment"""

@abstractmethod
def calculate_reward(self) -> float:
"""Calculate the reward obtained in the last step"""
76 changes: 76 additions & 0 deletions src/agentlab/benchmarks/multitool_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, Literal, Union

from pydantic import Annotated, Field, TypeAdapter
from tapeagents.core import Action, Observation, Tape
from tapeagents.environment import ToolCollectionEnvironment
from tapeagents.tools.base import Multitool, Tool

from agentlab.benchmarks.abstract_env import AbstractEnv

EnvTape = Tape[None, Action | Observation]


class FunctionCall(Action):
kind: Literal["function_call_action"] = ["function_call_action"]
function_name: str
args: list[Any] | None
kwargs: dict[str, Any] | None


class FunctionCallResult(Observation):
kind: Literal["function_call_result"] = ["function_call_result"]
result: Any


class SimpleFunctionCallTool(Tool):
action = FunctionCall
observation = FunctionCallResult
function: callable
function_name: str = ""

def model_post_init(self, __context):
function_name = getattr(self.function, "__name__", "")
if not function_name and not self.function_name:
raise ValueError("Function has no name, function_name must be provided")

def execute_action(self, action: FunctionCall) -> FunctionCallResult:
if not self.function_name == action.function_name:
raise ValueError(
f"Unexpected function action {action.function_name}, expected {self.function_name}"
)
result = self.function(*action.args, **action.kwargs)
return FunctionCallResult(result=result)


class MultiToolGym(AbstractEnv):
def __init__(self, tools: list[Tool | Multitool]):
self._env = ToolCollectionEnvironment(tools)
self._actions = self._env.actions()
self._actions_parser: TypeAdapter = TypeAdapter(
Annotated[Union[self._actions], Field(discriminator="kind")]
)
self.reset()

def reset(self, seed=None):
self._tape: EnvTape = EnvTape(steps=[])

def step(self, action: str):
try:
action_step = self._actions_parser.validate_json(action)
except Exception:
raise ValueError("Action must be a valid JSON dict")
assert isinstance(action_step, Action), "{action_step.kind} is not an Action"
self._tape += [action_step]
self._tape = self._env.react(self._tape)
observation_step: Observation = self._tape.steps[-1]
reward = self.calculate_reward()
terminated = False
truncated = False
env_info = {"step_metadata": observation_step.metadata}
return observation_step.llm_dict(), reward, terminated, truncated, env_info

def calculate_reward(self) -> float:
return 0.0

def close(self):
self._env.close()