diff --git a/CHANGELOG.md b/CHANGELOG.md index b450c4ac1..0af22517a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added a new `.add_sass_layer_file()` method to `ui.Theme` that supports reading a Sass file with layer boundary comments, e.g. `/*-- scss:defaults --*/`. This format [is supported by Quarto](https://quarto.org/docs/output-formats/html-themes-more.html#bootstrap-bootswatch-layering) and makes it easier to store Sass rules and declarations that need to be woven into Shiny's Sass Bootstrap files. (#1790) -* The `ui.Chat()` component's `.on_user_submit()` decorator method now passes the user input to the decorated function. This makes it a bit more obvious how to access the user input inside the decorated function. See the new templates (mentioned below) for examples. (#1801) +* The `ui.Chat()` component gains the following: + * The `.on_user_submit()` decorator method now passes the user input to the decorated function. This makes it a bit easier to access the user input. See the new templates (mentioned below) for examples. (#1801) + * A new `get_latest_stream_result()` method was added for an easy way to access the final result of the stream when it completes. (#1846) + * The `.append_message_stream()` method now returns the `reactive.extended_task` instance that it launches. (#1846) * `shiny create` includes new and improved `ui.Chat()` template options. Most of these templates leverage the new [`{chatlas}` package](https://posit-dev.github.io/chatlas/), our opinionated approach to interfacing with various LLM. (#1806) diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index 0ae300122..f9cd874c9 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -210,6 +210,10 @@ def __init__( reactive.Value(None) ) + self._latest_stream: reactive.Value[ + reactive.ExtendedTask[[], str] | None + ] = reactive.Value(None) + # TODO: deprecate messages once we start promoting managing LLM message # state through other means @reactive.effect @@ -607,6 +611,13 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any Use this method (over `.append_message()`) when `stream=True` (or similar) is specified in model's completion method. ``` + + Returns + ------- + : + An extended task that represents the streaming task. The `.result()` method + of the task can be called in a reactive context to get the final state of the + stream. """ message = _utils.wrap_async_iterable(message) @@ -614,10 +625,12 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any # Run the stream in the background to get non-blocking behavior @reactive.extended_task async def _stream_task(): - await self._append_message_stream(message) + return await self._append_message_stream(message) _stream_task() + self._latest_stream.set(_stream_task) + # Since the task runs in the background (outside/beyond the current context, # if any), we need to manually raise any exceptions that occur @reactive.effect @@ -627,6 +640,35 @@ async def _handle_error(): await self._raise_exception(e) _handle_error.destroy() # type: ignore + return _stream_task + + def get_latest_stream_result(self) -> str | None: + """ + Reactively read the latest message stream result. + + This method reads a reactive value containing the result of the latest + `.append_message_stream()`. Therefore, this method must be called in a reactive + context (e.g., a render function, a :func:`~shiny.reactive.calc`, or a + :func:`~shiny.reactive.effect`). + + Returns + ------- + : + The result of the latest stream (a string). + + Raises + ------ + : + A silent exception if no stream has completed yet. + """ + stream = self._latest_stream() + if stream is None: + from .. import req + + req(False) + else: + return stream.result() + async def _append_message_stream(self, message: AsyncIterable[Any]): id = _utils.private_random_id() @@ -636,6 +678,7 @@ async def _append_message_stream(self, message: AsyncIterable[Any]): try: async for msg in message: await self._append_message(msg, chunk=True, stream_id=id) + return self._current_stream_message finally: await self._append_message(empty, chunk="end", stream_id=id) await self._flush_pending_messages() diff --git a/shiny/ui/_markdown_stream.py b/shiny/ui/_markdown_stream.py index 52e11a7a3..fe993922e 100644 --- a/shiny/ui/_markdown_stream.py +++ b/shiny/ui/_markdown_stream.py @@ -7,7 +7,7 @@ from .._docstring import add_example from .._namespaces import resolve_id from .._typing_extensions import TypedDict -from ..session import require_active_session +from ..session import require_active_session, session_context from ..types import NotifyException from ..ui.css import CssUnit, as_css_unit from . import Tag @@ -86,6 +86,11 @@ def __init__( self.on_error = on_error + with session_context(self._session): + self._latest_stream: reactive.Value[ + Union[reactive.ExtendedTask[[], str], None] + ] = reactive.Value(None) + async def stream( self, content: Union[Iterable[str], AsyncIterable[str]], @@ -109,6 +114,13 @@ async def stream( ---- If you already have the content available as a string, you can do `.stream([content])` to set the content. + + Returns + ------- + : + An extended task that represents the streaming task. The `.result()` method + of the task can be called in a reactive context to get the final state of the + stream. """ content = _utils.wrap_async_iterable(content) @@ -139,6 +151,33 @@ async def _handle_error(): return _task + def get_latest_stream_result(self) -> Union[str, None]: + """ + Reactively read the latest stream result. + + This method reads a reactive value containing the result of the latest + `.stream()`. Therefore, this method must be called in a reactive context (e.g., + a render function, a :func:`~shiny.reactive.calc`, or a + :func:`~shiny.reactive.effect`). + + Returns + ------- + : + The result of the latest stream (a string). + + Raises + ------ + : + A silent exception if no stream has completed yet. + """ + stream = self._latest_stream() + if stream is None: + from .. import req + + req(False) + else: + return stream.result() + async def clear(self): """ Empty the UI element of the `MarkdownStream`. diff --git a/tests/playwright/shiny/components/chat/stream-result/app.py b/tests/playwright/shiny/components/chat/stream-result/app.py new file mode 100644 index 000000000..34e254408 --- /dev/null +++ b/tests/playwright/shiny/components/chat/stream-result/app.py @@ -0,0 +1,24 @@ +import asyncio + +from shiny.express import render, ui + +chat = ui.Chat("chat") + +chat.ui() +chat.update_user_input(value="Press Enter to start the stream") + + +async def stream_generator(): + for i in range(10): + await asyncio.sleep(0.25) + yield f"Message {i} \n\n" + + +@chat.on_user_submit +async def _(message: str): + await chat.append_message_stream(stream_generator()) + + +@render.code +async def stream_result_ui(): + return chat.get_latest_stream_result() diff --git a/tests/playwright/shiny/components/chat/stream-result/test_chat_stream_result.py b/tests/playwright/shiny/components/chat/stream-result/test_chat_stream_result.py new file mode 100644 index 000000000..8d6c216da --- /dev/null +++ b/tests/playwright/shiny/components/chat/stream-result/test_chat_stream_result.py @@ -0,0 +1,37 @@ +import re + +from playwright.sync_api import Page, expect +from utils.deploy_utils import skip_on_webkit + +from shiny.playwright import controller +from shiny.run import ShinyAppProc + + +@skip_on_webkit +def test_validate_chat_stream_result(page: Page, local_app: ShinyAppProc) -> None: + page.goto(local_app.url) + + chat = controller.Chat(page, "chat") + stream_result_ui = controller.OutputCode(page, "stream_result_ui") + + expect(chat.loc).to_be_visible(timeout=10 * 1000) + + chat.send_user_input() + + messages = [ + "Message 0", + "Message 1", + "Message 2", + "Message 3", + "Message 4", + "Message 5", + "Message 6", + "Message 7", + "Message 8", + "Message 9", + ] + # Allow for any whitespace between messages + chat.expect_messages(re.compile(r"\s*".join(messages)), timeout=30 * 1000) + + # Verify that the stream result is as expected + stream_result_ui.expect.to_contain_text("Message 9")