diff --git a/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb b/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb index 599228c74be47..4fea28d9cb945 100644 --- a/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb +++ b/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb @@ -1169,7 +1169,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.10.8" } }, "nbformat": 4, diff --git a/llama-index-core/llama_index/core/query_pipeline/components/loop.py b/llama-index-core/llama_index/core/query_pipeline/components/loop.py new file mode 100644 index 0000000000000..79d700072dbf9 --- /dev/null +++ b/llama-index-core/llama_index/core/query_pipeline/components/loop.py @@ -0,0 +1,92 @@ +from llama_index.core.base.query_pipeline.query import ( + InputKeys, + OutputKeys, + QueryComponent, +) +from llama_index.core.query_pipeline.query import QueryPipeline +from llama_index.core.bridge.pydantic import BaseModel, Field +from llama_index.core.callbacks.base import CallbackManager +from typing import Any, Dict, Optional, List, cast, Callable +from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent + +def _get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]: + """Get stateful components.""" + stateful_components: List[BaseStatefulComponent] = [] + for c in query_component.sub_query_components: + if isinstance(c, BaseStatefulComponent): + stateful_components.append(cast(BaseStatefulComponent, c)) + + if len(c.sub_query_components) > 0: + stateful_components.extend(_get_stateful_components(c)) + + return stateful_components + +class LoopComponent(QueryComponent): + """Loop component. + + """ + + pipeline: QueryPipeline = Field(..., description="Query pipeline") + should_exit_fn: Optional[Callable] = Field(..., description="Should exit function") + add_output_to_input_fn: Optional[Callable] = Field(..., description="Add output to input function. If not provided, will reuse the original input for the next iteration. If provided, will call the function to combine the output into the input for the next iteration.") + max_iterations: Optional[int] = Field(5, description="Max iterations") + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + pipeline: QueryPipeline, + should_exit_fn: Optional[Callable] = None, + add_output_to_input_fn: Optional[Callable] = None, + max_iterations: Optional[int] = 5, + ) -> None: + """Init params.""" + super().__init__(pipeline=pipeline, should_exit_fn=should_exit_fn, add_output_to_input_fn=add_output_to_input_fn, max_iterations=max_iterations) + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: implement + + def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: + pass + + @property + def stateful_components(self) -> List[BaseStatefulComponent]: + """Get stateful component.""" + # TODO: do this directly within the query pipeline + return _get_stateful_components(self.pipeline) + + def _run_component(self, **kwargs: Any) -> Dict: + """Run component.""" + state = {} + # partial agent output component with state + for stateful_component in self.stateful_components: + stateful_component.partial(state=state) + + current_input = kwargs + for i in range(self.max_iterations): + output = self.pipeline.run_component(**current_input) + if self.should_exit_fn: + should_exit = self.should_exit_fn(output) + if should_exit: + break + + if self.add_output_to_input_fn: + current_input = self.add_output_to_input_fn(current_input, output) + + return self.pipeline.run_component(**kwargs) + + async def _arun_component(self, **kwargs: Any) -> Any: + """Run component (async).""" + return await self.pipeline.arun_component(**kwargs) + + @property + def input_keys(self) -> InputKeys: + """Input keys.""" + return self.pipeline.input_keys + + @property + def output_keys(self) -> OutputKeys: + """Output keys.""" + return self.pipeline.output_keys \ No newline at end of file diff --git a/llama-index-core/llama_index/core/query_pipeline/components/stateful.py b/llama-index-core/llama_index/core/query_pipeline/components/stateful.py new file mode 100644 index 0000000000000..b7948b391a799 --- /dev/null +++ b/llama-index-core/llama_index/core/query_pipeline/components/stateful.py @@ -0,0 +1,63 @@ +"""Agent components.""" + +from inspect import signature +from typing import Any, Callable, Dict, Optional, Set, Tuple, cast + +from llama_index.core.base.query_pipeline.query import ( + InputKeys, + OutputKeys, + QueryComponent, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.query_pipeline.components.function import FnComponent, get_parameters +# from llama_index.core.query_pipeline.components.input import InputComponent + + +class BaseStatefulComponent(QueryComponent): + """Takes in agent inputs and transforms it into desired outputs.""" + + + +class StatefulFnComponent(BaseStatefulComponent, FnComponent): + """Query component that takes in an arbitrary function. + + Stateful version of `FnComponent`. Expects functions to have `state` as the first argument. + + """ + + def __init__( + self, + fn: Callable, + req_params: Optional[Set[str]] = None, + opt_params: Optional[Set[str]] = None, + **kwargs: Any + ) -> None: + """Init params.""" + + # determine parameters + default_req_params, default_opt_params = get_parameters(fn) + # make sure task and step are part of the list, and remove them from the list + if "state" not in default_req_params: + raise ValueError( + "StatefulFnComponent must have 'state' as required parameters" + ) + + default_req_params = default_req_params - {"state"} + default_opt_params = default_opt_params - {"state"} + + super().__init__(fn=fn, req_params=req_params, opt_params=opt_params, **kwargs) + + @property + def input_keys(self) -> InputKeys: + """Input keys.""" + return InputKeys.from_keys( + required_keys={"state", *self._req_params}, + optional_keys=self._opt_params, + ) + + @property + def output_keys(self) -> OutputKeys: + """Output keys.""" + # output can be anything, overrode validate function + return OutputKeys.from_keys({self.output_key}) \ No newline at end of file diff --git a/llama-index-core/tests/query_pipeline/test_components.py b/llama-index-core/tests/query_pipeline/test_components.py index 2dfdca600c2c4..47869b0742541 100644 --- a/llama-index-core/tests/query_pipeline/test_components.py +++ b/llama-index-core/tests/query_pipeline/test_components.py @@ -1,5 +1,5 @@ """Test components.""" -from typing import Any, List, Sequence +from typing import Any, List, Sequence, Dict import pytest from llama_index.core.base.base_selector import ( @@ -15,6 +15,8 @@ ) from llama_index.core.query_pipeline.components.function import FnComponent from llama_index.core.query_pipeline.components.input import InputComponent +from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent +from llama_index.core.query_pipeline.components.loop import LoopComponent from llama_index.core.query_pipeline.components.router import ( RouterComponent, SelectorComponent, @@ -155,3 +157,50 @@ def bar2_fn(a: Any) -> str: selector_c = SelectorComponent(selector=selector) output = selector_c.run_component(query="hello", choices=["t1", "t2"]) assert output["output"][0] == SingleSelection(index=1, reason="foo") + + +def stateful_foo_fn(state: Dict[str, Any], a: int, b: int = 2) -> Dict[str, Any]: + """Foo function.""" + old = state.get("prev", 0) + new = old + a + b + state["prev"] = new + return new + + +def test_stateful_fn_pipeline() -> None: + """Test pipeline with function components.""" + p = QueryPipeline() + p.add_modules( + { + "m1": StatefulFnComponent(fn=stateful_foo_fn), + "m2": StatefulFnComponent(fn=stateful_foo_fn) + } + ) + p.add_link("m1", "m2", src_key="output", dest_key="a", input_fn=lambda x: x["new"]) + # p.add_link("m1", "m2", src_key="output", dest_key="state", input_fn=lambda x: x["state"]) + # output = p.run(a=1, b=2) + # assert output == 6 + + # try one iteration + loop_component = LoopComponent( + pipeline=p, + should_exit_fn=lambda x: x > 10, + # add_output_to_input_fn=lambda cur_input, output: {"a": output}, + max_iterations=1 + ) + output = loop_component.run_component(a=1, b=2) + assert output["output"] == 6 + + # try two iterations + # loop 1: 0 + 1 + 2 = 3, 3 + 3 + 2 = 8 + # loop 2: 8 + 8 + 2 = 18, 18 + 18 + 2 = 38 + loop_component = LoopComponent( + pipeline=p, + should_exit_fn=lambda x: x > 10, + add_output_to_input_fn=lambda cur_input, output: {"a": output}, + max_iterations=5 + ) + assert loop_component.run_component(a=1, b=2)["output"] == 38 + + + # test loop component