Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guided agent nodes #42

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 77 additions & 47 deletions examples/gaia_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Any

from tapeagents.core import Step
from tapeagents.guided_agent import GuidedAgent
from tapeagents.guided_agent import GuidanceNode, GuidedAgent
from tapeagents.llms import LLM

from .prompts import TEMPLATES, PromptRegistry
from .steps import (
Expand Down Expand Up @@ -32,71 +33,109 @@ class PlanningMode(str, Enum):
replan_after_sources = "replan_after_sources"


class GaiaNode(GuidanceNode):
def get_steps_description(self, tape: GaiaTape, agent: Any) -> str:
"""
Allow different subset of steps based on the agent's configuration
"""
add_plan_thoughts = not tape.has_fact_schemas()
allowed_steps = get_allowed_steps(agent.short_steps, agent.subtasks, add_plan_thoughts)
return self.steps_prompt.format(allowed_steps=allowed_steps)

def prepare_tape(self, tape: GaiaTape, max_chars: int = 200) -> GaiaTape:
"""
Trim long observations except for the last 3 steps
"""
steps = []
for step in tape.steps[:-3]:
if isinstance(step, PageObservation):
short_text = f"{step.text[:max_chars]}\n..." if len(step.text) > max_chars else step.text
new_step = step.model_copy(update=dict(text=short_text))
elif isinstance(step, ActionExecutionFailure):
short_error = f"{step.error[:max_chars]}\n..." if len(step.error) > max_chars else step.error
new_step = step.model_copy(update=dict(error=short_error))
else:
new_step = step
steps.append(new_step)
trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[-3:]))
return trimmed_tape

def postprocess_step(self, tape: GaiaTape, new_steps: list[Step], step: Step) -> Step:
if isinstance(step, ListOfFactsThought):
# remove empty facts produced by the model
step.given_facts = [fact for fact in step.given_facts if fact.value is not None and fact.value != ""]
elif isinstance(step, (UseCalculatorAction, PythonCodeAction)):
# if calculator or code action is used, add the facts to the action call
step.facts = tape.model_copy(update=dict(steps=tape.steps + new_steps)).facts()
return step


class GaiaAgent(GuidedAgent):
max_iterations: int = 2
planning_mode: str = PlanningMode.simple
short_steps: bool = False
subtasks: bool = False
_start_step_cls: Any = GaiaQuestion
_agent_step_cls: Any = GaiaAgentStep
short_steps: bool
subtasks: bool

def prepare_guidance_prompts(self):
@classmethod
def create(
cls,
llm: LLM,
planning_mode: PlanningMode = PlanningMode.simple,
subtasks: bool = False,
short_steps: bool = False,
):
guidance_prompts = cls.prepare_guidance(planning_mode, subtasks)
return super().create(
llm,
nodes=[GaiaNode(name=kind, guidance=guidance) for kind, guidance in guidance_prompts.items()],
system_prompt=PromptRegistry.system_prompt,
steps_prompt=PromptRegistry.allowed_steps,
start_step_cls=GaiaQuestion,
agent_step_cls=GaiaAgentStep,
max_iterations=2,
templates=TEMPLATES,
subtasks=subtasks,
short_steps=short_steps,
)

