Skip to content

Commit

Permalink
Fix Agent unintentionally modifying stream for all Prompt Drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 13, 2024
1 parent 7073c50 commit c3722bc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## Fixed

- Exception when calling `Structure.to_json()` after it has run.

### Added

- `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task.
Expand All @@ -19,6 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BranchTask` for selecting which Tasks (if any) to run based on a condition.
- Support for `BranchTask` in `StructureVisualizer`.

### Fixed

- Exception when calling `Structure.to_json()` after it has run.
- `Agent` unintentionally modifying `stream` for all Prompt Drivers.

## [1.0.0] - 2024-12-09

### Added
Expand Down
9 changes: 6 additions & 3 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from copy import copy
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, Factory, define, field
Expand Down Expand Up @@ -38,18 +39,20 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB
def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()

self.prompt_driver.stream = self.stream
# Make a copy to avoid unintentionally modifying the global default Prompt Driver's stream setting
prompt_driver = copy(self.prompt_driver)
prompt_driver.stream = self.stream
if len(self.tasks) == 0:
if self.tools:
task = ToolkitTask(
self.input,
prompt_driver=self.prompt_driver,
prompt_driver=prompt_driver,
tools=self.tools,
max_meta_memory_entries=self.max_meta_memory_entries,
)
else:
task = PromptTask(
self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries
self.input, prompt_driver=prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries
)

self.add_task(task)
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,11 @@ def test_is_running(self):
task.state = BaseTask.State.RUNNING

assert agent.is_running()

def test_stream_mutation(self):
prompt_driver = MockPromptDriver()
agent = Agent(prompt_driver=MockPromptDriver(), stream=True)

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True
assert agent.tasks[0].prompt_driver is not prompt_driver

0 comments on commit c3722bc

Please sign in to comment.