Skip to content

Commit

Permalink
wip: refactor query pipeline agent to use stateful function components (
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Jun 29, 2024
1 parent 30e1118 commit 70a7adf
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 190 deletions.
301 changes: 130 additions & 171 deletions docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb

Large diffs are not rendered by default.

85 changes: 69 additions & 16 deletions llama-index-core/llama_index/core/agent/custom/pipeline_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):
Barebones agent worker that takes in a query pipeline.
Assumes that the first component in the query pipeline is an
`AgentInputComponent` and last is `AgentFnComponent`.
**Default Workflow**: The default workflow assumes that you compose
a query pipeline with `StatefulFnComponent` objects. This allows you to store, update
and retrieve state throughout the executions of the query pipeline by the agent.
The task and step state of the agent are stored in this `state` variable via a special key.
Of course you can choose to store other variables in this state as well.
**Deprecated Workflow**: The deprecated workflow assumes that the first component in the
query pipeline is an `AgentInputComponent` and last is `AgentFnComponent`.
Args:
pipeline (QueryPipeline): Query pipeline
Expand All @@ -63,6 +70,8 @@ class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):

pipeline: QueryPipeline = Field(..., description="Query pipeline")
callback_manager: CallbackManager = Field(..., exclude=True)
task_key: str = Field("task", description="Key to store task in state")
step_state_key: str = Field("step_state", description="Key to store step in state")

class Config:
arbitrary_types_allowed = True
Expand All @@ -71,6 +80,7 @@ def __init__(
self,
pipeline: QueryPipeline,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
"""Initialize."""
if callback_manager is not None:
Expand All @@ -81,14 +91,19 @@ def __init__(
super().__init__(
pipeline=pipeline,
callback_manager=callback_manager,
**kwargs,
)
# validate query pipeline
self.agent_input_component
# self.agent_input_component
self.agent_components

@property
def agent_input_component(self) -> AgentInputComponent:
"""Get agent input component."""
"""Get agent input component.
NOTE: This is deprecated and will be removed in the future.
"""
root_key = self.pipeline.get_root_keys()[0]
if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
raise ValueError(
Expand All @@ -103,6 +118,26 @@ def agent_components(self) -> List[AgentFnComponent]:
"""Get agent output component."""
return _get_agent_components(self.pipeline)

def preprocess(self, task: Task, step: TaskStep) -> None:
"""Preprocessing flow.
This runs preprocessing to propagate the task and step as variables
to relevant components in the query pipeline.
Contains deprecated flow of updating agent components.
But also contains main flow of updating StatefulFnComponent components.
"""
# NOTE: this is deprecated
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)

# update stateful components
self.pipeline.update_state(
{self.task_key: task, self.step_state_key: step.step_state}
)

def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
Expand Down Expand Up @@ -147,11 +182,21 @@ def _get_task_step_response(
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)

agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
self.preprocess(task, step)

# HACK: do a try/except for now. Fine since old agent components are deprecated
try:
self.agent_input_component
uses_deprecated = True
except ValueError:
uses_deprecated = False

if uses_deprecated:
agent_response, is_done = self.pipeline.run(
state=step.step_state, task=task
)
else:
agent_response, is_done = self.pipeline.run()
response = self._get_task_step_response(agent_response, step, is_done)
# sync step state with task state
task.extra_state.update(step.step_state)
Expand All @@ -162,13 +207,21 @@ async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)

agent_response, is_done = await self.pipeline.arun(
state=step.step_state, task=task
)
self.preprocess(task, step)

# HACK: do a try/except for now. Fine since old agent components are deprecated
try:
self.agent_input_component
uses_deprecated = True
except ValueError:
uses_deprecated = False

if uses_deprecated:
agent_response, is_done = await self.pipeline.arun(
state=step.step_state, task=task
)
else:
agent_response, is_done = await self.pipeline.arun()
response = self._get_task_step_response(agent_response, step, is_done)
task.extra_state.update(step.step_state)
return response
Expand Down
5 changes: 5 additions & 0 deletions llama-index-core/llama_index/core/query_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
ChainableMixin,
QueryComponent,
)
from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent
from llama_index.core.query_pipeline.components.loop import LoopComponent