@classmethod
def prepare_guidance(cls, planning_mode: PlanningMode, subtasks: bool) -> dict[str, str]:
"""
Prepare guidance prompts based on the planning mode and subtasks flag
"""
guidance_prompts = {}
if self.planning_mode == PlanningMode.simple:
if planning_mode == PlanningMode.simple:
guidance_prompts = {
"question": PromptRegistry.plan,
"plan_thought": PromptRegistry.facts_survey,
"list_of_facts_thought": PromptRegistry.start_execution,
}
elif self.planning_mode == PlanningMode.facts_and_sources:
elif planning_mode == PlanningMode.facts_and_sources:
guidance_prompts = {
"question": PromptRegistry.plan,
"draft_plans_thought": PromptRegistry.facts_survey,
"list_of_facts_thought": PromptRegistry.sources_plan,
"sources_thought": PromptRegistry.start_execution,
}
elif self.planning_mode == PlanningMode.multiplan:
elif planning_mode == PlanningMode.multiplan:
guidance_prompts = {
"question": PromptRegistry.plan3,
"draft_plans_thought": PromptRegistry.facts_survey,
"list_of_facts_thought": PromptRegistry.sources_plan,
"sources_thought": PromptRegistry.start_execution,
}
elif self.planning_mode == PlanningMode.replan_after_sources:
elif planning_mode == PlanningMode.replan_after_sources:
guidance_prompts = {
"question": PromptRegistry.plan3,
"draft_plans_thought": PromptRegistry.facts_survey,
"list_of_facts_thought": PromptRegistry.sources_plan,
"sources_thought": PromptRegistry.better_plan,
"plan_thought": PromptRegistry.start_execution,
}
if self.subtasks:
else:
raise ValueError(f"Unknown planning mode: {planning_mode}")
if subtasks:
guidance_prompts["calculation_result_observation"] = PromptRegistry.is_subtask_finished
self.templates = TEMPLATES | guidance_prompts

def get_steps_description(self, tape: GaiaTape) -> str:
add_plan_thoughts = not tape.has_fact_schemas()
allowed_steps = get_allowed_steps(self.short_steps, self.subtasks, add_plan_thoughts)
return self.templates["allowed_steps"].format(allowed_steps=allowed_steps)

def prepare_tape(self, tape: GaiaTape, max_chars: int = 200) -> GaiaTape:
"""
Trim all read results except in the last 3 steps
"""
self.prepare_guidance_prompts()
steps = []
for step in tape.steps[:-3]:
if isinstance(step, PageObservation):
short_text = f"{step.text[:max_chars]}\n..." if len(step.text) > max_chars else step.text
new_step = step.model_copy(update=dict(text=short_text))
elif isinstance(step, ActionExecutionFailure):
short_error = f"{step.error[:max_chars]}\n..." if len(step.error) > max_chars else step.error
new_step = step.model_copy(update=dict(error=short_error))
else:
new_step = step
steps.append(new_step)
trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[-3:]))
return trimmed_tape
guidance_prompts["_default"] = ""
return guidance_prompts

