Skip to content

Commit

Permalink
agent + query pipeline cleanups (#10563)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Feb 11, 2024
1 parent e3a169b commit 5981dd3
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 243 deletions.
500 changes: 259 additions & 241 deletions docs/examples/agent/agent_runner/query_pipeline_agent.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions llama_index/agent/custom/pipeline_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,9 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None:
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
# reset new memory
task.extra_state["memory"].reset()

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
self.pipeline.set_callback_manager(callback_manager)
5 changes: 5 additions & 0 deletions llama_index/agent/custom/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,8 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None:
# reset new memory
task.extra_state["memory"].reset()
self._finalize_task(task.extra_state, **kwargs)

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
5 changes: 5 additions & 0 deletions llama_index/agent/openai/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,8 @@ def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]:
# # break

# # while cast(AgentChatResponse, last_step_output.output).response !=

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
5 changes: 5 additions & 0 deletions llama_index/agent/react/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,8 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None:
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
5 changes: 5 additions & 0 deletions llama_index/agent/react_multimodal/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,8 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None:
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
29 changes: 28 additions & 1 deletion llama_index/agent/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,33 @@ def __init__(
init_task_state_kwargs: Optional[dict] = None,
delete_task_on_finish: bool = False,
default_tool_choice: str = "auto",
verbose: bool = False,
) -> None:
"""Initialize."""
self.agent_worker = agent_worker
self.state = state or AgentState()
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
self.callback_manager = callback_manager or CallbackManager([])

# get and set callback manager
if callback_manager is not None:
self.agent_worker.set_callback_manager(callback_manager)
self.callback_manager = callback_manager
else:
# TODO: This is *temporary*
# Stopgap before having a callback on the BaseAgentWorker interface.
# Doing that requires a bit more refactoring to make sure existing code
# doesn't break.
if hasattr(self.agent_worker, "callback_manager"):
self.callback_manager = (
self.agent_worker.callback_manager or CallbackManager()
)
else:
self.callback_manager = CallbackManager()

self.init_task_state_kwargs = init_task_state_kwargs or {}
self.delete_task_on_finish = delete_task_on_finish
self.default_tool_choice = default_tool_choice
self.verbose = verbose

@staticmethod
def from_llm(
Expand Down Expand Up @@ -263,10 +281,13 @@ def create_task(self, input: str, **kwargs: Any) -> Task:
)
else:
extra_state = self.init_task_state_kwargs

callback_manager = kwargs.pop("callback_manager", self.callback_manager)
task = Task(
input=input,
memory=self.memory,
extra_state=extra_state,
callback_manager=callback_manager,
**kwargs,
)
# # put input into memory
Expand Down Expand Up @@ -325,6 +346,9 @@ def _run_step(
if input is not None:
step.input = input

if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")

# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible

Expand Down Expand Up @@ -359,6 +383,9 @@ async def _arun_step(
if input is not None:
step.input = input

if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")

# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible
if mode == ChatResponseMode.WAIT:
Expand Down
15 changes: 14 additions & 1 deletion llama_index/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Optional

from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import trace_method
from llama_index.callbacks import CallbackManager, trace_method
from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.llms.types import ChatMessage
Expand Down Expand Up @@ -147,6 +147,9 @@ class Task(BaseModel):
"""

class Config:
arbitrary_types_allowed = True

task_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), type=str, description="Task ID"
)
Expand All @@ -161,6 +164,12 @@ class Task(BaseModel):
),
)

callback_manager: CallbackManager = Field(
default_factory=CallbackManager,
exclude=True,
description="Callback manager for the task.",
)

extra_state: Dict[str, Any] = Field(
default_factory=dict,
description=(
Expand Down Expand Up @@ -220,3 +229,7 @@ async def astream_step(
@abstractmethod
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)

0 comments on commit 5981dd3

Please sign in to comment.