from llama_index.core.base.query_pipeline.query import (
CustomQueryComponent,
)
Expand All @@ -40,4 +43,6 @@
"ChainableMixin",
"QueryComponent",
"CustomQueryComponent",
"StatefulFnComponent",
"LoopComponent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def default_agent_input_fn(task: Any, state: dict) -> dict:


class AgentInputComponent(QueryComponent):
"""Takes in agent inputs and transforms it into desired outputs."""
"""Takes in agent inputs and transforms it into desired outputs.
NOTE: this is now deprecated in favor of using `StatefulFnComponent`.
"""

fn: Callable = Field(..., description="Function to run.")
async_fn: Optional[Callable] = Field(
Expand Down Expand Up @@ -149,6 +153,8 @@ class AgentFnComponent(BaseAgentComponent):
Designed to let users easily modify state.
NOTE: this is now deprecated in favor of using `StatefulFnComponent`.
"""

fn: Callable = Field(..., description="Function to run.")
Expand Down Expand Up @@ -257,6 +263,8 @@ class CustomAgentComponent(BaseAgentComponent):
Designed to let users easily modify state.
NOTE: this is now deprecated in favor of using `StatefulFnComponent`.
"""

callback_manager: CallbackManager = Field(
Expand Down
14 changes: 12 additions & 2 deletions llama-index-core/llama_index/core/query_pipeline/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def update_stateful_components(
def get_and_update_stateful_components(
query_component: QueryComponent, state: Dict[str, Any]
) -> List[BaseStatefulComponent]:
"""Get and update stateful components."""
"""Get and update stateful components.
Assign all stateful components in the query component with the state.
"""
stateful_components = get_stateful_components(query_component)
update_stateful_components(stateful_components, state)
return stateful_components
Expand Down Expand Up @@ -245,10 +249,15 @@ def __init__(
get_and_update_stateful_components(self, state)

def set_state(self, state: Dict[str, Any]) -> None:
"""Update state."""
"""Set state."""
self.state = state
get_and_update_stateful_components(self, state)

def update_state(self, state: Dict[str, Any]) -> None:
"""Update state."""
self.state.update(state)
get_and_update_stateful_components(self, state)

def reset_state(self) -> None:
"""Reset state."""
# use pydantic validator to update state
Expand Down Expand Up @@ -328,6 +337,7 @@ def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None:

self.module_dict[module_key] = cast(QueryComponent, module)
self.dag.add_node(module_key)
# propagate state to new modules added
# TODO: there's more efficient ways to do this
get_and_update_stateful_components(self, self.state)

Expand Down
3 changes: 3 additions & 0 deletions llama-index-core/tests/agent/custom/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python_tests(
name="tests",
)
164 changes: 164 additions & 0 deletions llama-index-core/tests/agent/custom/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Test query pipeline worker."""

from typing import Any, Dict, Set, Tuple

from llama_index.core.agent.custom.pipeline_worker import (
QueryPipelineAgentWorker,
)
from llama_index.core.agent.runner.base import AgentRunner
from llama_index.core.agent.types import Task
from llama_index.core.bridge.pydantic import Field
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.query_pipeline import FnComponent, QueryPipeline
from llama_index.core.query_pipeline.components.agent import (
AgentFnComponent,
AgentInputComponent,
CustomAgentComponent,
)
from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent


def mock_fn(a: str) -> str:
"""Mock function."""
return a + "3"


def mock_agent_input_fn(task: Task, state: dict) -> dict:
"""Mock agent input function."""
if "count" not in state:
state["count"] = 0
state["max_count"] = 2
state["input"] = task.input
return {"a": state["input"]}


def mock_agent_output_fn(
task: Task, state: dict, output: str
) -> Tuple[AgentChatResponse, bool]:
state["count"] += 1
state["input"] = output
is_done = state["count"] >= state["max_count"]
return AgentChatResponse(response=str(output)), is_done


def mock_agent_input_fn_stateful(state: Dict[str, Any]) -> str:
"""Mock agent input function (for StatefulFnComponent)."""
d = mock_agent_input_fn(state["task"], state["step_state"])
return d["a"]


def mock_agent_output_fn_stateful(
state: Dict[str, Any], output: str
) -> Tuple[AgentChatResponse, bool]:
"""Mock agent output function (for StatefulFnComponent)."""
return mock_agent_output_fn(state["task"], state["step_state"], output)


def mock_agent_output_fn(
task: Task, state: dict, output: str
) -> Tuple[AgentChatResponse, bool]:
state["count"] += 1
state["input"] = output
is_done = state["count"] >= state["max_count"]
return AgentChatResponse(response=str(output)), is_done


def test_qp_agent_fn() -> None:
"""Test query pipeline agent.
Implement via function components.
"""
agent_input = AgentInputComponent(fn=mock_agent_input_fn)
fn_component = FnComponent(fn=mock_fn)
agent_output = AgentFnComponent(fn=mock_agent_output_fn)
qp = QueryPipeline(chain=[agent_input, fn_component, agent_output])

agent_worker = QueryPipelineAgentWorker(pipeline=qp)
agent_runner = AgentRunner(agent_worker=agent_worker)

# test create_task
task = agent_runner.create_task("foo")
assert task.input == "foo"

step_output = agent_runner.run_step(task.task_id)
assert str(step_output.output) == "foo3"
assert step_output.is_last is False

step_output = agent_runner.run_step(task.task_id)
assert str(step_output.output) == "foo33"
assert step_output.is_last is True


class MyCustomAgentComponent(CustomAgentComponent):
"""Custom agent component."""

separator: str = Field(default=":", description="Separator")

def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
return {"output": kwargs["a"] + self.separator + kwargs["a"]}

@property
def _input_keys(self) -> Set[str]:
"""Input keys."""
return {"a"}

@property
def _output_keys(self) -> Set[str]:
"""Output keys."""
return {"output"}


def test_qp_agent_custom() -> None:
"""Test query pipeline agent.
Implement via `AgentCustomQueryComponent` subclass.
"""
agent_input = AgentInputComponent(fn=mock_agent_input_fn)
fn_component = MyCustomAgentComponent(separator="/")
agent_output = AgentFnComponent(fn=mock_agent_output_fn)
qp = QueryPipeline(chain=[agent_input, fn_component, agent_output])

agent_worker = QueryPipelineAgentWorker(pipeline=qp)
agent_runner = AgentRunner(agent_worker=agent_worker)

# test create_task
task = agent_runner.create_task("foo")
assert task.input == "foo"

step_output = agent_runner.run_step(task.task_id)
assert str(step_output.output) == "foo/foo"
assert step_output.is_last is False

step_output = agent_runner.run_step(task.task_id)
assert str(step_output.output) == "foo/foo/foo/foo"
assert step_output.is_last is True


def test_qp_agent_stateful_fn() -> None:
"""Test query pipeline agent with stateful components.
The old flows of using `AgentInputComponent` and `AgentFnComponent` are deprecated.
"""
agent_input = StatefulFnComponent(fn=mock_agent_input_fn_stateful)
fn_component = FnComponent(fn=mock_fn)
agent_output = StatefulFnComponent(fn=mock_agent_output_fn_stateful)
qp = QueryPipeline(chain=[agent_input, fn_component, agent_output])

agent_worker = QueryPipelineAgentWorker(pipeline=qp)
agent_runner = AgentRunner(agent_worker=agent_worker)

# test create_task
task = agent_runner.create_task("foo")
assert task.input == "foo"

step_output = agent_runner.run_step(task.task_id)
assert str(step_output.output) == "foo3"
assert step_output.is_last is False

step_output = agent_runner.run_step(task.task_id)
assert str(step_output.output) == "foo33"
assert step_output.is_last is True

0 comments on commit 70a7adf

Please sign in to comment.