Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu committed Jun 19, 2024
1 parent 3ff0576 commit 6758300
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,8 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from llama_index.core.base.query_pipeline.query import (
InputKeys,
OutputKeys,
QueryComponent,
)
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.callbacks.base import CallbackManager
from typing import Any, Dict, Optional, List, cast, Callable
from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent

def _get_stateful_components(query_component: QueryComponent) -> List[BaseStatefulComponent]:
"""Get stateful components."""
stateful_components: List[BaseStatefulComponent] = []
for c in query_component.sub_query_components:
if isinstance(c, BaseStatefulComponent):
stateful_components.append(cast(BaseStatefulComponent, c))

if len(c.sub_query_components) > 0:
stateful_components.extend(_get_stateful_components(c))

return stateful_components

class LoopComponent(QueryComponent):
"""Loop component.
"""

pipeline: QueryPipeline = Field(..., description="Query pipeline")
should_exit_fn: Optional[Callable] = Field(..., description="Should exit function")
add_output_to_input_fn: Optional[Callable] = Field(..., description="Add output to input function. If not provided, will reuse the original input for the next iteration. If provided, will call the function to combine the output into the input for the next iteration.")
max_iterations: Optional[int] = Field(5, description="Max iterations")

class Config:
arbitrary_types_allowed = True

def __init__(
self,
pipeline: QueryPipeline,
should_exit_fn: Optional[Callable] = None,
add_output_to_input_fn: Optional[Callable] = None,
max_iterations: Optional[int] = 5,
) -> None:
"""Init params."""
super().__init__(pipeline=pipeline, should_exit_fn=should_exit_fn, add_output_to_input_fn=add_output_to_input_fn, max_iterations=max_iterations)

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: implement

def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
pass

@property
def stateful_components(self) -> List[BaseStatefulComponent]:
"""Get stateful component."""
# TODO: do this directly within the query pipeline
return _get_stateful_components(self.pipeline)

def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
state = {}
# partial agent output component with state
for stateful_component in self.stateful_components:
stateful_component.partial(state=state)

current_input = kwargs
for i in range(self.max_iterations):
output = self.pipeline.run_component(**current_input)
if self.should_exit_fn:
should_exit = self.should_exit_fn(output)
if should_exit:
break

if self.add_output_to_input_fn:
current_input = self.add_output_to_input_fn(current_input, output)

return self.pipeline.run_component(**kwargs)

async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
return await self.pipeline.arun_component(**kwargs)

@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return self.pipeline.input_keys

@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return self.pipeline.output_keys
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Agent components."""

from inspect import signature
from typing import Any, Callable, Dict, Optional, Set, Tuple, cast

from llama_index.core.base.query_pipeline.query import (
InputKeys,
OutputKeys,
QueryComponent,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.query_pipeline.components.function import FnComponent, get_parameters
# from llama_index.core.query_pipeline.components.input import InputComponent


class BaseStatefulComponent(QueryComponent):
"""Takes in agent inputs and transforms it into desired outputs."""



class StatefulFnComponent(BaseStatefulComponent, FnComponent):
"""Query component that takes in an arbitrary function.
Stateful version of `FnComponent`. Expects functions to have `state` as the first argument.
"""

def __init__(
self,
fn: Callable,
req_params: Optional[Set[str]] = None,
opt_params: Optional[Set[str]] = None,
**kwargs: Any
) -> None:
"""Init params."""

# determine parameters
default_req_params, default_opt_params = get_parameters(fn)
# make sure task and step are part of the list, and remove them from the list
if "state" not in default_req_params:
raise ValueError(
"StatefulFnComponent must have 'state' as required parameters"
)

default_req_params = default_req_params - {"state"}
default_opt_params = default_opt_params - {"state"}

super().__init__(fn=fn, req_params=req_params, opt_params=opt_params, **kwargs)

@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys(
required_keys={"state", *self._req_params},
optional_keys=self._opt_params,
)

@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# output can be anything, overrode validate function
return OutputKeys.from_keys({self.output_key})
51 changes: 50 additions & 1 deletion llama-index-core/tests/query_pipeline/test_components.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test components."""
from typing import Any, List, Sequence
from typing import Any, List, Sequence, Dict

import pytest
from llama_index.core.base.base_selector import (
Expand All @@ -15,6 +15,8 @@
)
from llama_index.core.query_pipeline.components.function import FnComponent
from llama_index.core.query_pipeline.components.input import InputComponent
from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent
from llama_index.core.query_pipeline.components.loop import LoopComponent
from llama_index.core.query_pipeline.components.router import (
RouterComponent,
SelectorComponent,
Expand Down Expand Up @@ -155,3 +157,50 @@ def bar2_fn(a: Any) -> str:
selector_c = SelectorComponent(selector=selector)
output = selector_c.run_component(query="hello", choices=["t1", "t2"])
assert output["output"][0] == SingleSelection(index=1, reason="foo")


def stateful_foo_fn(state: Dict[str, Any], a: int, b: int = 2) -> Dict[str, Any]:
"""Foo function."""
old = state.get("prev", 0)
new = old + a + b
state["prev"] = new
return new


def test_stateful_fn_pipeline() -> None:
"""Test pipeline with function components."""
p = QueryPipeline()
p.add_modules(
{
"m1": StatefulFnComponent(fn=stateful_foo_fn),
"m2": StatefulFnComponent(fn=stateful_foo_fn)
}
)
p.add_link("m1", "m2", src_key="output", dest_key="a", input_fn=lambda x: x["new"])
# p.add_link("m1", "m2", src_key="output", dest_key="state", input_fn=lambda x: x["state"])
# output = p.run(a=1, b=2)
# assert output == 6

# try one iteration
loop_component = LoopComponent(
pipeline=p,
should_exit_fn=lambda x: x > 10,
# add_output_to_input_fn=lambda cur_input, output: {"a": output},
max_iterations=1
)
output = loop_component.run_component(a=1, b=2)
assert output["output"] == 6

# try two iterations
# loop 1: 0 + 1 + 2 = 3, 3 + 3 + 2 = 8
# loop 2: 8 + 8 + 2 = 18, 18 + 18 + 2 = 38
loop_component = LoopComponent(
pipeline=p,
should_exit_fn=lambda x: x > 10,
add_output_to_input_fn=lambda cur_input, output: {"a": output},
max_iterations=5
)
assert loop_component.run_component(a=1, b=2)["output"] == 38


# test loop component

0 comments on commit 6758300

Please sign in to comment.