diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 683cb52f..53ac374f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -140,7 +140,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: if results_truncated: logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) return - + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee1566..35391114 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -3,6 +3,7 @@ This module provides the core functionality for creating, managing, and executing tools through agents. """ +from .class_loader import load_tools_from_instance from .decorator import tool from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "load_tools_from_instance", ] diff --git a/src/strands/tools/class_loader.py b/src/strands/tools/class_loader.py new file mode 100644 index 00000000..bfeec0b2 --- /dev/null +++ b/src/strands/tools/class_loader.py @@ -0,0 +1,144 @@ +"""This module defines a method for accessing tools from an instance. + +It exposes: +- `load_tools_from_instance`: loads all public methods from an instance as AgentTool objects, with automatic name + disambiguation for instance methods. + +It will load instance, class, and static methods from the class, including inherited methods. + +By default, all public methods (not starting with _) will be loaded as AgentTool objects, even if not decorated. + +Note: + Tool names must be unique within an agent. If you load tools from multiple instances of the same class, + you MUST provide a unique label for each instance, or tools will overwrite each other in the registry. + The registry does not warn or error on duplicates; the last tool registered with a given name wins. + +The `load_tools_from_instance` function will return a list of `AgentTool` objects. +""" + +import inspect +import logging +from typing import Any, Callable, List, Optional + +from ..types.tools import AgentTool, ToolResult, ToolSpec, ToolUse +from .decorator import FunctionToolMetadata + +logger = logging.getLogger(__name__) + + +class GenericFunctionTool(AgentTool): + """Wraps any callable (instance, static, or class method) as an AgentTool. + + Uses FunctionToolMetadata for metadata extraction and input validation. + """ + + def __init__(self, func: Callable, name: Optional[str] = None, description: Optional[str] = None): + """Initialize a GenericFunctionTool.""" + super().__init__() + self._func = func + try: + self._meta = FunctionToolMetadata(func) + self._tool_spec = self._meta.extract_metadata() + if name: + self._tool_spec["name"] = name + if description: + self._tool_spec["description"] = description + except Exception as e: + logger.warning("Could not convert %s to AgentTool: %s", getattr(func, "__name__", str(func)), str(e)) + raise + + @property + def tool_name(self) -> str: + """Return the tool's name.""" + return str(self._tool_spec["name"]) + + @property + def tool_spec(self) -> ToolSpec: + """Return the tool's specification.""" + return self._tool_spec + + @property + def tool_type(self) -> str: + """Return the tool's type.""" + return "function" + + def invoke(self, tool: ToolUse, *args: Any, **kwargs: Any) -> ToolResult: + """Invoke the tool with validated input.""" + try: + validated_input = self._meta.validate_input(tool["input"]) + result = self._func(**validated_input) + return { + "toolUseId": tool.get("toolUseId", "unknown"), + "status": "success", + "content": [{"text": str(result)}], + } + except Exception as e: + return { + "toolUseId": tool.get("toolUseId", "unknown"), + "status": "error", + "content": [{"text": f"Error: {e}"}], + } + + +def load_tools_from_instance( + instance: object, + disambiguator: Optional[str] = None, +) -> List[AgentTool]: + """Load all public methods from an instance as AgentTool objects with name disambiguation. + + Instance methods are bound to the given instance and are disambiguated by suffixing the tool name + with the given label (or the instance id if no prefix is provided). Static and class methods are + not disambiguated, as they do not depend on instance state. + + Args: + instance: The instance to inspect. + disambiguator: Optional string to disambiguate instance method tool names. If not provided, uses id(instance). + + Returns: + List of AgentTool objects (GenericFunctionTool wrappers). + + Note: + Tool names must be unique within an agent. If you load tools from multiple instances of the same + class, you MUST provide a unique label for each instance, or tools will overwrite each + other in the registry. The registry does not warn or error on duplicates; the last tool registered + with a given name wins. This function will log a warning if a duplicate tool name is detected in + the returned list. + + Example: + from strands.tools.class_loader import load_tools_from_instance + + class MyClass: + def foo(self, x: int) -> int: + return x + 1 + + @staticmethod + def bar(y: int) -> int: + return y * 2 + + instance = MyClass() + tools = load_tools_from_instance(instance, disambiguator="special") + # tools is a list of AgentTool objects for foo and bar, with foo disambiguated as 'myclass_foo_special' + """ + methods: List[AgentTool] = [] + class_name = instance.__class__.__name__.lower() + func: Any + for name, _member in inspect.getmembers(instance.__class__): + if name.startswith("_"): + continue + tool_name = f"{class_name}_{name}" + raw_attr = instance.__class__.__dict__.get(name, None) + if isinstance(raw_attr, staticmethod): + func = raw_attr.__func__ + elif isinstance(raw_attr, classmethod): + func = raw_attr.__func__.__get__(instance, instance.__class__) + else: + # Instance method: bind to instance and disambiguate + func = getattr(instance, name, None) + tool_name += f"_{str(id(instance))}" if disambiguator is None else f"_{disambiguator}" + if callable(func): + try: + methods.append(GenericFunctionTool(func, name=tool_name)) + except Exception: + # Warning already logged in GenericFunctionTool + pass + return methods diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index f15a8e4f..0ab976b7 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -47,6 +47,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: import docstring_parser from pydantic import BaseModel, Field, create_model +from ..types.tools import ToolSpec + # Type for wrapped function T = TypeVar("T", bound=Callable[..., Any]) @@ -124,7 +126,7 @@ def _create_input_model(self) -> Type[BaseModel]: # Handle case with no parameters return create_model(model_name) - def extract_metadata(self) -> Dict[str, Any]: + def extract_metadata(self) -> ToolSpec: """Extract metadata from the function to create a tool specification. This method analyzes the function to create a standardized tool specification that Strands Agent can use to @@ -155,7 +157,7 @@ def extract_metadata(self) -> Dict[str, Any]: self._clean_pydantic_schema(input_schema) # Create tool specification - tool_spec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}} + tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}} return tool_spec @@ -288,7 +290,7 @@ def decorator(f: T) -> T: tool_spec = tool_meta.extract_metadata() # Update with any additional kwargs - tool_spec.update(tool_kwargs) + tool_spec.update(tool_kwargs) # type: ignore # Attach TOOL_SPEC directly to the original function (critical for backward compatibility) f.TOOL_SPEC = tool_spec # type: ignore diff --git a/tests-integ/test_class_loader.py b/tests-integ/test_class_loader.py new file mode 100644 index 00000000..4eb93221 --- /dev/null +++ b/tests-integ/test_class_loader.py @@ -0,0 +1,26 @@ +from strands.agent.agent import Agent +from strands.tools.class_loader import load_tools_from_instance + + +class WeatherTimeTool: + def get_weather_in_paris(self) -> str: + return "sunny" + + @staticmethod + def get_time_in_paris(r) -> str: + return "15:00" + + +def test_agent_weather_and_time(): + tool = WeatherTimeTool() + tools = load_tools_from_instance(tool) + prompt = ( + "What is the time and weather in paris?" + "return only with the weather and time for example 'rainy 04:00'" + "if you cannot respond with 'FAILED'" + ) + agent = Agent(tools=tools) + response = agent(prompt) + text = str(response).lower() + assert "sunny" in text + assert "15:00" in text diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index f0669284..8b1dade3 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -104,8 +104,8 @@ def test_can_reuse_mcp_client(): @pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == 'true', - reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue" + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", ) def test_streamable_http_mcp_client(): server_thread = threading.Thread( diff --git a/tests/strands/tools/test_class_loader.py b/tests/strands/tools/test_class_loader.py new file mode 100644 index 00000000..d84dd22a --- /dev/null +++ b/tests/strands/tools/test_class_loader.py @@ -0,0 +1,131 @@ +import pytest + +from strands.agent.agent import Agent +from strands.tools.class_loader import load_tools_from_instance + + +class MyTestClass: + def foo(self, x: int) -> int: + """Add 1 to x.""" + return x + 1 + + @staticmethod + def bar(y: int) -> int: + """Multiply y by 2.""" + return y * 2 + + @classmethod + def baz(cls, z: int) -> int: + """Subtract 1 from z.""" + return z - 1 + + not_a_method = 42 + + +def test_agent_tool_invocation_for_all_method_types(): + """Test that agent.tool.{tool_name} works for instance, static, and class methods.""" + instance = MyTestClass() + tools = load_tools_from_instance(instance, disambiguator="agenttest") + agent = Agent(tools=tools) + # Instance method + result_foo = agent.tool.mytestclass_foo_agenttest(x=5) + assert result_foo["status"] == "success" + assert result_foo["content"][0]["text"] == "6" + # Static method + result_bar = agent.tool.mytestclass_bar(y=3) + assert result_bar["status"] == "success" + assert result_bar["content"][0]["text"] == "6" + # Class method + result_baz = agent.tool.mytestclass_baz(z=10) + assert result_baz["status"] == "success" + assert result_baz["content"][0]["text"] == "9" + + +def test_non_callable_attributes_are_skipped(): + """Test that non-callable attributes are not loaded as tools.""" + + class ClassWithNonCallable: + foo = 123 + + def bar(self): + return 1 + + instance = ClassWithNonCallable() + tools = load_tools_from_instance(instance, disambiguator="nc") + tool_names = {tool.tool_name for tool in tools} + assert "classwithnoncallable_foo" not in tool_names + assert "classwithnoncallable_bar_nc" in tool_names + + +def test_error_handling_for_unconvertible_methods(monkeypatch): + """Test that a warning is logged and method is skipped if it cannot be converted.""" + + class BadClass: + def bad(self, x): + return x + + instance = BadClass() + # Patch FunctionToolMetadata to raise Exception + from strands.tools import class_loader + + orig = class_loader.FunctionToolMetadata.__init__ + + def fail_init(self, func): + raise ValueError("fail") + + monkeypatch.setattr(class_loader.FunctionToolMetadata, "__init__", fail_init) + with pytest.raises(ValueError): + # Direct instantiation should raise + class_loader.GenericFunctionTool(instance.bad) + # But loader should skip and not raise + tools = load_tools_from_instance(instance, disambiguator="bad") + assert tools == [] + # Restore + monkeypatch.setattr(class_loader.FunctionToolMetadata, "__init__", orig) + + +def test_default_prefix_is_instance_id(): + """Test that the default prefix is id(instance) when no prefix is provided.""" + instance = MyTestClass() + tools = load_tools_from_instance(instance) + tool_names = {tool.tool_name for tool in tools} + assert f"mytestclass_foo_{str(id(instance))}" in tool_names + assert "mytestclass_bar" in tool_names + assert "mytestclass_baz" in tool_names + + +def test_multiple_instances_of_same_class(): + """Test loading tools from multiple instances of the same class, including a static method.""" + + class Counter: + def __init__(self, start): + self.start = start + + def increment(self, x: int) -> int: + return self.start + x + + @staticmethod + def double_static(y: int) -> int: + return y * 2 + + a = Counter(10) + b = Counter(100) + tools_a = load_tools_from_instance(a, disambiguator="a") + tools_b = load_tools_from_instance(b, disambiguator="b") + agent = Agent(tools=tools_a + tools_b) + # Call increment for each instance + result_a = agent.tool.counter_increment_a(x=5) + result_b = agent.tool.counter_increment_b(x=5) + assert result_a["status"] == "success" + assert result_b["status"] == "success" + assert result_a["content"][0]["text"] == "15" + assert result_b["content"][0]["text"] == "105" + # Static method should be available (not disambiguated) + result_static = agent.tool.counter_double_static(y=7) + assert result_static["status"] == "success" + assert result_static["content"][0]["text"] == "14" + # Tool names are unique for instance methods, static method is shared + tool_names = set(agent.tool_names) + assert "counter_increment_a" in tool_names + assert "counter_increment_b" in tool_names + assert "counter_double_static" in tool_names