diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 0f9658ff6..b3ba1c9c8 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -1,17 +1,12 @@ from dataclasses import asdict, dataclass from logging import getLogger -from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union +from typing import Any, List, TypedDict, Union from smolagents.models import ChatMessage, MessageRole from smolagents.monitoring import AgentLogger from smolagents.utils import AgentError, make_json_serializable -if TYPE_CHECKING: - from smolagents.models import ChatMessage - from smolagents.monitoring import AgentLogger - - logger = getLogger(__name__) @@ -42,7 +37,7 @@ class MemoryStep: def dict(self): return asdict(self) - def to_messages(self, **kwargs) -> List[Dict[str, Any]]: + def to_messages(self, **kwargs) -> List[Message]: raise NotImplementedError @@ -55,7 +50,7 @@ class ActionStep(MemoryStep): step_number: int | None = None error: AgentError | None = None duration: float | None = None - model_output_message: ChatMessage = None + model_output_message: ChatMessage | None = None model_output: str | None = None observations: str | None = None observations_images: List[str] | None = None @@ -78,7 +73,7 @@ def dict(self): } def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]: - messages = [] + messages: List[Message] = [] if self.model_input_messages is not None and show_model_input_messages: messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) if self.model_output is not None and not summary_mode: diff --git a/src/smolagents/models.py b/src/smolagents/models.py index b5768469b..57fe29243 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -106,7 +106,7 @@ def from_hf_api(cls, message, raw) -> "ChatMessage": return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw) @classmethod - def from_dict(cls, data: dict) -> "ChatMessage": + def from_dict(cls, data: dict[str, Any]) -> "ChatMessage": if data.get("tool_calls"): tool_calls = [ ChatMessageToolCall( @@ -124,11 +124,11 @@ def dict(self): def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: if isinstance(arguments, dict): return arguments - else: - try: - return json.loads(arguments) - except Exception: - return arguments + + try: + return json.loads(arguments) + except Exception: + return arguments def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: @@ -155,7 +155,7 @@ def roles(cls): } -def get_tool_json_schema(tool: Tool) -> Dict: +def get_tool_json_schema(tool: Tool) -> dict[str, Any]: properties = deepcopy(tool.inputs) required = [] for key, value in properties.items(): @@ -298,7 +298,7 @@ def _prepare_completion_kwargs( return completion_kwargs - def get_token_counts(self) -> Dict[str, int]: + def get_token_counts(self) -> dict[str, int | None]: return { "input_token_count": self.last_input_token_count, "output_token_count": self.last_output_token_count, @@ -331,11 +331,11 @@ def __call__( """ pass # To be implemented in child classes! - def to_dict(self) -> Dict: + def to_dict(self) -> dict[str, Any]: """ Converts the model into a JSON-compatible dictionary. """ - model_dictionary = { + model_dictionary: dict[str, Any] = { **self.kwargs, "last_input_token_count": self.last_input_token_count, "last_output_token_count": self.last_output_token_count,