1+ from typing import (
2+ Any ,
3+ Callable ,
4+ Optional
5+ )
6+
7+ import yaml
8+ import json
9+ from rich .panel import Panel
10+ from rich .text import Text
11+
12+ from src .tools import AsyncTool
13+ from src .exception import (
14+ AgentGenerationError ,
15+ AgentParsingError ,
16+ AgentToolExecutionError ,
17+ AgentToolCallError
18+ )
19+ from src .base .async_multistep_agent import (PromptTemplates ,
20+ populate_template ,
21+ AsyncMultiStepAgent )
22+ from src .memory import (ActionStep ,
23+ ToolCall ,
24+ AgentMemory )
25+ from src .logger import (LogLevel ,
26+ YELLOW_HEX ,
27+ logger )
28+ from src .models import Model , parse_json_if_needed , ChatMessage
29+ from src .utils .agent_types import (
30+ AgentAudio ,
31+ AgentImage ,
32+ )
33+ from src .utils import assemble_project_path
34+
35+ class BaseAgent (AsyncMultiStepAgent ):
36+ """Base class for agents with common logic."""
37+ AGENT_NAME = "base_agent" # Must be overridden by subclasses
38+
39+ def __init__ (
40+ self ,
41+ config , # Specific configuration object for the agent
42+ tools : list [AsyncTool ],
43+ model : Model ,
44+ prompt_templates_path : str , # Path to the prompt templates file
45+ prompt_templates : PromptTemplates | None = None , # For preloaded templates
46+ max_steps : int = 20 ,
47+ add_base_tools : bool = False ,
48+ verbosity_level : LogLevel = LogLevel .INFO ,
49+ grammar : dict [str , str ] | None = None ,
50+ managed_agents : list | None = None ,
51+ step_callbacks : list [Callable ] | None = None ,
52+ planning_interval : int | None = None ,
53+ name : str | None = None , # AGENT_NAME will be used if not specified
54+ description : str | None = None ,
55+ provide_run_summary : bool = False ,
56+ final_answer_checks : list [Callable ] | None = None ,
57+ ** kwargs
58+ ):
59+ self .config = config # Save config for possible access by subclasses
60+
61+ agent_name_to_use = name if name is not None else self .AGENT_NAME
62+
63+ super ().__init__ (
64+ tools = tools ,
65+ model = model ,
66+ prompt_templates = None , # Initialize as None, load later
67+ max_steps = max_steps ,
68+ add_base_tools = add_base_tools ,
69+ verbosity_level = verbosity_level ,
70+ grammar = grammar ,
71+ managed_agents = managed_agents ,
72+ step_callbacks = step_callbacks ,
73+ planning_interval = planning_interval ,
74+ name = agent_name_to_use , # Use the defined agent name
75+ description = description ,
76+ provide_run_summary = provide_run_summary ,
77+ final_answer_checks = final_answer_checks ,
78+ ** kwargs # Pass remaining arguments to the parent class
79+ )
80+
81+ # Loading prompt_templates
82+ if prompt_templates :
83+ self .prompt_templates = prompt_templates
84+ else :
85+ abs_template_path = assemble_project_path (prompt_templates_path )
86+ with open (abs_template_path , "r" , encoding = 'utf-8' ) as f :
87+ self .prompt_templates = yaml .safe_load (f )
88+
89+ self .system_prompt = self .initialize_system_prompt ()
90+ self .user_prompt = self .initialize_user_prompt ()
91+
92+ self .memory = AgentMemory (
93+ system_prompt = self .system_prompt ,
94+ user_prompt = self .user_prompt ,
95+ )
96+
97+ def initialize_system_prompt (self ) -> str :
98+ """Initialize the system prompt for the agent."""
99+ system_prompt = populate_template (
100+ self .prompt_templates ["system_prompt" ],
101+ variables = {"tools" : self .tools , "managed_agents" : self .managed_agents },
102+ )
103+ return system_prompt
104+
105+ def initialize_user_prompt (self ) -> str :
106+
107+ user_prompt = populate_template (
108+ self .prompt_templates ["user_prompt" ],
109+ variables = {},
110+ )
111+
112+ return user_prompt
113+
114+ def initialize_task_instruction (self ) -> str :
115+ """Initialize the task instruction for the agent."""
116+ task_instruction = populate_template (
117+ self .prompt_templates ["task_instruction" ],
118+ variables = {"task" : self .task },
119+ )
120+ return task_instruction
121+
122+ def _substitute_state_variables (self , arguments : dict [str , str ] | str ) -> dict [str , Any ] | str :
123+ """Replace string values in arguments with their corresponding state values if they exist."""
124+ if isinstance (arguments , dict ):
125+ return {
126+ key : self .state .get (value , value ) if isinstance (value , str ) else value
127+ for key , value in arguments .items ()
128+ }
129+ return arguments
130+
131+ async def execute_tool_call (self , tool_name : str , arguments : dict [str , str ] | str ) -> Any :
132+ """
133+ Execute a tool or managed agent with the provided arguments.
134+
135+ The arguments are replaced with the actual values from the state if they refer to state variables.
136+
137+ Args:
138+ tool_name (`str`): Name of the tool or managed agent to execute.
139+ arguments (dict[str, str] | str): Arguments passed to the tool call.
140+ """
141+ # Check if the tool exists
142+ available_tools = {** self .tools , ** self .managed_agents }
143+ if tool_name not in available_tools :
144+ raise AgentToolExecutionError (
145+ f"Unknown tool { tool_name } , should be one of: { ', ' .join (available_tools )} ." , self .logger
146+ )
147+
148+ # Get the tool and substitute state variables in arguments
149+ tool = available_tools [tool_name ]
150+ arguments = self ._substitute_state_variables (arguments )
151+ is_managed_agent = tool_name in self .managed_agents
152+
153+ try :
154+ # Call tool with appropriate arguments
155+ if isinstance (arguments , dict ):
156+ return await tool (** arguments ) if is_managed_agent else await tool (** arguments , sanitize_inputs_outputs = True )
157+ elif isinstance (arguments , str ):
158+ return await tool (arguments ) if is_managed_agent else await tool (arguments , sanitize_inputs_outputs = True )
159+ else :
160+ raise TypeError (f"Unsupported arguments type: { type (arguments )} " )
161+
162+ except TypeError as e :
163+ # Handle invalid arguments
164+ description = getattr (tool , "description" , "No description" )
165+ if is_managed_agent :
166+ error_msg = (
167+ f"Invalid request to team member '{ tool_name } ' with arguments { json .dumps (arguments , ensure_ascii = False )} : { e } \n "
168+ "You should call this team member with a valid request.\n "
169+ f"Team member description: { description } "
170+ )
171+ else :
172+ error_msg = (
173+ f"Invalid call to tool '{ tool_name } ' with arguments { json .dumps (arguments , ensure_ascii = False )} : { e } \n "
174+ "You should call this tool with correct input arguments.\n "
175+ f"Expected inputs: { json .dumps (tool .parameters )} \n "
176+ f"Returns output type: { tool .output_type } \n "
177+ f"Tool description: '{ description } '"
178+ )
179+ raise AgentToolCallError (error_msg , self .logger ) from e
180+
181+ except Exception as e :
182+ # Handle execution errors
183+ if is_managed_agent :
184+ error_msg = (
185+ f"Error executing request to team member '{ tool_name } ' with arguments { json .dumps (arguments )} : { e } \n "
186+ "Please try again or request to another team member"
187+ )
188+ else :
189+ error_msg = (
190+ f"Error executing tool '{ tool_name } ' with arguments { json .dumps (arguments )} : { type (e ).__name__ } : { e } \n "
191+ "Please try again or use another tool"
192+ )
193+ raise AgentToolExecutionError (error_msg , self .logger ) from e
194+
195+ async def step (self , memory_step : ActionStep ) -> None | Any :
196+ """
197+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
198+ Returns None if the step is not final.
199+ """
200+ memory_messages = await self .write_memory_to_messages ()
201+
202+ input_messages = memory_messages .copy ()
203+
204+ # Add new step in logs
205+ memory_step .model_input_messages = input_messages
206+
207+ try :
208+ chat_message : ChatMessage = await self .model (
209+ input_messages ,
210+ stop_sequences = ["Observation:" , "Calling tools:" ],
211+ tools_to_call_from = list (self .tools .values ()),
212+ )
213+ memory_step .model_output_message = chat_message
214+ model_output = chat_message .content
215+ self .logger .log_markdown (
216+ content = model_output if model_output else str (chat_message .raw ),
217+ title = "Output message of the LLM:" ,
218+ level = LogLevel .DEBUG ,
219+ )
220+
221+ memory_step .model_output_message .content = model_output
222+ memory_step .model_output = model_output
223+ except Exception as e :
224+ raise AgentGenerationError (f"Error while generating output:\n { e } " , self .logger ) from e
225+
226+ if chat_message .tool_calls is None or len (chat_message .tool_calls ) == 0 :
227+ try :
228+ chat_message = self .model .parse_tool_calls (chat_message )
229+ except Exception as e :
230+ raise AgentParsingError (f"Error while parsing tool call from model output: { e } " , self .logger )
231+ else :
232+ for tool_call in chat_message .tool_calls :
233+ tool_call .function .arguments = parse_json_if_needed (tool_call .function .arguments )
234+
235+ tool_call = chat_message .tool_calls [0 ]
236+ tool_name , tool_call_id = tool_call .function .name , tool_call .id
237+ tool_arguments = tool_call .function .arguments
238+ memory_step .model_output = str (f"Called Tool: '{ tool_name } ' with arguments: { tool_arguments } " )
239+ memory_step .tool_calls = [ToolCall (name = tool_name , arguments = tool_arguments , id = tool_call_id )]
240+
241+ # Execute
242+ self .logger .log (
243+ Panel (Text (f"Calling tool: '{ tool_name } ' with arguments: { tool_arguments } " )),
244+ level = LogLevel .INFO ,
245+ )
246+ if tool_name == "final_answer" :
247+ if isinstance (tool_arguments , dict ):
248+ if "result" in tool_arguments :
249+ result = tool_arguments ["result" ]
250+ else :
251+ result = tool_arguments
252+ else :
253+ result = tool_arguments
254+ if (
255+ isinstance (result , str ) and result in self .state .keys ()
256+ ): # if the answer is a state variable, return the value
257+ final_result = self .state [result ]
258+ self .logger .log (
259+ f"[bold { YELLOW_HEX } ]Final answer:[/bold { YELLOW_HEX } ] Extracting key '{ result } ' from state to return value '{ final_result } '." ,
260+ level = LogLevel .INFO ,
261+ )
262+ else :
263+ final_result = result
264+ self .logger .log (
265+ Text (f"Final result: { final_result } " , style = f"bold { YELLOW_HEX } " ),
266+ level = LogLevel .INFO ,
267+ )
268+
269+ memory_step .action_output = final_result
270+ return final_result
271+ else :
272+ if tool_arguments is None :
273+ tool_arguments = {}
274+ observation = await self .execute_tool_call (tool_name , tool_arguments )
275+ observation_type = type (observation )
276+ if observation_type in [AgentImage , AgentAudio ]:
277+ if observation_type == AgentImage :
278+ observation_name = "image.png"
279+ elif observation_type == AgentAudio :
280+ observation_name = "audio.mp3"
281+ # TODO: observation naming could allow for different names of same type
282+
283+ self .state [observation_name ] = observation
284+ updated_information = f"Stored '{ observation_name } ' in memory."
285+ else :
286+ updated_information = str (observation ).strip ()
287+ self .logger .log (
288+ f"Observations: { updated_information .replace ('[' , '|' )} " , # escape potential rich-tag-like components
289+ level = LogLevel .INFO ,
290+ )
291+ memory_step .observations = updated_information
292+ return None
0 commit comments