Skip to content

Commit

Permalink
Chat's .append_message_stream() now returns an ExtendedTask (#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert authored Feb 14, 2025
1 parent 33bf25d commit 1e9b868
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 3 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Added a new `.add_sass_layer_file()` method to `ui.Theme` that supports reading a Sass file with layer boundary comments, e.g. `/*-- scss:defaults --*/`. This format [is supported by Quarto](https://quarto.org/docs/output-formats/html-themes-more.html#bootstrap-bootswatch-layering) and makes it easier to store Sass rules and declarations that need to be woven into Shiny's Sass Bootstrap files. (#1790)

* The `ui.Chat()` component's `.on_user_submit()` decorator method now passes the user input to the decorated function. This makes it a bit more obvious how to access the user input inside the decorated function. See the new templates (mentioned below) for examples. (#1801)
* The `ui.Chat()` component gains the following:
* The `.on_user_submit()` decorator method now passes the user input to the decorated function. This makes it a bit easier to access the user input. See the new templates (mentioned below) for examples. (#1801)
* A new `get_latest_stream_result()` method was added for an easy way to access the final result of the stream when it completes. (#1846)
* The `.append_message_stream()` method now returns the `reactive.extended_task` instance that it launches. (#1846)

* `shiny create` includes new and improved `ui.Chat()` template options. Most of these templates leverage the new [`{chatlas}` package](https://posit-dev.github.io/chatlas/), our opinionated approach to interfacing with various LLM. (#1806)

Expand Down
45 changes: 44 additions & 1 deletion shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def __init__(
reactive.Value(None)
)

self._latest_stream: reactive.Value[
reactive.ExtendedTask[[], str] | None
] = reactive.Value(None)

# TODO: deprecate messages once we start promoting managing LLM message
# state through other means
@reactive.effect
Expand Down Expand Up @@ -607,17 +611,26 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any
Use this method (over `.append_message()`) when `stream=True` (or similar) is
specified in model's completion method.
```
Returns
-------
:
An extended task that represents the streaming task. The `.result()` method
of the task can be called in a reactive context to get the final state of the
stream.
"""

message = _utils.wrap_async_iterable(message)

# Run the stream in the background to get non-blocking behavior
@reactive.extended_task
async def _stream_task():
await self._append_message_stream(message)
return await self._append_message_stream(message)

_stream_task()

self._latest_stream.set(_stream_task)

# Since the task runs in the background (outside/beyond the current context,
# if any), we need to manually raise any exceptions that occur
@reactive.effect
Expand All @@ -627,6 +640,35 @@ async def _handle_error():
await self._raise_exception(e)
_handle_error.destroy() # type: ignore

return _stream_task

def get_latest_stream_result(self) -> str | None:
"""
Reactively read the latest message stream result.
This method reads a reactive value containing the result of the latest
`.append_message_stream()`. Therefore, this method must be called in a reactive
context (e.g., a render function, a :func:`~shiny.reactive.calc`, or a
:func:`~shiny.reactive.effect`).
Returns
-------
:
The result of the latest stream (a string).
Raises
------
:
A silent exception if no stream has completed yet.
"""
stream = self._latest_stream()
if stream is None:
from .. import req

req(False)
else:
return stream.result()

async def _append_message_stream(self, message: AsyncIterable[Any]):
id = _utils.private_random_id()

Expand All @@ -636,6 +678,7 @@ async def _append_message_stream(self, message: AsyncIterable[Any]):
try:
async for msg in message:
await self._append_message(msg, chunk=True, stream_id=id)
return self._current_stream_message
finally:
await self._append_message(empty, chunk="end", stream_id=id)
await self._flush_pending_messages()
Expand Down
41 changes: 40 additions & 1 deletion shiny/ui/_markdown_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .._docstring import add_example
from .._namespaces import resolve_id
from .._typing_extensions import TypedDict
from ..session import require_active_session
from ..session import require_active_session, session_context
from ..types import NotifyException
from ..ui.css import CssUnit, as_css_unit
from . import Tag
Expand Down Expand Up @@ -86,6 +86,11 @@ def __init__(

self.on_error = on_error

with session_context(self._session):
self._latest_stream: reactive.Value[
Union[reactive.ExtendedTask[[], str], None]
] = reactive.Value(None)

async def stream(
self,
content: Union[Iterable[str], AsyncIterable[str]],
Expand All @@ -109,6 +114,13 @@ async def stream(
----
If you already have the content available as a string, you can do
`.stream([content])` to set the content.
Returns
-------
:
An extended task that represents the streaming task. The `.result()` method
of the task can be called in a reactive context to get the final state of the
stream.
"""

content = _utils.wrap_async_iterable(content)
Expand Down Expand Up @@ -139,6 +151,33 @@ async def _handle_error():

return _task

def get_latest_stream_result(self) -> Union[str, None]:
"""
Reactively read the latest stream result.
This method reads a reactive value containing the result of the latest
`.stream()`. Therefore, this method must be called in a reactive context (e.g.,
a render function, a :func:`~shiny.reactive.calc`, or a
:func:`~shiny.reactive.effect`).
Returns
-------
:
The result of the latest stream (a string).
Raises
------
:
A silent exception if no stream has completed yet.
"""
stream = self._latest_stream()
if stream is None:
from .. import req

req(False)
else:
return stream.result()

async def clear(self):
"""
Empty the UI element of the `MarkdownStream`.
Expand Down
24 changes: 24 additions & 0 deletions tests/playwright/shiny/components/chat/stream-result/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio

from shiny.express import render, ui

chat = ui.Chat("chat")

chat.ui()
chat.update_user_input(value="Press Enter to start the stream")


async def stream_generator():
for i in range(10):
await asyncio.sleep(0.25)
yield f"Message {i} \n\n"


@chat.on_user_submit
async def _(message: str):
await chat.append_message_stream(stream_generator())


@render.code
async def stream_result_ui():
return chat.get_latest_stream_result()
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import re

from playwright.sync_api import Page, expect
from utils.deploy_utils import skip_on_webkit

from shiny.playwright import controller
from shiny.run import ShinyAppProc


@skip_on_webkit
def test_validate_chat_stream_result(page: Page, local_app: ShinyAppProc) -> None:
page.goto(local_app.url)

chat = controller.Chat(page, "chat")
stream_result_ui = controller.OutputCode(page, "stream_result_ui")

expect(chat.loc).to_be_visible(timeout=10 * 1000)

chat.send_user_input()

messages = [
"Message 0",
"Message 1",
"Message 2",
"Message 3",
"Message 4",
"Message 5",
"Message 6",
"Message 7",
"Message 8",
"Message 9",
]
# Allow for any whitespace between messages
chat.expect_messages(re.compile(r"\s*".join(messages)), timeout=30 * 1000)

# Verify that the stream result is as expected
stream_result_ui.expect.to_contain_text("Message 9")

0 comments on commit 1e9b868

Please sign in to comment.