Skip to content

Commit

Permalink
Update Attrs (#1092)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 21, 2024
1 parent 92269e1 commit 3ae3714
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- `BaseConversationMemory.prompt_driver` for use with autopruning.

### Fixed
- Parsing streaming response with some OpenAi compatible services.
Expand Down
14 changes: 7 additions & 7 deletions griptape/memory/structure/base_conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.drivers import BaseConversationMemoryDriver
from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver
from griptape.memory.structure import Run
from griptape.structures import Structure


@define
class BaseConversationMemory(SerializableMixin, ABC):
driver: Optional[BaseConversationMemoryDriver] = field(
default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True
)
prompt_driver: BasePromptDriver = field(
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True
)
runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True})
structure: Structure = field(init=False)
autoload: bool = field(default=True, kw_only=True)
autoprune: bool = field(default=True, kw_only=True)
max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
Expand Down Expand Up @@ -65,9 +66,8 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] =
"""
num_runs_to_fit_in_prompt = len(self.runs)

if self.autoprune and hasattr(self, "structure"):
if self.autoprune:
should_prune = True
prompt_driver = Defaults.drivers_config.prompt_driver
temp_stack = PromptStack()

# Try to determine how many Conversation Memory runs we can
Expand All @@ -82,8 +82,8 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] =
temp_stack.messages.extend(memory_inputs)

# Convert the Prompt Stack into tokens left.
tokens_left = prompt_driver.tokenizer.count_input_tokens_left(
prompt_driver.prompt_stack_to_string(temp_stack),
tokens_left = self.prompt_driver.tokenizer.count_input_tokens_left(
self.prompt_driver.prompt_stack_to_string(temp_stack),
)
if tokens_left > 0:
# There are still tokens left, no need to prune.
Expand Down
3 changes: 0 additions & 3 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def validate_rules(self, _: Attribute, rules: list[Rule]) -> None:
raise ValueError("can't have both rules and rulesets specified")

def __attrs_post_init__(self) -> None:
if self.conversation_memory is not None:
self.conversation_memory.structure = self

tasks = self.tasks.copy()
self.tasks.clear()
self.add_tasks(*tasks)
Expand Down
21 changes: 11 additions & 10 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.9"
openai = "^1.1.1"
attrs = "^23.2.0"
attrs = "^24.2.0"
jinja2 = "^3.1.4"
marshmallow = "^3.21.3"
marshmallow-enum = "^1.5.1"
Expand Down

0 comments on commit 3ae3714

Please sign in to comment.