diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a8cb3296..6374a2c0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `BranchTask` in `StructureVisualizer`. - `EvalEngine` for evaluating the performance of an LLM's output against a given input. - `BaseFileLoader.save()` method for saving an Artifact to a destination. +- `Structure.run_stream()` for streaming Events from a Structure as an iterator. ### Changed diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 49a1f05ce..90df9434e 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -74,6 +74,18 @@ Handler 1 ``` +## Stream Iterator + +You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator. + +!!! tip + +Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events. + +```python +--8<-- "docs/griptape-framework/misc/src/events_streaming.py" +``` + ## Context Managers You can also use [EventListener](../../reference/griptape/events/event_listener.md)s as a Python Context Manager. diff --git a/docs/griptape-framework/misc/src/events_streaming.py b/docs/griptape-framework/misc/src/events_streaming.py new file mode 100644 index 000000000..18927269b --- /dev/null +++ b/docs/griptape-framework/misc/src/events_streaming.py @@ -0,0 +1,7 @@ +from griptape.events import BaseEvent +from griptape.structures import Agent + +agent = Agent() + +for event in agent.run_stream("Hi!", event_types=[BaseEvent]): # All Events + print(type(event)) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index fe93e1cce..24d9d5233 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -2,20 +2,27 @@ import uuid from abc import ABC, abstractmethod +from queue import Queue +from threading import Thread from typing import TYPE_CHECKING, Any, Literal, Optional, Union from attrs import Factory, define, field from griptape.common import observable from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events.base_event import BaseEvent +from griptape.events.event_listener import EventListener from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory, Run from griptape.mixins.rule_mixin import RuleMixin from griptape.mixins.runnable_mixin import RunnableMixin from griptape.mixins.serializable_mixin import SerializableMixin +from griptape.utils.contextvars_utils import with_contextvars if TYPE_CHECKING: + from collections.abc import Iterator + from griptape.artifacts import BaseArtifact from griptape.memory.structure import BaseConversationMemory from griptape.tasks import BaseTask @@ -42,6 +49,7 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC): meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _execution_args: tuple = () + _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False) def __attrs_post_init__(self) -> None: tasks = self._tasks.copy() @@ -198,5 +206,25 @@ def run(self, *args) -> Structure: return result + @observable + def run_stream(self, *args, event_types: Optional[list[type[BaseEvent]]] = None) -> Iterator[BaseEvent]: + if event_types is None: + event_types = [BaseEvent] + else: + if FinishStructureRunEvent not in event_types: + event_types = [*event_types, FinishStructureRunEvent] + + with EventListener(self._event_queue.put, event_types=event_types): + t = Thread(target=with_contextvars(self.run), args=args) + t.start() + + while True: + event = self._event_queue.get() + if isinstance(event, FinishStructureRunEvent): + break + else: + yield event + t.join() + @abstractmethod def try_run(self, *args) -> Structure: ... diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 7d90d26e2..22ff5245b 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -1,56 +1,40 @@ from __future__ import annotations import json -from queue import Queue -from threading import Thread from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts.text_artifact import TextArtifact from griptape.events import ( ActionChunkEvent, - BaseChunkEvent, - EventBus, - EventListener, FinishPromptEvent, FinishStructureRunEvent, TextChunkEvent, ) -from griptape.utils.contextvars_utils import with_contextvars if TYPE_CHECKING: from collections.abc import Iterator - from griptape.events.base_event import BaseEvent from griptape.structures import Structure @define class Stream: - """A wrapper for Structures that converts `BaseChunkEvent`s into an iterator of TextArtifacts. - - It achieves this by running the Structure in a separate thread, listening for events from the Structure, - and yielding those events. - - See relevant Stack Overflow post: https://stackoverflow.com/questions/9968592/turn-functions-with-a-callback-into-python-generators + """A wrapper for Structures filters Events relevant to text output and converts them to TextArtifacts. Attributes: structure: The Structure to wrap. - _event_queue: A queue to hold events from the Structure. """ structure: Structure = field() - _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) - def run(self, *args) -> Iterator[TextArtifact]: - t = Thread(target=with_contextvars(self._run_structure), args=args) - t.start() - action_str = "" - while True: - event = self._event_queue.get() + + for event in self.structure.run_stream( + *args, event_types=[TextChunkEvent, ActionChunkEvent, FinishPromptEvent, FinishStructureRunEvent] + ): if isinstance(event, FinishStructureRunEvent): break elif isinstance(event, FinishPromptEvent): @@ -67,18 +51,3 @@ def run(self, *args) -> Iterator[TextArtifact]: action_str = "" except Exception: pass - t.join() - - def _run_structure(self, *args) -> None: - def event_handler(event: BaseEvent) -> None: - self._event_queue.put(event) - - stream_event_listener = EventListener( - on_event=event_handler, - event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent], - ) - EventBus.add_event_listener(stream_event_listener) - - self.structure.run(*args) - - EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 3b0e508e0..21a637ff6 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -1,5 +1,6 @@ import pytest +from griptape.events import FinishStructureRunEvent, FinishTaskEvent, StartTaskEvent from griptape.structures import Agent, Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -130,3 +131,39 @@ def test_from_dict(self): assert len(deserialized_agent.task_outputs) == 1 assert deserialized_agent.task_outputs[task.id].value == "mock output" + + def test_run_stream(self): + from griptape.events import ( + EventBus, + FinishPromptEvent, + FinishStructureRunEvent, + StartPromptEvent, + StartStructureRunEvent, + ) + + agent = Agent() + event_types = [ + StartStructureRunEvent, + StartTaskEvent, + StartPromptEvent, + FinishPromptEvent, + FinishTaskEvent, + FinishStructureRunEvent, + ] + events = agent.run_stream() + + for idx, event in enumerate(events): + assert isinstance(event, event_types[idx]) + assert len(EventBus.event_listeners) == 0 + + def test_run_stream_custom_event_types(self): + from griptape.events import EventBus, FinishPromptEvent, StartPromptEvent, StartStructureRunEvent + + agent = Agent() + event_types = [StartStructureRunEvent, StartPromptEvent, FinishPromptEvent] + expected_event_types = [StartStructureRunEvent, StartPromptEvent, FinishPromptEvent, FinishStructureRunEvent] + events = agent.run_stream(event_types=event_types) + + for idx, event in enumerate(events): + assert isinstance(event, expected_event_types[idx]) + assert len(EventBus.event_listeners) == 0 diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 223d5a9e7..9c1b78801 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -31,6 +31,6 @@ def test_init(self, agent): with pytest.raises(StopIteration): next(chat_stream_run) else: - next(chat_stream.run()) + assert next(chat_stream.run()).value == "\n" with pytest.raises(StopIteration): next(chat_stream.run())