diff --git a/CHANGELOG.md b/CHANGELOG.md index 98d5f7380..0528b1a46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +### Added +- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAi compatible services. diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 44c053dc4..15d0a9e99 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -10,9 +10,8 @@ from griptape.mixins import SerializableMixin if TYPE_CHECKING: - from griptape.drivers import BaseConversationMemoryDriver + from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.memory.structure import Run - from griptape.structures import Structure @define @@ -20,8 +19,10 @@ class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True ) + prompt_driver: BasePromptDriver = field( + default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) - structure: Structure = field(init=False) autoload: bool = field(default=True, kw_only=True) autoprune: bool = field(default=True, kw_only=True) max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) @@ -65,9 +66,8 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = """ num_runs_to_fit_in_prompt = len(self.runs) - if self.autoprune and hasattr(self, "structure"): + if self.autoprune: should_prune = True - prompt_driver = Defaults.drivers_config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can @@ -82,8 +82,8 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = temp_stack.messages.extend(memory_inputs) # Convert the Prompt Stack into tokens left. - tokens_left = prompt_driver.tokenizer.count_input_tokens_left( - prompt_driver.prompt_stack_to_string(temp_stack), + tokens_left = self.prompt_driver.tokenizer.count_input_tokens_left( + self.prompt_driver.prompt_stack_to_string(temp_stack), ) if tokens_left > 0: # There are still tokens left, no need to prune. diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 0572e289d..63ba02373 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -54,9 +54,6 @@ def validate_rules(self, _: Attribute, rules: list[Rule]) -> None: raise ValueError("can't have both rules and rulesets specified") def __attrs_post_init__(self) -> None: - if self.conversation_memory is not None: - self.conversation_memory.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) diff --git a/poetry.lock b/poetry.lock index 32c3964ef..fc0118d0a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -265,22 +265,22 @@ files = [ [[package]] name = "attrs" -version = "23.2.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, - {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] -tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "babel" @@ -3669,6 +3669,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"}, {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"}, {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"}, ] @@ -6988,4 +6989,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ee23a885217a5285e3a33cac221c55f011cd4ce428b33cd8abfbdac38a27a638" +content-hash = "d368587717dd8496f0db30403afa59ca6ff9e0b4e2d747f2b4c703e832d904c3" diff --git a/pyproject.toml b/pyproject.toml index 6c50013ad..e02c08b6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" openai = "^1.1.1" -attrs = "^23.2.0" +attrs = "^24.2.0" jinja2 = "^3.1.4" marshmallow = "^3.21.3" marshmallow-enum = "^1.5.1"