Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stateful and loop components #14235

Merged
merged 6 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 130 additions & 171 deletions docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb

Large diffs are not rendered by default.

85 changes: 69 additions & 16 deletions llama-index-core/llama_index/core/agent/custom/pipeline_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):

Barebones agent worker that takes in a query pipeline.

Assumes that the first component in the query pipeline is an
`AgentInputComponent` and last is `AgentFnComponent`.
**Default Workflow**: The default workflow assumes that you compose
a query pipeline with `StatefulFnComponent` objects. This allows you to store, update
and retrieve state throughout the executions of the query pipeline by the agent.

The task and step state of the agent are stored in this `state` variable via a special key.
Of course you can choose to store other variables in this state as well.

**Deprecated Workflow**: The deprecated workflow assumes that the first component in the
query pipeline is an `AgentInputComponent` and last is `AgentFnComponent`.

Args:
pipeline (QueryPipeline): Query pipeline
Expand All @@ -63,6 +70,8 @@ class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):

pipeline: QueryPipeline = Field(..., description="Query pipeline")
callback_manager: CallbackManager = Field(..., exclude=True)
task_key: str = Field("task", description="Key to store task in state")
step_state_key: str = Field("step_state", description="Key to store step in state")

class Config:
arbitrary_types_allowed = True
Expand All @@ -71,6 +80,7 @@ def __init__(
self,
pipeline: QueryPipeline,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
"""Initialize."""
if callback_manager is not None:
Expand All @@ -81,14 +91,19 @@ def __init__(
super().__init__(
pipeline=pipeline,
callback_manager=callback_manager,
**kwargs,
)
# validate query pipeline
self.agent_input_component
# self.agent_input_component
self.agent_components

@property
def agent_input_component(self) -> AgentInputComponent:
"""Get agent input component."""
"""Get agent input component.

NOTE: This is deprecated and will be removed in the future.

"""
root_key = self.pipeline.get_root_keys()[0]
if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
raise ValueError(
Expand All @@ -103,6 +118,26 @@ def agent_components(self) -> List[AgentFnComponent]:
"""Get agent output component."""
return _get_agent_components(self.pipeline)

def preprocess(self, task: Task, step: TaskStep) -> None:
"""Preprocessing flow.

This runs preprocessing to propagate the task and step as variables
to relevant components in the query pipeline.

Contains deprecated flow of updating agent components.
But also contains main flow of updating StatefulFnComponent components.

"""
# NOTE: this is deprecated
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)

# update stateful components
self.pipeline.update_state(
{self.task_key: task, self.step_state_key: step.step_state}
)

def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
Expand Down Expand Up @@ -147,11 +182,21 @@ def _get_task_step_response(
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)

agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
self.preprocess(task, step)

# HACK: do a try/except for now. Fine since old agent components are deprecated
try:
self.agent_input_component
uses_deprecated = True
except ValueError:
uses_deprecated = False

if uses_deprecated:
agent_response, is_done = self.pipeline.run(
state=step.step_state, task=task
)
else:
agent_response, is_done = self.pipeline.run()
response = self._get_task_step_response(agent_response, step, is_done)
# sync step state with task state
task.extra_state.update(step.step_state)
Expand All @@ -162,13 +207,21 @@ async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)

agent_response, is_done = await self.pipeline.arun(
state=step.step_state, task=task
)
self.preprocess(task, step)

# HACK: do a try/except for now. Fine since old agent components are deprecated
try:
self.agent_input_component
uses_deprecated = True
except ValueError:
uses_deprecated = False

if uses_deprecated:
agent_response, is_done = await self.pipeline.arun(
state=step.step_state, task=task
)
else:
agent_response, is_done = await self.pipeline.arun()
response = self._get_task_step_response(agent_response, step, is_done)
task.extra_state.update(step.step_state)
return response
Expand Down
5 changes: 5 additions & 0 deletions llama-index-core/llama_index/core/query_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
ChainableMixin,
QueryComponent,
)
from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent
from llama_index.core.query_pipeline.components.loop import LoopComponent

from llama_index.core.base.query_pipeline.query import (
CustomQueryComponent,
)
Expand All @@ -40,4 +43,6 @@
"ChainableMixin",
"QueryComponent",
"CustomQueryComponent",
"StatefulFnComponent",
"LoopComponent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
SelectorComponent,
)
from llama_index.core.query_pipeline.components.tool_runner import ToolRunnerComponent
from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent
from llama_index.core.query_pipeline.components.loop import LoopComponent

