From c3722bcfb15acf43fe37c675402028b5309a75b9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Dec 2024 10:16:17 -0800 Subject: [PATCH] Fix `Agent` unintentionally modifying `stream` for all Prompt Drivers --- CHANGELOG.md | 9 +++++---- griptape/structures/agent.py | 9 ++++++--- tests/unit/structures/test_agent.py | 8 ++++++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81a27aa..fe7b7707b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased -## Fixed - -- Exception when calling `Structure.to_json()` after it has run. - ### Added - `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task. @@ -19,6 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BranchTask` for selecting which Tasks (if any) to run based on a condition. - Support for `BranchTask` in `StructureVisualizer`. +### Fixed + +- Exception when calling `Structure.to_json()` after it has run. +- `Agent` unintentionally modifying `stream` for all Prompt Drivers. + ## [1.0.0] - 2024-12-09 ### Added diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index be1b73e34..de4d848ea 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import copy from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import Attribute, Factory, define, field @@ -38,18 +39,20 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() - self.prompt_driver.stream = self.stream + # Make a copy to avoid unintentionally modifying the global default Prompt Driver's stream setting + prompt_driver = copy(self.prompt_driver) + prompt_driver.stream = self.stream if len(self.tasks) == 0: if self.tools: task = ToolkitTask( self.input, - prompt_driver=self.prompt_driver, + prompt_driver=prompt_driver, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries, ) else: task = PromptTask( - self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries + self.input, prompt_driver=prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries ) self.add_task(task) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 27211f29e..387910f40 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -276,3 +276,11 @@ def test_is_running(self): task.state = BaseTask.State.RUNNING assert agent.is_running() + + def test_stream_mutation(self): + prompt_driver = MockPromptDriver() + agent = Agent(prompt_driver=MockPromptDriver(), stream=True) + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].prompt_driver.stream is True + assert agent.tasks[0].prompt_driver is not prompt_driver