Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu committed Jun 19, 2024
1 parent 555a8d3 commit 48116f0
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from llama_index.core.query_pipeline.components.tool_runner import ToolRunnerComponent
from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent
from llama_index.core.query_pipeline.components.loop import LoopComponent

__all__ = [
"AgentFnComponent",
Expand All @@ -29,5 +30,6 @@
"RouterComponent",
"SelectorComponent",
"ToolRunnerComponent",
"StatefulFnComponent"
"StatefulFnComponent",
"LoopComponent"
]
45 changes: 23 additions & 22 deletions llama-index-core/llama_index/core/query_pipeline/components/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
from typing import Any, Dict, Optional, List, cast, Callable
from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent

def _get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]:
"""Get stateful components."""
stateful_components: List[BaseStatefulComponent] = []
for c in query_component.sub_query_components:
if isinstance(c, BaseStatefulComponent):
stateful_components.append(cast(BaseStatefulComponent, c))
# def _get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]:
# """Get stateful components."""
# stateful_components: List[BaseStatefulComponent] = []
# for c in query_component.sub_query_components:
# if isinstance(c, BaseStatefulComponent):
# stateful_components.append(cast(BaseStatefulComponent, c))

if len(c.sub_query_components) > 0:
stateful_components.extend(_get_stateful_components(c))
# if len(c.sub_query_components) > 0:
# stateful_components.extend(_get_stateful_components(c))

return stateful_components
# return stateful_components

