Skip to content

Commit

Permalink
More flexible verbosity level
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jan 10, 2025
1 parent 6743d01 commit 667e378
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 69 deletions.
27 changes: 6 additions & 21 deletions examples/benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -172,7 +172,7 @@
"[132 rows x 4 columns]"
]
},
"execution_count": 21,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -195,7 +195,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -398,23 +398,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 27836.90it/s]\n",
" 16%|█▌ | 21/132 [02:18<07:35, 4.11s/it]"
]
}
],
"outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
Expand All @@ -423,6 +407,7 @@
" \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
" \"meta-llama/Llama-3.1-8B-Instruct\",\n",
" \"mistralai/Mistral-Nemo-Instruct-2407\",\n",
" # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n",
Expand Down Expand Up @@ -1010,7 +995,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "test",
"display_name": "compare-agents",
"language": "python",
"name": "python3"
},
Expand Down
119 changes: 75 additions & 44 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from enum import IntEnum
from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
from rich.console import Console

from .default_tools import FinalAnswerTool, TOOL_MAPPING
from .e2b_executor import E2BExecutor
Expand Down Expand Up @@ -164,6 +166,22 @@ def format_prompt_with_managed_agents_descriptions(
YELLOW_HEX = "#d4b702"


class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output


class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()

def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
console.print(*args, **kwargs)


class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
Expand All @@ -179,7 +197,7 @@ def __init__(
max_steps: int = 6,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbose: bool = False,
verbose_level: int = 1,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
Expand All @@ -205,7 +223,6 @@ def __init__(

self.managed_agents = {}
if managed_agents is not None:
print("NOTNONE")
self.managed_agents = {agent.name: agent for agent in managed_agents}

self.tools = {tool.name: tool for tool in tools}
Expand All @@ -222,8 +239,8 @@ def __init__(
self.input_messages = None
self.logs = []
self.task = None
self.verbose = verbose
self.monitor = Monitor(self.model)
self.logger = AgentLogger(level=verbose_level)
self.monitor = Monitor(self.model, self.logger)
self.step_callbacks = step_callbacks if step_callbacks is not None else []
self.step_callbacks.append(self.monitor.update_metrics)

Expand Down Expand Up @@ -485,14 +502,15 @@ def run(
else:
self.logs.append(system_prompt_step)

console.print(
self.logger.log(
Panel(
f"\n[bold]{self.task.strip()}\n",
title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
border_style=YELLOW_HEX,
subtitle_align="left",
)
),
level=LogLevel.INFO,
)

self.logs.append(TaskStep(task=self.task))
Expand Down Expand Up @@ -531,12 +549,13 @@ def stream_run(self, task: str):
is_first_step=(self.step_number == 0),
step=self.step_number,
)
console.print(
self.logger.log(
Rule(
f"[bold]Step {self.step_number}",
characters="━",
style=YELLOW_HEX,
)
),
level=LogLevel.INFO,
)

# Run one step!
Expand All @@ -557,7 +576,7 @@ def stream_run(self, task: str):
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
console.print(Text(f"Final answer: {final_answer}"))
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
final_step_log.action_output = final_answer
final_step_log.end_time = time.time()
final_step_log.duration = step_log.end_time - step_start_time
Expand Down Expand Up @@ -586,12 +605,13 @@ def direct_run(self, task: str):
is_first_step=(self.step_number == 0),
step=self.step_number,
)
console.print(
self.logger.log(
Rule(
f"[bold]Step {self.step_number}",
characters="━",
style=YELLOW_HEX,
)
),
level=LogLevel.INFO,
)

# Run one step!
Expand All @@ -613,7 +633,7 @@ def direct_run(self, task: str):
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
console.print(Text(f"Final answer: {final_answer}"))
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
final_step_log.action_output = final_answer
final_step_log.duration = 0
for callback in self.step_callbacks:
Expand Down Expand Up @@ -679,8 +699,10 @@ def planning_step(self, task, is_first_step: bool, step: int):
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.print(
Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction)
self.logger.log(
Rule("[bold]Initial plan", style="orange"),
Text(final_plan_redaction),
level=LogLevel.INFO,
)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
Expand Down Expand Up @@ -735,8 +757,10 @@ def planning_step(self, task, is_first_step: bool, step: int):
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.print(
Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction)
self.logger.log(
Rule("[bold]Updated plan", style="orange"),
Text(final_plan_redaction),
level=LogLevel.INFO,
)


Expand Down Expand Up @@ -795,8 +819,11 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
)

# Execute
console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
self.logger.log(
Panel(
Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")
),
level=LogLevel.INFO,
)
if tool_name == "final_answer":
if isinstance(tool_arguments, dict):
Expand All @@ -810,13 +837,15 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
final_answer = self.state[answer]
console.print(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'."
self.logger.log(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.",
level=LogLevel.INFO,
)
else:
final_answer = answer
console.print(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}")
self.logger.log(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
level=LogLevel.INFO,
)

log_entry.action_output = final_answer
Expand All @@ -837,7 +866,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
console.print(f"Observations: {updated_information}")
self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
log_entry.observations = updated_information
return None

Expand Down Expand Up @@ -922,22 +951,22 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}")

if self.verbose:
console.print(
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Syntax(
llm_output,
lexer="markdown",
theme="github-dark",
word_wrap=True,
),
)
)
self.logger.log(
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Syntax(
llm_output,
lexer="markdown",
theme="github-dark",
word_wrap=True,
),
),
level=LogLevel.DEBUG,
)

# Parse
try:
Expand All @@ -955,7 +984,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
)

# Execute
console.print(
self.logger.log(
Panel(
Syntax(
code_action,
Expand All @@ -966,7 +995,8 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
title="[bold]Executing this code:",
title_align="left",
box=box.HORIZONTALS,
)
),
level=LogLevel.INFO,
)
observation = ""
is_final_answer = False
Expand All @@ -993,8 +1023,9 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
else:
error_msg = str(e)
if "Import of " in str(e) and " is not allowed" in str(e):
console.print(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent."
self.logger.log(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO,
)
raise AgentExecutionError(error_msg)

Expand All @@ -1008,7 +1039,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
style=(f"bold {YELLOW_HEX}" if is_final_answer else ""),
),
]
console.print(Group(*execution_outputs_console))
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
log_entry.action_output = output
return output if is_final_answer else None

Expand Down
7 changes: 3 additions & 4 deletions src/smolagents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
# limitations under the License.
from rich.text import Text

from .utils import console


class Monitor:
def __init__(self, tracked_model):
def __init__(self, tracked_model, logger):
self.step_durations = []
self.tracked_model = tracked_model
self.logger = logger
if (
getattr(self.tracked_model, "last_input_token_count", "Not found")
!= "Not found"
Expand Down Expand Up @@ -53,7 +52,7 @@ def update_metrics(self, step_log):
self.total_output_token_count += self.tracked_model.last_output_token_count
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
console_outputs += "]"
console.print(Text(console_outputs, style="dim"))
self.logger.log(Text(console_outputs, style="dim"), level=1)


__all__ = ["Monitor"]

0 comments on commit 667e378

Please sign in to comment.