From 48116f0069d22256255d1bf937ef9c503e3d8b4c Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Wed, 19 Jun 2024 15:28:56 -0700 Subject: [PATCH] cr --- .../query_pipeline/components/__init__.py | 4 +- .../core/query_pipeline/components/loop.py | 45 +++++++-------- .../query_pipeline/components/stateful.py | 53 ++++++++++++------ .../llama_index/core/query_pipeline/query.py | 55 +++++++++++++++++++ .../tests/query_pipeline/test_components.py | 20 ++++--- 5 files changed, 130 insertions(+), 47 deletions(-) diff --git a/llama-index-core/llama_index/core/query_pipeline/components/__init__.py b/llama-index-core/llama_index/core/query_pipeline/components/__init__.py index a5523b251d955..aebf32f30e5fa 100644 --- a/llama-index-core/llama_index/core/query_pipeline/components/__init__.py +++ b/llama-index-core/llama_index/core/query_pipeline/components/__init__.py @@ -16,6 +16,7 @@ ) from llama_index.core.query_pipeline.components.tool_runner import ToolRunnerComponent from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent +from llama_index.core.query_pipeline.components.loop import LoopComponent __all__ = [ "AgentFnComponent", @@ -29,5 +30,6 @@ "RouterComponent", "SelectorComponent", "ToolRunnerComponent", - "StatefulFnComponent" + "StatefulFnComponent", + "LoopComponent" ] diff --git a/llama-index-core/llama_index/core/query_pipeline/components/loop.py b/llama-index-core/llama_index/core/query_pipeline/components/loop.py index 79d700072dbf9..3dbc3daf2d291 100644 --- a/llama-index-core/llama_index/core/query_pipeline/components/loop.py +++ b/llama-index-core/llama_index/core/query_pipeline/components/loop.py @@ -9,17 +9,17 @@ from typing import Any, Dict, Optional, List, cast, Callable from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent -def _get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]: - """Get stateful components.""" - stateful_components: List[BaseStatefulComponent] = [] - for c in query_component.sub_query_components: - if isinstance(c, BaseStatefulComponent): - stateful_components.append(cast(BaseStatefulComponent, c)) +# def _get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]: +# """Get stateful components.""" +# stateful_components: List[BaseStatefulComponent] = [] +# for c in query_component.sub_query_components: +# if isinstance(c, BaseStatefulComponent): +# stateful_components.append(cast(BaseStatefulComponent, c)) - if len(c.sub_query_components) > 0: - stateful_components.extend(_get_stateful_components(c)) +# if len(c.sub_query_components) > 0: +# stateful_components.extend(_get_stateful_components(c)) - return stateful_components +# return stateful_components class LoopComponent(QueryComponent): """Loop component. @@ -49,20 +49,10 @@ def set_callback_manager(self, callback_manager: CallbackManager) -> None: # TODO: implement def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: - pass - - @property - def stateful_components(self) -> List[BaseStatefulComponent]: - """Get stateful component.""" - # TODO: do this directly within the query pipeline - return _get_stateful_components(self.pipeline) + return input def _run_component(self, **kwargs: Any) -> Dict: """Run component.""" - state = {} - # partial agent output component with state - for stateful_component in self.stateful_components: - stateful_component.partial(state=state) current_input = kwargs for i in range(self.max_iterations): @@ -75,11 +65,22 @@ def _run_component(self, **kwargs: Any) -> Dict: if self.add_output_to_input_fn: current_input = self.add_output_to_input_fn(current_input, output) - return self.pipeline.run_component(**kwargs) + return output async def _arun_component(self, **kwargs: Any) -> Any: """Run component (async).""" - return await self.pipeline.arun_component(**kwargs) + current_input = kwargs + for i in range(self.max_iterations): + output = await self.pipeline.arun_component(**current_input) + if self.should_exit_fn: + should_exit = self.should_exit_fn(output) + if should_exit: + break + + if self.add_output_to_input_fn: + current_input = self.add_output_to_input_fn(current_input, output) + + return output @property def input_keys(self) -> InputKeys: diff --git a/llama-index-core/llama_index/core/query_pipeline/components/stateful.py b/llama-index-core/llama_index/core/query_pipeline/components/stateful.py index b7948b391a799..da2114d5f024b 100644 --- a/llama-index-core/llama_index/core/query_pipeline/components/stateful.py +++ b/llama-index-core/llama_index/core/query_pipeline/components/stateful.py @@ -17,7 +17,13 @@ class BaseStatefulComponent(QueryComponent): """Takes in agent inputs and transforms it into desired outputs.""" + state: Dict[str, Any] = Field( + default_factory=dict, description="State of the pipeline." + ) + def reset_state(self) -> None: + """Reset state.""" + self.state = {} class StatefulFnComponent(BaseStatefulComponent, FnComponent): """Query component that takes in an arbitrary function. @@ -25,12 +31,12 @@ class StatefulFnComponent(BaseStatefulComponent, FnComponent): Stateful version of `FnComponent`. Expects functions to have `state` as the first argument. """ - def __init__( self, fn: Callable, req_params: Optional[Set[str]] = None, opt_params: Optional[Set[str]] = None, + state: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: """Init params.""" @@ -45,19 +51,34 @@ def __init__( default_req_params = default_req_params - {"state"} default_opt_params = default_opt_params - {"state"} + + if req_params is None: + req_params = default_req_params + if opt_params is None: + opt_params = default_opt_params - super().__init__(fn=fn, req_params=req_params, opt_params=opt_params, **kwargs) - - @property - def input_keys(self) -> InputKeys: - """Input keys.""" - return InputKeys.from_keys( - required_keys={"state", *self._req_params}, - optional_keys=self._opt_params, - ) - - @property - def output_keys(self) -> OutputKeys: - """Output keys.""" - # output can be anything, overrode validate function - return OutputKeys.from_keys({self.output_key}) \ No newline at end of file + super().__init__(fn=fn, req_params=req_params, opt_params=opt_params, state=state or {}, **kwargs) + + def _run_component(self, **kwargs: Any) -> Dict: + """Run component.""" + kwargs.update({"state": self.state}) + return super()._run_component(**kwargs) + + async def _arun_component(self, **kwargs: Any) -> Any: + """Async run component.""" + kwargs.update({"state": self.state}) + return await super()._arun_component(**kwargs) + + # @property + # def input_keys(self) -> InputKeys: + # """Input keys.""" + # return InputKeys.from_keys( + # required_keys={"state", *self._req_params}, + # optional_keys=self._opt_params, + # ) + + # @property + # def output_keys(self) -> OutputKeys: + # """Output keys.""" + # # output can be anything, overrode validate function + # return OutputKeys.from_keys({self.output_key}) \ No newline at end of file diff --git a/llama-index-core/llama_index/core/query_pipeline/query.py b/llama-index-core/llama_index/core/query_pipeline/query.py index b28907f41454b..f502a80c36c60 100644 --- a/llama-index-core/llama_index/core/query_pipeline/query.py +++ b/llama-index-core/llama_index/core/query_pipeline/query.py @@ -32,6 +32,9 @@ ComponentIntermediates, ) from llama_index.core.utils import print_text +from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent +from llama_index.core.bridge.pydantic import root_validator, validator + # TODO: Make this (safely) pydantic? @@ -153,6 +156,32 @@ def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiD return graph_copy +def get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]: + """Get stateful components.""" + stateful_components: List[BaseStatefulComponent] = [] + for c in query_component.sub_query_components: + if isinstance(c, BaseStatefulComponent): + stateful_components.append(cast(BaseStatefulComponent, c)) + + if len(c.sub_query_components) > 0: + stateful_components.extend(get_stateful_components(c)) + + return stateful_components + + +def update_stateful_components(stateful_components: List[BaseStatefulComponent], state: Dict[str, Any]) -> None: + """Update stateful components.""" + for stateful_component in stateful_components: + # stateful_component.partial(state=state) + stateful_component.state = state + + +def get_and_update_stateful_components(query_component: QueryComponent, state: Dict[str, Any]) -> List[BaseStatefulComponent]: + """Get and update stateful components.""" + stateful_components = get_stateful_components(query_component) + update_stateful_components(stateful_components, state) + return stateful_components + CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str] @@ -184,6 +213,9 @@ class QueryPipeline(QueryComponent): num_workers: int = Field( default=4, description="Number of workers to use (currently async only)." ) + state: Dict[str, Any] = Field( + default_factory=dict, description="State of the pipeline." + ) class Config: arbitrary_types_allowed = True @@ -194,14 +226,29 @@ def __init__( chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None, modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None, links: Optional[List[Link]] = None, + state: Optional[Dict[str, Any]] = None, **kwargs: Any, ): super().__init__( callback_manager=callback_manager or CallbackManager([]), + state=state or {}, **kwargs, ) self._init_graph(chain=chain, modules=modules, links=links) + # Pydantic validator isn't called for __init__ so we need to call it manually + get_and_update_stateful_components(self, state) + + def set_state(self, state: Dict[str, Any]) -> None: + """Update state.""" + self.state = state + get_and_update_stateful_components(self, state) + + + def reset_state(self) -> None: + """Reset state.""" + # use pydantic validator to update state + self.set_state({}) def _init_graph( self, @@ -243,6 +290,11 @@ def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None: for i in range(len(chain) - 1): self.add_link(src=module_keys[i], dest=module_keys[i + 1]) + @property + def stateful_components(self) -> List[BaseStatefulComponent]: + """Get stateful component.""" + return get_stateful_components(self) + def add_links( self, links: List[Link], @@ -272,6 +324,8 @@ def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None: self.module_dict[module_key] = cast(QueryComponent, module) self.dag.add_node(module_key) + # TODO: there's more efficient ways to do this + get_and_update_stateful_components(self, self.state) def add_link( self, @@ -811,6 +865,7 @@ def get_next_module_keys(self, run_state: RunState) -> List[str]: if module_key in run_state.executed_modules: continue # Module already executed + if all( key in module_input for key in self.module_dict[module_key].free_req_input_keys diff --git a/llama-index-core/tests/query_pipeline/test_components.py b/llama-index-core/tests/query_pipeline/test_components.py index 47869b0742541..bca2fb200135a 100644 --- a/llama-index-core/tests/query_pipeline/test_components.py +++ b/llama-index-core/tests/query_pipeline/test_components.py @@ -176,28 +176,32 @@ def test_stateful_fn_pipeline() -> None: "m2": StatefulFnComponent(fn=stateful_foo_fn) } ) - p.add_link("m1", "m2", src_key="output", dest_key="a", input_fn=lambda x: x["new"]) - # p.add_link("m1", "m2", src_key="output", dest_key="state", input_fn=lambda x: x["state"]) - # output = p.run(a=1, b=2) - # assert output == 6 + p.add_link("m1", "m2", src_key="output", dest_key="a") + output = p.run(a=1, b=2) + assert output == 8 + p.reset_state() + output = p.run(a=1, b=2) + assert output == 8 # try one iteration + p.reset_state() loop_component = LoopComponent( pipeline=p, - should_exit_fn=lambda x: x > 10, + should_exit_fn=lambda x: x["output"] > 10, # add_output_to_input_fn=lambda cur_input, output: {"a": output}, max_iterations=1 ) output = loop_component.run_component(a=1, b=2) - assert output["output"] == 6 + assert output["output"] == 8 # try two iterations + p.reset_state() # loop 1: 0 + 1 + 2 = 3, 3 + 3 + 2 = 8 # loop 2: 8 + 8 + 2 = 18, 18 + 18 + 2 = 38 loop_component = LoopComponent( pipeline=p, - should_exit_fn=lambda x: x > 10, - add_output_to_input_fn=lambda cur_input, output: {"a": output}, + should_exit_fn=lambda x: x["output"] > 10, + add_output_to_input_fn=lambda cur_input, output: {"a": output["output"]}, max_iterations=5 ) assert loop_component.run_component(a=1, b=2)["output"] == 38