diff --git a/examples/gaia_agent/agent.py b/examples/gaia_agent/agent.py index 9c7d8c8..c1af5d2 100644 --- a/examples/gaia_agent/agent.py +++ b/examples/gaia_agent/agent.py @@ -3,7 +3,8 @@ from typing import Any from tapeagents.core import Step -from tapeagents.guided_agent import GuidedAgent +from tapeagents.guided_agent import GuidanceNode, GuidedAgent +from tapeagents.llms import LLM from .prompts import TEMPLATES, PromptRegistry from .steps import ( @@ -32,37 +33,96 @@ class PlanningMode(str, Enum): replan_after_sources = "replan_after_sources" +class GaiaNode(GuidanceNode): + def get_steps_description(self, tape: GaiaTape, agent: Any) -> str: + """ + Allow different subset of steps based on the agent's configuration + """ + add_plan_thoughts = not tape.has_fact_schemas() + allowed_steps = get_allowed_steps(agent.short_steps, agent.subtasks, add_plan_thoughts) + return self.steps_prompt.format(allowed_steps=allowed_steps) + + def prepare_tape(self, tape: GaiaTape, max_chars: int = 200) -> GaiaTape: + """ + Trim long observations except for the last 3 steps + """ + steps = [] + for step in tape.steps[:-3]: + if isinstance(step, PageObservation): + short_text = f"{step.text[:max_chars]}\n..." if len(step.text) > max_chars else step.text + new_step = step.model_copy(update=dict(text=short_text)) + elif isinstance(step, ActionExecutionFailure): + short_error = f"{step.error[:max_chars]}\n..." if len(step.error) > max_chars else step.error + new_step = step.model_copy(update=dict(error=short_error)) + else: + new_step = step + steps.append(new_step) + trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[-3:])) + return trimmed_tape + + def postprocess_step(self, tape: GaiaTape, new_steps: list[Step], step: Step) -> Step: + if isinstance(step, ListOfFactsThought): + # remove empty facts produced by the model + step.given_facts = [fact for fact in step.given_facts if fact.value is not None and fact.value != ""] + elif isinstance(step, (UseCalculatorAction, PythonCodeAction)): + # if calculator or code action is used, add the facts to the action call + step.facts = tape.model_copy(update=dict(steps=tape.steps + new_steps)).facts() + return step + + class GaiaAgent(GuidedAgent): - max_iterations: int = 2 - planning_mode: str = PlanningMode.simple - short_steps: bool = False - subtasks: bool = False - _start_step_cls: Any = GaiaQuestion - _agent_step_cls: Any = GaiaAgentStep + short_steps: bool + subtasks: bool - def prepare_guidance_prompts(self): + @classmethod + def create( + cls, + llm: LLM, + planning_mode: PlanningMode = PlanningMode.simple, + subtasks: bool = False, + short_steps: bool = False, + ): + guidance_prompts = cls.prepare_guidance(planning_mode, subtasks) + return super().create( + llm, + nodes=[GaiaNode(name=kind, guidance=guidance) for kind, guidance in guidance_prompts.items()], + system_prompt=PromptRegistry.system_prompt, + steps_prompt=PromptRegistry.allowed_steps, + start_step_cls=GaiaQuestion, + agent_step_cls=GaiaAgentStep, + max_iterations=2, + templates=TEMPLATES, + subtasks=subtasks, + short_steps=short_steps, + ) + + @classmethod + def prepare_guidance(cls, planning_mode: PlanningMode, subtasks: bool) -> dict[str, str]: + """ + Prepare guidance prompts based on the planning mode and subtasks flag + """ guidance_prompts = {} - if self.planning_mode == PlanningMode.simple: + if planning_mode == PlanningMode.simple: guidance_prompts = { "question": PromptRegistry.plan, "plan_thought": PromptRegistry.facts_survey, "list_of_facts_thought": PromptRegistry.start_execution, } - elif self.planning_mode == PlanningMode.facts_and_sources: + elif planning_mode == PlanningMode.facts_and_sources: guidance_prompts = { "question": PromptRegistry.plan, "draft_plans_thought": PromptRegistry.facts_survey, "list_of_facts_thought": PromptRegistry.sources_plan, "sources_thought": PromptRegistry.start_execution, } - elif self.planning_mode == PlanningMode.multiplan: + elif planning_mode == PlanningMode.multiplan: guidance_prompts = { "question": PromptRegistry.plan3, "draft_plans_thought": PromptRegistry.facts_survey, "list_of_facts_thought": PromptRegistry.sources_plan, "sources_thought": PromptRegistry.start_execution, } - elif self.planning_mode == PlanningMode.replan_after_sources: + elif planning_mode == PlanningMode.replan_after_sources: guidance_prompts = { "question": PromptRegistry.plan3, "draft_plans_thought": PromptRegistry.facts_survey, @@ -70,33 +130,12 @@ def prepare_guidance_prompts(self): "sources_thought": PromptRegistry.better_plan, "plan_thought": PromptRegistry.start_execution, } - if self.subtasks: + else: + raise ValueError(f"Unknown planning mode: {planning_mode}") + if subtasks: guidance_prompts["calculation_result_observation"] = PromptRegistry.is_subtask_finished - self.templates = TEMPLATES | guidance_prompts - - def get_steps_description(self, tape: GaiaTape) -> str: - add_plan_thoughts = not tape.has_fact_schemas() - allowed_steps = get_allowed_steps(self.short_steps, self.subtasks, add_plan_thoughts) - return self.templates["allowed_steps"].format(allowed_steps=allowed_steps) - - def prepare_tape(self, tape: GaiaTape, max_chars: int = 200) -> GaiaTape: - """ - Trim all read results except in the last 3 steps - """ - self.prepare_guidance_prompts() - steps = [] - for step in tape.steps[:-3]: - if isinstance(step, PageObservation): - short_text = f"{step.text[:max_chars]}\n..." if len(step.text) > max_chars else step.text - new_step = step.model_copy(update=dict(text=short_text)) - elif isinstance(step, ActionExecutionFailure): - short_error = f"{step.error[:max_chars]}\n..." if len(step.error) > max_chars else step.error - new_step = step.model_copy(update=dict(error=short_error)) - else: - new_step = step - steps.append(new_step) - trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[-3:])) - return trimmed_tape + guidance_prompts["_default"] = "" + return guidance_prompts def trim_tape(self, tape: GaiaTape) -> GaiaTape: """ @@ -117,12 +156,3 @@ def trim_tape(self, tape: GaiaTape) -> GaiaTape: short_tape.steps.append(step) logger.info(f"Tape reduced from {len(tape)} to {len(short_tape)} steps") return short_tape - - def postprocess_step(self, tape: GaiaTape, new_steps: list[Step], step: Step) -> Step: - if isinstance(step, ListOfFactsThought): - # remove empty facts produced by the model - step.given_facts = [fact for fact in step.given_facts if fact.value is not None and fact.value != ""] - elif isinstance(step, (UseCalculatorAction, PythonCodeAction)): - # if calculator or code action is used, add the facts to the action call - step.facts = tape.model_copy(update=dict(steps=tape.steps + new_steps)).facts() - return step diff --git a/examples/gaia_agent/scripts/evaluate.py b/examples/gaia_agent/scripts/evaluate.py index 4f4038f..3dfeb5e 100644 --- a/examples/gaia_agent/scripts/evaluate.py +++ b/examples/gaia_agent/scripts/evaluate.py @@ -46,7 +46,7 @@ def main(cfg: DictConfig) -> None: "yellow", ) ) - agent = GaiaAgent(llms={"default": llm}, **cfg.agent) + agent = GaiaAgent.create(llm, **cfg.agent) tasks = load_dataset(cfg.data_dir) if cfg.task_id is not None: diff --git a/examples/gaia_agent/scripts/replay.py b/examples/gaia_agent/scripts/replay.py index e34de33..cb2e076 100644 --- a/examples/gaia_agent/scripts/replay.py +++ b/examples/gaia_agent/scripts/replay.py @@ -27,8 +27,7 @@ def main(fname: str, dataset_path: str = ""): prompts = results.prompts llm_calls = [ - LLMCall(prompt=Prompt.model_validate(prompt), output=LLMOutput(), cached=False) - for prompt in results.prompts + LLMCall(prompt=Prompt.model_validate(prompt), output=LLMOutput(), cached=False) for prompt in results.prompts ] model_name = results.model params = results.llm_config @@ -43,7 +42,7 @@ def main(fname: str, dataset_path: str = ""): ) env = GaiaEnvironment(only_cached_webpages=True) env.browser._cache = results.web_cache - agent = GaiaAgent(llms={"default": llm}, short_steps=True) + agent = GaiaAgent.create(llm, short_steps=True) logger.info(f"Web Cache {len(results.web_cache)}") logger.info(f"Prompts {len(prompts)}") diff --git a/examples/workarena/agent.py b/examples/workarena/agent.py index 6f6915b..8e24773 100644 --- a/examples/workarena/agent.py +++ b/examples/workarena/agent.py @@ -3,7 +3,8 @@ from tapeagents.core import Prompt from tapeagents.dialog_tape import SystemStep, UserStep -from tapeagents.guided_agent import GuidedAgent +from tapeagents.guided_agent import GuidanceNode, GuidedAgent +from tapeagents.llms import LLM from tapeagents.utils import get_step_schemas_from_union_type from .prompts import TEMPLATES, PromptRegistry @@ -17,7 +18,7 @@ ) -class WorkArenaBaseline(GuidedAgent): +class WorkArenaBaselineNode(GuidanceNode): """ Agent that is close to the original workarena one. Implemented features (best feature set for gpt4o from workarena paper): @@ -30,10 +31,7 @@ class WorkArenaBaseline(GuidedAgent): - long_description """ - _start_step_cls: Any = PageObservation - _agent_step_cls: Any = WorkArenaAgentStep - - def make_prompt(self, tape: WorkArenaTape) -> Prompt: + def make_prompt(self, agent: Any, tape: WorkArenaTape) -> Prompt: assert isinstance(tape.steps[1], WorkArenaTask) goal = PromptRegistry.goal_instructions.format(goal=tape.steps[1].task) obs = [s for s in tape if isinstance(s, PageObservation)][-1].text @@ -42,10 +40,7 @@ def make_prompt(self, tape: WorkArenaTape) -> Prompt: allowed_steps=get_step_schemas_from_union_type(WorkArenaBaselineStep) ) mac_hint = PromptRegistry.mac_hint if platform.system() == "Darwin" else "" - main_prompt = f"""\ -{goal} -{obs} -{history} + main_prompt = f"""{goal}\n{obs}\n{history} {PromptRegistry.baseline_steps_prompt}{allowed_steps}{mac_hint} {PromptRegistry.hints} {PromptRegistry.be_cautious} @@ -75,17 +70,17 @@ def history_prompt(self, tape: WorkArenaTape) -> str: return prompt -class WorkArenaAgent(GuidedAgent): - _start_step_cls: Any = PageObservation - _agent_step_cls: Any = WorkArenaAgentStep - templates: dict[str, str] = TEMPLATES +class WorkArenaBaseline(GuidedAgent): + @classmethod + def create(cls, llm: LLM): + return cls(llms={"default": llm}, nodes=[WorkArenaBaselineNode()]) # type: ignore - def get_steps_description(self, tape) -> str: - return self.templates["allowed_steps"].format( - allowed_steps=get_step_schemas_from_union_type(WorkArenaAgentStep) - ) - def prepare_tape(self, tape, max_chars: int = 100): +class WorkArenaNode(GuidanceNode): + def get_steps_description(self, tape: WorkArenaTape) -> str: + return self.steps_prompt.format(allowed_steps=get_step_schemas_from_union_type(WorkArenaAgentStep)) + + def prepare_tape(self, tape: WorkArenaTape, max_chars: int = 100): """ Trim all page observations except the last two. """ @@ -103,3 +98,21 @@ def prepare_tape(self, tape, max_chars: int = 100): steps.append(new_step) trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[prev_page_position:])) return trimmed_tape + + +class WorkArenaAgent(GuidedAgent): + @classmethod + def create(cls, llm: LLM): + return super().create( + llm, + nodes=[ + WorkArenaNode(name="task", guidance=PromptRegistry.start), + WorkArenaNode(name="reflection_thought", guidance=PromptRegistry.act), + WorkArenaNode(name="default", guidance=PromptRegistry.think), + ], + system_prompt=PromptRegistry.system_prompt, + steps_prompt=PromptRegistry.allowed_steps, + start_step_cls=PageObservation, + agent_step_cls=WorkArenaAgentStep, + max_iterations=2, + ) diff --git a/tapeagents/guided_agent.py b/tapeagents/guided_agent.py index 535beeb..4be859a 100644 --- a/tapeagents/guided_agent.py +++ b/tapeagents/guided_agent.py @@ -3,8 +3,9 @@ from typing import Any, Generator, Generic from pydantic import TypeAdapter, ValidationError +from typing_extensions import Self -from .agent import Agent +from .agent import Agent, Node from .core import ( AgentResponseParsingFailureAction, AgentStep, @@ -15,68 +16,61 @@ Tape, TapeType, ) -from .llms import LLMStream +from .llms import LLM, LLMStream from .utils import FatalError, sanitize_json_completion logger = logging.getLogger(__name__) -class GuidedAgent(Agent, Generic[TapeType]): +class GuidanceNode(Node): """ - Generic agent class which renders all tape steps into prompt and parses the llm completion into a sequence of steps. - Main features: - - validates that the tape starts with a specific step class. - - attaches a guidance prompt text to the end of the prompt after rendering the tape. - - selects guidance based on the kind of the last step in the tape from the templates dictionary. - - trims the tape if the total token count exceeds the context size. + A node for the guided agent. + Validates that the tape starts with a specific step class. + Attaches a guidance text to the end of the prompt after rendering the tape. + Parses the llm output into provided step classes (class provided in a form of annotated union). + Trims the tape if needed. """ - _start_step_cls: Any - _agent_step_cls: Any - templates: dict[str, str] = {} - max_iterations: int = 2 - - def delegate(self, tape: TapeType): - return self - - def get_steps_description(self, tape) -> str: - return self.templates["allowed_steps"] - - def make_prompt(self, tape: Tape) -> Prompt: - assert isinstance(tape.steps[0], self._start_step_cls) + guidance: str + system_prompt: str = "" + steps_prompt: str = "" + agent_step_cls: Any = None + start_step_cls: Any = None + def make_prompt(self, agent: Any, tape: Tape) -> Prompt: + assert isinstance(tape.steps[0], self.start_step_cls) cleaned_tape = self.prepare_tape(tape) - messages = self.tape_to_messages(cleaned_tape) - if self.llm.count_tokens(messages) > (self.llm.context_size - 500): - cleaned_tape = self.trim_tape(cleaned_tape) - messages = self.tape_to_messages(cleaned_tape) + steps_description = self.get_steps_description(tape, agent) + messages = self.tape_to_messages(cleaned_tape, steps_description) + if agent.llm.count_tokens(messages) > (agent.llm.context_size - 500): + cleaned_tape = agent.trim_tape(cleaned_tape) + messages = self.tape_to_messages(cleaned_tape, steps_description) return Prompt(messages=messages) - def make_llm_output(self, tape: TapeType, index: int) -> LLMOutput: - return LLMOutput(role="assistant", content=tape.steps[index].llm_view()) - def prepare_tape(self, tape: Tape) -> Tape: return tape - def trim_tape(self, tape: Tape) -> Tape: - return tape + def make_llm_output(self, tape: Tape, index: int) -> LLMOutput: + return LLMOutput(role="assistant", content=tape.steps[index].llm_view()) - def tape_to_messages(self, tape: Tape) -> list[dict]: + def tape_to_messages(self, tape: Tape, steps_description: str) -> list[dict]: messages: list[dict] = [ - {"role": "system", "content": self.templates["system_prompt"]}, - {"role": "user", "content": self.get_steps_description(tape)}, + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": steps_description}, ] for step in tape: role = "assistant" if isinstance(step, AgentStep) else "user" messages.append({"role": role, "content": step.llm_view()}) - if tape.steps[-1].kind in self.templates: - guidance = self.templates[tape.steps[-1].kind] - messages.append({"role": "user", "content": guidance}) - elif "_default" in self.templates: - messages.append({"role": "user", "content": self.templates["_default"]}) + if self.guidance: + messages.append({"role": "user", "content": self.guidance}) return messages - def generate_steps(self, tape: Tape, llm_stream: LLMStream) -> Generator[Step | PartialStep, None, None]: + def get_steps_description(self, tape: Tape, agent: Any) -> str: + return self.steps_prompt + + def generate_steps( + self, agent: Any, tape: Tape, llm_stream: LLMStream + ) -> Generator[Step | PartialStep, None, None]: new_steps = [] try: cnt = 0 @@ -106,7 +100,7 @@ def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, N yield AgentResponseParsingFailureAction(error=f"Failed to parse agent output: {completion}\n\nError: {e}") return try: - steps = [TypeAdapter(self._agent_step_cls).validate_python(step_dict) for step_dict in step_dicts] + steps = [TypeAdapter(self.agent_step_cls).validate_python(step_dict) for step_dict in step_dicts] except ValidationError as e: err_text = "" for err in e.errors(): @@ -126,3 +120,51 @@ def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, N for step in steps: step.metadata.prompt_id = prompt_id yield step + + +class GuidedAgent(Agent, Generic[TapeType]): + """ + Generic agent class which renders all tape steps into prompt and parses the llm completion into a sequence of steps. + Main features: + - selects guidance node based on the kind of the last step in the tape. + - selected node does the following: + - validates that the tape starts with a specific step class. + - attaches a guidance prompt text to the end of the prompt after rendering the tape. + - trims the tape if the total token count exceeds the context size. + """ + + nodes: list[GuidanceNode] # type: ignore + + def select_node(self, tape: TapeType) -> Node: + last_kind = tape.steps[-1].kind + for node in self.nodes: + if last_kind == node.name: + return node + return self.nodes[-1] # default to the last node + + @classmethod + def create( + cls, + llm: LLM, + nodes: list[GuidanceNode], + system_prompt: str, + steps_prompt: str, + start_step_cls: Any, + agent_step_cls: Any, + **kwargs, + ) -> Self: + prepared_nodes = [] + for node in nodes: + # set common default values + node.system_prompt = node.system_prompt or system_prompt + node.steps_prompt = node.steps_prompt or steps_prompt + node.start_step_cls = node.start_step_cls or start_step_cls + node.agent_step_cls = node.agent_step_cls or agent_step_cls + prepared_nodes.append(node) + return super().create(llm, nodes=prepared_nodes, **kwargs) + + def delegate(self, tape: TapeType): + return self + + def trim_tape(self, tape: Tape) -> Tape: + return tape diff --git a/tests/test_examples.py b/tests/test_examples.py index 9e212ae..6e9cbfd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -128,7 +128,7 @@ def test_gaia_agent(): llm = ReplayLLM(llm_calls=[LLMCall.model_validate(p) for p in results.prompts], model_name=results.model) env = GaiaEnvironment(only_cached_webpages=True, safe_calculator=False) env.browser.set_web_cache(results.web_cache) - agent = GaiaAgent(llms={"default": llm}, short_steps=True) + agent = GaiaAgent.create(llm, short_steps=True) tapes = [GaiaTape.model_validate(tape) for tape in results.tapes] logger.info(f"Validate {len(tapes)} tapes")