class LoopComponent(QueryComponent):
"""Loop component.
Expand Down Expand Up @@ -49,20 +49,10 @@ def set_callback_manager(self, callback_manager: CallbackManager) -> None:
# TODO: implement

def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
pass

@property
def stateful_components(self) -> List[BaseStatefulComponent]:
"""Get stateful component."""
# TODO: do this directly within the query pipeline
return _get_stateful_components(self.pipeline)
return input

def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
state = {}
# partial agent output component with state
for stateful_component in self.stateful_components:
stateful_component.partial(state=state)

current_input = kwargs
for i in range(self.max_iterations):
Expand All @@ -75,11 +65,22 @@ def _run_component(self, **kwargs: Any) -> Dict:
if self.add_output_to_input_fn:
current_input = self.add_output_to_input_fn(current_input, output)

return self.pipeline.run_component(**kwargs)
return output

async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
return await self.pipeline.arun_component(**kwargs)
current_input = kwargs
for i in range(self.max_iterations):
output = await self.pipeline.arun_component(**current_input)
if self.should_exit_fn:
should_exit = self.should_exit_fn(output)
if should_exit:
break

if self.add_output_to_input_fn:
current_input = self.add_output_to_input_fn(current_input, output)

return output

@property
def input_keys(self) -> InputKeys:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@
class BaseStatefulComponent(QueryComponent):
"""Takes in agent inputs and transforms it into desired outputs."""

state: Dict[str, Any] = Field(
default_factory=dict, description="State of the pipeline."
)

def reset_state(self) -> None:
"""Reset state."""
self.state = {}

class StatefulFnComponent(BaseStatefulComponent, FnComponent):
"""Query component that takes in an arbitrary function.
Stateful version of `FnComponent`. Expects functions to have `state` as the first argument.
"""

def __init__(
self,
fn: Callable,
req_params: Optional[Set[str]] = None,
opt_params: Optional[Set[str]] = None,
state: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> None:
"""Init params."""
Expand All @@ -45,19 +51,34 @@ def __init__(

default_req_params = default_req_params - {"state"}
default_opt_params = default_opt_params - {"state"}

if req_params is None:
req_params = default_req_params
if opt_params is None:
opt_params = default_opt_params

super().__init__(fn=fn, req_params=req_params, opt_params=opt_params, **kwargs)

@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys(
required_keys={"state", *self._req_params},
optional_keys=self._opt_params,
)

@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# output can be anything, overrode validate function
return OutputKeys.from_keys({self.output_key})
super().__init__(fn=fn, req_params=req_params, opt_params=opt_params, state=state or {}, **kwargs)

def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
kwargs.update({"state": self.state})
return super()._run_component(**kwargs)

async def _arun_component(self, **kwargs: Any) -> Any:
"""Async run component."""
kwargs.update({"state": self.state})
return await super()._arun_component(**kwargs)

# @property
# def input_keys(self) -> InputKeys:
# """Input keys."""
# return InputKeys.from_keys(
# required_keys={"state", *self._req_params},
# optional_keys=self._opt_params,
# )

# @property
# def output_keys(self) -> OutputKeys:
# """Output keys."""
# # output can be anything, overrode validate function
# return OutputKeys.from_keys({self.output_key})
55 changes: 55 additions & 0 deletions llama-index-core/llama_index/core/query_pipeline/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
ComponentIntermediates,
)
from llama_index.core.utils import print_text
from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent
from llama_index.core.bridge.pydantic import root_validator, validator



# TODO: Make this (safely) pydantic?
Expand Down Expand Up @@ -153,6 +156,32 @@ def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiD

return graph_copy

def get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]:
"""Get stateful components."""
stateful_components: List[BaseStatefulComponent] = []
for c in query_component.sub_query_components:
if isinstance(c, BaseStatefulComponent):
stateful_components.append(cast(BaseStatefulComponent, c))

if len(c.sub_query_components) > 0:
stateful_components.extend(get_stateful_components(c))

return stateful_components


def update_stateful_components(stateful_components: List[BaseStatefulComponent], state: Dict[str, Any]) -> None:
"""Update stateful components."""
for stateful_component in stateful_components:
# stateful_component.partial(state=state)
stateful_component.state = state


def get_and_update_stateful_components(query_component: QueryComponent, state: Dict[str, Any]) -> List[BaseStatefulComponent]:
"""Get and update stateful components."""
stateful_components = get_stateful_components(query_component)
update_stateful_components(stateful_components, state)
return stateful_components


CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str]

Expand Down Expand Up @@ -184,6 +213,9 @@ class QueryPipeline(QueryComponent):
num_workers: int = Field(
default=4, description="Number of workers to use (currently async only)."
)
state: Dict[str, Any] = Field(
default_factory=dict, description="State of the pipeline."
)

class Config:
arbitrary_types_allowed = True
Expand All @@ -194,14 +226,29 @@ def __init__(
chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
links: Optional[List[Link]] = None,
state: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
super().__init__(
callback_manager=callback_manager or CallbackManager([]),
state=state or {},
**kwargs,
)

self._init_graph(chain=chain, modules=modules, links=links)
# Pydantic validator isn't called for __init__ so we need to call it manually
get_and_update_stateful_components(self, state)

def set_state(self, state: Dict[str, Any]) -> None:
"""Update state."""
self.state = state
get_and_update_stateful_components(self, state)


def reset_state(self) -> None:
"""Reset state."""
# use pydantic validator to update state
self.set_state({})

def _init_graph(
self,
Expand Down Expand Up @@ -243,6 +290,11 @@ def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None:
for i in range(len(chain) - 1):
self.add_link(src=module_keys[i], dest=module_keys[i + 1])

@property
def stateful_components(self) -> List[BaseStatefulComponent]:
"""Get stateful component."""
return get_stateful_components(self)

def add_links(
self,
links: List[Link],
Expand Down Expand Up @@ -272,6 +324,8 @@ def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None:

self.module_dict[module_key] = cast(QueryComponent, module)
self.dag.add_node(module_key)
# TODO: there's more efficient ways to do this
get_and_update_stateful_components(self, self.state)

def add_link(
self,
Expand Down Expand Up @@ -811,6 +865,7 @@ def get_next_module_keys(self, run_state: RunState) -> List[str]:
if module_key in run_state.executed_modules:
continue # Module already executed


if all(
key in module_input
for key in self.module_dict[module_key].free_req_input_keys
Expand Down
20 changes: 12 additions & 8 deletions llama-index-core/tests/query_pipeline/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,32 @@ def test_stateful_fn_pipeline() -> None:
"m2": StatefulFnComponent(fn=stateful_foo_fn)
}
)
p.add_link("m1", "m2", src_key="output", dest_key="a", input_fn=lambda x: x["new"])
# p.add_link("m1", "m2", src_key="output", dest_key="state", input_fn=lambda x: x["state"])
# output = p.run(a=1, b=2)
# assert output == 6
p.add_link("m1", "m2", src_key="output", dest_key="a")
output = p.run(a=1, b=2)
assert output == 8
p.reset_state()
output = p.run(a=1, b=2)
assert output == 8

# try one iteration
p.reset_state()
loop_component = LoopComponent(
pipeline=p,
should_exit_fn=lambda x: x > 10,
should_exit_fn=lambda x: x["output"] > 10,
# add_output_to_input_fn=lambda cur_input, output: {"a": output},
max_iterations=1
)
output = loop_component.run_component(a=1, b=2)
assert output["output"] == 6
assert output["output"] == 8

# try two iterations
p.reset_state()
# loop 1: 0 + 1 + 2 = 3, 3 + 3 + 2 = 8
# loop 2: 8 + 8 + 2 = 18, 18 + 18 + 2 = 38
loop_component = LoopComponent(
pipeline=p,
should_exit_fn=lambda x: x > 10,
add_output_to_input_fn=lambda cur_input, output: {"a": output},
should_exit_fn=lambda x: x["output"] > 10,
add_output_to_input_fn=lambda cur_input, output: {"a": output["output"]},
max_iterations=5
)
assert loop_component.run_component(a=1, b=2)["output"] == 38
Expand Down

0 comments on commit 48116f0

Please sign in to comment.