Skip to content

Commit

Permalink
fix agent streaming? (#11675)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Mar 6, 2024
1 parent 8d0607e commit 935c5f6
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 10 deletions.
9 changes: 7 additions & 2 deletions llama-index-core/llama_index/core/agent/react/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import uuid
from functools import partial
from itertools import chain
from threading import Thread
from typing import (
Expand Down Expand Up @@ -529,6 +530,7 @@ def _run_step_stream(
thread = Thread(
target=agent_response.write_response_to_history,
args=(task.extra_state["new_memory"],),
kwargs={"on_stream_end_fn": partial(self.finalize_task, task)},
)
thread.start()

Expand Down Expand Up @@ -592,7 +594,8 @@ async def _arun_step_stream(
# create task to write chat response to history
asyncio.create_task(
agent_response.awrite_response_to_history(
task.extra_state["new_memory"]
task.extra_state["new_memory"],
on_stream_end_fn=partial(self.finalize_task, task),
)
)
# wait until response writing is done
Expand Down Expand Up @@ -628,7 +631,9 @@ async def astream_step(
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
task.memory.set(
task.memory.get_all() + task.extra_state["new_memory"].get_all()
)
# reset new memory
task.extra_state["new_memory"].reset()

Expand Down
10 changes: 8 additions & 2 deletions llama-index-core/llama_index/core/agent/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,10 @@ def _chat(
# ensure tool_choice does not cause endless loops
tool_choice = "auto"

return self.finalize_response(task.task_id, result_output)
return self.finalize_response(
task.task_id,
result_output,
)

async def _achat(
self,
Expand Down Expand Up @@ -556,7 +559,10 @@ async def _achat(
# ensure tool_choice does not cause endless loops
tool_choice = "auto"

return self.finalize_response(task.task_id, result_output)
return self.finalize_response(
task.task_id,
result_output,
)

@trace_method("chat")
def chat(
Expand Down
10 changes: 8 additions & 2 deletions llama-index-core/llama_index/core/agent/runner/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ def _chat(
result_output = cur_step_output
break

return self.finalize_response(task.task_id, result_output)
return self.finalize_response(
task.task_id,
result_output,
)

async def _achat(
self,
Expand Down Expand Up @@ -393,7 +396,10 @@ async def _achat(
result_output = cur_step_output
break

return self.finalize_response(task.task_id, result_output)
return self.finalize_response(
task.task_id,
result_output,
)

@trace_method("chat")
def chat(
Expand Down
10 changes: 9 additions & 1 deletion llama-index-core/llama_index/core/chat_engine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def aput_in_queue(self, delta: Optional[str]) -> None:
self._new_item_event.set()

def write_response_to_history(
self, memory: BaseMemory, raise_error: bool = False
self,
memory: BaseMemory,
on_stream_end_fn: Optional[callable] = None,
raise_error: bool = False,
) -> None:
if self.chat_stream is None:
raise ValueError(
Expand Down Expand Up @@ -131,10 +134,13 @@ def write_response_to_history(

# This act as is_done events for any consumers waiting
self._is_function_not_none_thread_event.set()
if on_stream_end_fn is not None and not self._is_function:
on_stream_end_fn()

async def awrite_response_to_history(
self,
memory: BaseMemory,
on_stream_end_fn: Optional[callable] = None,
) -> None:
if self.achat_stream is None:
raise ValueError(
Expand Down Expand Up @@ -164,6 +170,8 @@ async def awrite_response_to_history(
# These act as is_done events for any consumers waiting
self._is_function_false_event.set()
self._new_item_event.set()
if on_stream_end_fn is not None and not self._is_function:
on_stream_end_fn()

@property
def response_gen(self) -> Generator[str, None, None]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import uuid
from functools import partial
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args

Expand Down Expand Up @@ -285,6 +286,7 @@ def _get_stream_ai_response(
thread = Thread(
target=chat_stream_response.write_response_to_history,
args=(task.extra_state["new_memory"],),
kwargs={"on_stream_end_fn": partial(self.finalize_task, task)},
)
thread.start()
# Wait for the event to be set
Expand All @@ -306,7 +308,8 @@ async def _get_async_stream_ai_response(
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history(
task.extra_state["new_memory"]
task.extra_state["new_memory"],
on_stream_end_fn=partial(self.finalize_task, task),
)
)
# wait until openAI functions stop executing
Expand Down Expand Up @@ -605,7 +608,9 @@ async def astream_step(
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
task.memory.set(
task.memory.get_all() + task.extra_state["new_memory"].get_all()
)
# reset new memory
task.extra_state["new_memory"].reset()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def test_chat_basic(MockSyncOpenAI: MagicMock, add_tool: FunctionTool) -> None:
response = agent.chat("What is 1 + 1?")
assert isinstance(response, AgentChatResponse)
assert response.response == "\n\nThis is a test!"
assert len(agent.chat_history) == 2
assert agent.chat_history[0].content == "What is 1 + 1?"
assert agent.chat_history[1].content == "\n\nThis is a test!"


@patch("llama_index.llms.openai.base.AsyncOpenAI")
Expand All @@ -165,6 +168,9 @@ async def test_achat_basic(MockAsyncOpenAI: MagicMock, add_tool: FunctionTool) -
response = await agent.achat("What is 1 + 1?")
assert isinstance(response, AgentChatResponse)
assert response.response == "\n\nThis is a test!"
assert len(agent.chat_history) == 2
assert agent.chat_history[0].content == "What is 1 + 1?"
assert agent.chat_history[1].content == "\n\nThis is a test!"


@patch("llama_index.llms.openai.base.SyncOpenAI")
Expand All @@ -182,6 +188,9 @@ def test_stream_chat_basic(MockSyncOpenAI: MagicMock, add_tool: FunctionTool) ->
assert isinstance(response, StreamingAgentChatResponse)
# str() strips newline values
assert str(response) == "This is a test!"
assert len(agent.chat_history) == 2
assert agent.chat_history[0].content == "What is 1 + 1?"
assert agent.chat_history[1].content == "This is a test!"


@patch("llama_index.llms.openai.base.AsyncOpenAI")
Expand All @@ -204,6 +213,9 @@ async def test_astream_chat_basic(
assert isinstance(response_stream, StreamingAgentChatResponse)
# str() strips newline values
assert response == "\n\nThis is a test!"
assert len(agent.chat_history) == 2
assert agent.chat_history[0].content == "What is 1 + 1?"
assert agent.chat_history[1].content == "This is a test!"


@patch("llama_index.llms.openai.base.SyncOpenAI")
Expand Down Expand Up @@ -319,6 +331,11 @@ async def test_async_add_step(
# add human input (not used but should be in memory)
task = agent.create_task("What is 1 + 1?")
mock_instance.chat.completions.create.side_effect = mock_achat_stream

# stream the output to ensure it gets written to memory
step_output = await agent.astream_step(task.task_id, input="tmp")
chat_history = task.extra_state["new_memory"].get_all()
async for _ in step_output.output.async_response_gen():
pass

chat_history = task.memory.get_all()
assert "tmp" in [m.content for m in chat_history]

0 comments on commit 935c5f6

Please sign in to comment.