__all__ = [
"AgentFnComponent",
Expand All @@ -28,4 +30,6 @@
"RouterComponent",
"SelectorComponent",
"ToolRunnerComponent",
"StatefulFnComponent",
"LoopComponent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def default_agent_input_fn(task: Any, state: dict) -> dict:


class AgentInputComponent(QueryComponent):
"""Takes in agent inputs and transforms it into desired outputs."""
"""Takes in agent inputs and transforms it into desired outputs.

NOTE: this is now deprecated in favor of using `StatefulFnComponent`.

"""

fn: Callable = Field(..., description="Function to run.")
async_fn: Optional[Callable] = Field(
Expand Down Expand Up @@ -149,6 +153,8 @@ class AgentFnComponent(BaseAgentComponent):

Designed to let users easily modify state.

NOTE: this is now deprecated in favor of using `StatefulFnComponent`.

"""

fn: Callable = Field(..., description="Function to run.")
Expand Down Expand Up @@ -257,6 +263,8 @@ class CustomAgentComponent(BaseAgentComponent):

Designed to let users easily modify state.

NOTE: this is now deprecated in favor of using `StatefulFnComponent`.

"""

callback_manager: CallbackManager = Field(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from llama_index.core.base.query_pipeline.query import (
InputKeys,
OutputKeys,
QueryComponent,
)
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks.base import CallbackManager
from typing import Any, Dict, Optional, Callable


class LoopComponent(QueryComponent):
"""Loop component."""

pipeline: QueryPipeline = Field(..., description="Query pipeline")
should_exit_fn: Optional[Callable] = Field(..., description="Should exit function")
add_output_to_input_fn: Optional[Callable] = Field(
...,
description="Add output to input function. If not provided, will reuse the original input for the next iteration. If provided, will call the function to combine the output into the input for the next iteration.",
)
max_iterations: Optional[int] = Field(5, description="Max iterations")

class Config:
arbitrary_types_allowed = True

def __init__(
self,
pipeline: QueryPipeline,
should_exit_fn: Optional[Callable] = None,
add_output_to_input_fn: Optional[Callable] = None,
max_iterations: Optional[int] = 5,
) -> None:
"""Init params."""
super().__init__(
pipeline=pipeline,
should_exit_fn=should_exit_fn,
add_output_to_input_fn=add_output_to_input_fn,
max_iterations=max_iterations,
)

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: implement

def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
return input

def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
current_input = kwargs
for i in range(self.max_iterations):
output = self.pipeline.run_component(**current_input)
if self.should_exit_fn:
should_exit = self.should_exit_fn(output)
if should_exit:
break

if self.add_output_to_input_fn:
current_input = self.add_output_to_input_fn(current_input, output)

return output

async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
current_input = kwargs
for i in range(self.max_iterations):
output = await self.pipeline.arun_component(**current_input)
if self.should_exit_fn:
should_exit = self.should_exit_fn(output)
if should_exit:
break

if self.add_output_to_input_fn:
current_input = self.add_output_to_input_fn(current_input, output)

return output

@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return self.pipeline.input_keys

@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return self.pipeline.output_keys
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Agent components."""

from typing import Any, Callable, Dict, Optional, Set

from llama_index.core.base.query_pipeline.query import (
QueryComponent,
)
from llama_index.core.bridge.pydantic import Field
from llama_index.core.query_pipeline.components.function import (
FnComponent,
get_parameters,
)

# from llama_index.core.query_pipeline.components.input import InputComponent


class BaseStatefulComponent(QueryComponent):
"""Takes in agent inputs and transforms it into desired outputs."""

state: Dict[str, Any] = Field(
default_factory=dict, description="State of the pipeline."
)

def reset_state(self) -> None:
"""Reset state."""
self.state = {}


class StatefulFnComponent(BaseStatefulComponent, FnComponent):
"""Query component that takes in an arbitrary function.

Stateful version of `FnComponent`. Expects functions to have `state` as the first argument.

"""

def __init__(
self,
fn: Callable,
req_params: Optional[Set[str]] = None,
opt_params: Optional[Set[str]] = None,
state: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> None:
"""Init params."""
# determine parameters
default_req_params, default_opt_params = get_parameters(fn)
# make sure task and step are part of the list, and remove them from the list
if "state" not in default_req_params:
raise ValueError(
"StatefulFnComponent must have 'state' as required parameters"
)

default_req_params = default_req_params - {"state"}
default_opt_params = default_opt_params - {"state"}

if req_params is None:
req_params = default_req_params
if opt_params is None:
opt_params = default_opt_params

super().__init__(
fn=fn,
req_params=req_params,
opt_params=opt_params,
state=state or {},
**kwargs
)

def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
kwargs.update({"state": self.state})
return super()._run_component(**kwargs)

async def _arun_component(self, **kwargs: Any) -> Any:
"""Async run component."""
kwargs.update({"state": self.state})
return await super()._arun_component(**kwargs)

# @property
# def input_keys(self) -> InputKeys:
# """Input keys."""
# return InputKeys.from_keys(
# required_keys={"state", *self._req_params},
# optional_keys=self._opt_params,
# )

# @property
# def output_keys(self) -> OutputKeys:
# """Output keys."""
# # output can be anything, overrode validate function
# return OutputKeys.from_keys({self.output_key})
Loading
Loading