def trim_tape(self, tape: GaiaTape) -> GaiaTape:
"""
Expand All @@ -117,12 +156,3 @@ def trim_tape(self, tape: GaiaTape) -> GaiaTape:
short_tape.steps.append(step)
logger.info(f"Tape reduced from {len(tape)} to {len(short_tape)} steps")
return short_tape

def postprocess_step(self, tape: GaiaTape, new_steps: list[Step], step: Step) -> Step:
if isinstance(step, ListOfFactsThought):
# remove empty facts produced by the model
step.given_facts = [fact for fact in step.given_facts if fact.value is not None and fact.value != ""]
elif isinstance(step, (UseCalculatorAction, PythonCodeAction)):
# if calculator or code action is used, add the facts to the action call
step.facts = tape.model_copy(update=dict(steps=tape.steps + new_steps)).facts()
return step
2 changes: 1 addition & 1 deletion examples/gaia_agent/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main(cfg: DictConfig) -> None:
"yellow",
)
)
agent = GaiaAgent(llms={"default": llm}, **cfg.agent)
agent = GaiaAgent.create(llm, **cfg.agent)
tasks = load_dataset(cfg.data_dir)

if cfg.task_id is not None:
Expand Down
5 changes: 2 additions & 3 deletions examples/gaia_agent/scripts/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def main(fname: str, dataset_path: str = ""):

prompts = results.prompts
llm_calls = [
LLMCall(prompt=Prompt.model_validate(prompt), output=LLMOutput(), cached=False)
for prompt in results.prompts
LLMCall(prompt=Prompt.model_validate(prompt), output=LLMOutput(), cached=False) for prompt in results.prompts
]
model_name = results.model
params = results.llm_config
Expand All @@ -43,7 +42,7 @@ def main(fname: str, dataset_path: str = ""):
)
env = GaiaEnvironment(only_cached_webpages=True)
env.browser._cache = results.web_cache
agent = GaiaAgent(llms={"default": llm}, short_steps=True)
agent = GaiaAgent.create(llm, short_steps=True)

logger.info(f"Web Cache {len(results.web_cache)}")
logger.info(f"Prompts {len(prompts)}")
Expand Down
51 changes: 32 additions & 19 deletions examples/workarena/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from tapeagents.core import Prompt
from tapeagents.dialog_tape import SystemStep, UserStep
from tapeagents.guided_agent import GuidedAgent
from tapeagents.guided_agent import GuidanceNode, GuidedAgent
from tapeagents.llms import LLM
from tapeagents.utils import get_step_schemas_from_union_type

from .prompts import TEMPLATES, PromptRegistry
Expand All @@ -17,7 +18,7 @@
)


class WorkArenaBaseline(GuidedAgent):
class WorkArenaBaselineNode(GuidanceNode):
"""
Agent that is close to the original workarena one.
Implemented features (best feature set for gpt4o from workarena paper):
Expand All @@ -30,10 +31,7 @@ class WorkArenaBaseline(GuidedAgent):
- long_description
"""

_start_step_cls: Any = PageObservation
_agent_step_cls: Any = WorkArenaAgentStep

def make_prompt(self, tape: WorkArenaTape) -> Prompt:
def make_prompt(self, agent: Any, tape: WorkArenaTape) -> Prompt:
assert isinstance(tape.steps[1], WorkArenaTask)
goal = PromptRegistry.goal_instructions.format(goal=tape.steps[1].task)
obs = [s for s in tape if isinstance(s, PageObservation)][-1].text
Expand All @@ -42,10 +40,7 @@ def make_prompt(self, tape: WorkArenaTape) -> Prompt:
allowed_steps=get_step_schemas_from_union_type(WorkArenaBaselineStep)
)
mac_hint = PromptRegistry.mac_hint if platform.system() == "Darwin" else ""
main_prompt = f"""\
{goal}
{obs}
{history}
main_prompt = f"""{goal}\n{obs}\n{history}
{PromptRegistry.baseline_steps_prompt}{allowed_steps}{mac_hint}
{PromptRegistry.hints}
{PromptRegistry.be_cautious}
Expand Down Expand Up @@ -75,17 +70,17 @@ def history_prompt(self, tape: WorkArenaTape) -> str:
return prompt


class WorkArenaAgent(GuidedAgent):
_start_step_cls: Any = PageObservation
_agent_step_cls: Any = WorkArenaAgentStep
templates: dict[str, str] = TEMPLATES
class WorkArenaBaseline(GuidedAgent):
@classmethod
def create(cls, llm: LLM):
return cls(llms={"default": llm}, nodes=[WorkArenaBaselineNode()]) # type: ignore

def get_steps_description(self, tape) -> str:
return self.templates["allowed_steps"].format(
allowed_steps=get_step_schemas_from_union_type(WorkArenaAgentStep)
)

def prepare_tape(self, tape, max_chars: int = 100):
class WorkArenaNode(GuidanceNode):
def get_steps_description(self, tape: WorkArenaTape) -> str:
return self.steps_prompt.format(allowed_steps=get_step_schemas_from_union_type(WorkArenaAgentStep))

def prepare_tape(self, tape: WorkArenaTape, max_chars: int = 100):
"""
Trim all page observations except the last two.
"""
Expand All @@ -103,3 +98,21 @@ def prepare_tape(self, tape, max_chars: int = 100):
steps.append(new_step)
trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[prev_page_position:]))
return trimmed_tape


class WorkArenaAgent(GuidedAgent):
@classmethod
def create(cls, llm: LLM):
return super().create(
llm,
nodes=[
WorkArenaNode(name="task", guidance=PromptRegistry.start),
WorkArenaNode(name="reflection_thought", guidance=PromptRegistry.act),
WorkArenaNode(name="default", guidance=PromptRegistry.think),
],
system_prompt=PromptRegistry.system_prompt,
steps_prompt=PromptRegistry.allowed_steps,
start_step_cls=PageObservation,
agent_step_cls=WorkArenaAgentStep,
max_iterations=2,
)
Loading