diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e5436df66..2665d282cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. - `CohereEmbeddingDriver` for using Cohere's embeddings API. - `CohereStructureConfig` for providing Structures with quick Cohere configuration. +- `BaseTokenizer.prompt_stack_to_string()` to convert a Prompt Stack to a string. +- `BaseTokenizer.prompt_stack_input_to_string()` to convert a Prompt Stack Input to a ChatML-style message dictionary. ### Changed - **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`. @@ -27,6 +29,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead. - **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead. - **BREAKING**: Removed `OpenAiChatCompletionPromptDriver` as it uses the legacy [OpenAi Completions API](https://platform.openai.com/docs/api-reference/completions). +- **BREAKING**: Removed `BasePromptDriver.count_tokens()`. +- **BREAKING**: Removed `BasePromptDriver.max_output_tokens()`. +- **BREAKING**: Moved `BasePromptDriver.prompt_stack_to_string()` to `BaseTokenizer`. +- **BREAKING**: Moved/renamed `PromptStack.add_to_conversation_memory` to `BaseConversationMemory.add_to_prompt_stack`. +- **BREAKING**: Moved `griptape.constants.RESPONSE_STOP_SEQUENCE` to `ToolkitTask`. +- `ToolkitTask.RESPONSE_STOP_SEQUENCE` is now only added when using `ToolkitTask`. +- `BaseTokenizer.count_tokens()` can now approximately token counts given a Prompt Stack. +- Updated Prompt Drivers to use `BasePromptDriver.max_tokens` instead of using `BasePromptDriver.max_output_tokens()`. - Improved error message when `GriptapeCloudKnowledgeBaseClient` does not have a description set. - Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html). - `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`. diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index f3d03a6844..f64b88dc4f 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,16 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any + from collections.abc import Iterator -from attrs import define, field, Factory -from griptape.drivers import BasePromptDriver +from typing import TYPE_CHECKING, Any + +from attrs import Factory, define, field + from griptape.artifacts import TextArtifact +from griptape.drivers import BasePromptDriver +from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency -from griptape.tokenizers import SimpleTokenizer, BaseTokenizer if TYPE_CHECKING: - from griptape.utils import PromptStack import boto3 + from griptape.utils import PromptStack + @define class AmazonBedrockPromptDriver(BasePromptDriver): @@ -19,7 +23,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver): default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True ) additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) - tokenizer: BaseTokenizer = field(default=Factory(lambda: SimpleTokenizer(characters_per_token=4)), kw_only=True) + tokenizer: BaseTokenizer = field(default=Factory(lambda: AmazonBedrockTokenizer()), kw_only=True) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: response = self.bedrock_client.converse(**self._base_params(prompt_stack)) @@ -40,12 +44,24 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: else: raise Exception("model response is empty") + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + content = [{"text": prompt_input.content}] + + if prompt_input.is_system(): + return {"text": prompt_input.content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} + else: + return {"role": "user", "content": content} + def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [ - {"text": input.content} for input in prompt_stack.inputs if input.is_system() and input.content + self.tokenizer.prompt_stack_input_to_message(input) + for input in prompt_stack.inputs + if input.is_system() and input.content ] messages = [ - {"role": self.__to_amazon_bedrock_role(input), "content": [{"text": input.content}]} + self.tokenizer.prompt_stack_input_to_message(input) for input in prompt_stack.inputs if not input.is_system() ] @@ -57,11 +73,3 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "inferenceConfig": {"temperature": self.temperature}, "additionalModelRequestFields": self.additional_model_request_fields, } - - def __to_amazon_bedrock_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "system" - elif prompt_input.is_assistant(): - return "assistant" - else: - return "user" diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 486233643c..1c7d973266 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -5,7 +5,7 @@ from griptape.artifacts import TextArtifact from griptape.utils import PromptStack, import_optional_dependency from griptape.drivers import BasePromptDriver -from griptape.tokenizers import AnthropicTokenizer +from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer @define @@ -15,7 +15,6 @@ class AnthropicPromptDriver(BasePromptDriver): api_key: Anthropic API key. model: Anthropic model name. client: Custom `Anthropic` client. - tokenizer: Custom `AnthropicTokenizer`. """ api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) @@ -26,11 +25,12 @@ class AnthropicPromptDriver(BasePromptDriver): ), kw_only=True, ) - tokenizer: AnthropicTokenizer = field( + tokenizer: BaseTokenizer = field( default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True ) top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True}) top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) + max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: response = self.client.messages.create(**self._base_params(prompt_stack)) @@ -46,7 +46,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict: messages = [ - {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content} + self.tokenizer.prompt_stack_input_to_message(prompt_input) for prompt_input in prompt_stack.inputs if not prompt_input.is_system() ] @@ -62,16 +62,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, - "max_tokens": self.max_output_tokens(self.prompt_stack_to_string(prompt_stack)), "top_p": self.top_p, "top_k": self.top_k, + "max_tokens": self.max_tokens, **self._prompt_stack_to_model_input(prompt_stack), } - - def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "system" - elif prompt_input.is_assistant(): - return "assistant" - else: - return "user" diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 096035f8be..0acb86bafa 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Callable +from typing import TYPE_CHECKING, Optional from collections.abc import Iterator from attrs import define, field, Factory from griptape.events import StartPromptEvent, FinishPromptEvent, CompletionChunkEvent @@ -32,9 +32,6 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): temperature: float = field(default=0.1, kw_only=True, metadata={"serializable": True}) max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) structure: Optional[Structure] = field(default=None, kw_only=True) - prompt_stack_to_string: Callable[[PromptStack], str] = field( - default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True - ) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory(lambda: (ImportError, ValueError)), kw_only=True ) @@ -42,32 +39,23 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - def max_output_tokens(self, text: str | list) -> int: - tokens_left = self.tokenizer.count_output_tokens_left(text) - - if self.max_tokens: - return min(self.max_tokens, tokens_left) - else: - return tokens_left - - def token_count(self, prompt_stack: PromptStack) -> int: - return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)) - def before_run(self, prompt_stack: PromptStack) -> None: if self.structure: self.structure.publish_event( StartPromptEvent( model=self.model, - token_count=self.token_count(prompt_stack), + token_count=self.tokenizer.count_tokens(prompt_stack), prompt_stack=prompt_stack, - prompt=self.prompt_stack_to_string(prompt_stack), + prompt=self.tokenizer.prompt_stack_to_string(prompt_stack), ) ) def after_run(self, result: TextArtifact) -> None: if self.structure: self.structure.publish_event( - FinishPromptEvent(model=self.model, token_count=result.token_count(self.tokenizer), result=result.value) + FinishPromptEvent( + model=self.model, result=result.value, token_count=self.tokenizer.count_tokens(result.value) + ) ) def run(self, prompt_stack: PromptStack) -> TextArtifact: @@ -92,21 +80,6 @@ def run(self, prompt_stack: PromptStack) -> TextArtifact: else: raise Exception("prompt driver failed after all retry attempts") - def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str: - prompt_lines = [] - - for i in prompt_stack.inputs: - if i.is_user(): - prompt_lines.append(f"User: {i.content}") - elif i.is_assistant(): - prompt_lines.append(f"Assistant: {i.content}") - else: - prompt_lines.append(i.content) - - prompt_lines.append("Assistant:") - - return "\n\n".join(prompt_lines) - @abstractmethod def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ... diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 7a2f39cd6d..4b67ae51ee 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from collections.abc import Iterator from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver -from griptape.tokenizers import CohereTokenizer from griptape.utils import PromptStack, import_optional_dependency +from griptape.tokenizers import BaseTokenizer, CohereTokenizer if TYPE_CHECKING: from cohere import Client @@ -18,7 +18,6 @@ class CoherePromptDriver(BasePromptDriver): api_key: Cohere API key. model: Cohere model name. client: Custom `cohere.Client`. - tokenizer: Custom `CohereTokenizer`. """ api_key: str = field(kw_only=True, metadata={"serializable": False}) @@ -27,7 +26,7 @@ class CoherePromptDriver(BasePromptDriver): default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), kw_only=True, ) - tokenizer: CohereTokenizer = field( + tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True, ) @@ -44,26 +43,23 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: if event.event_type == "text-generation": yield TextArtifact(value=event.text) + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + if prompt_input.is_system(): + return {"role": "SYSTEM", "text": prompt_input.content} + elif prompt_input.is_user(): + return {"role": "USER", "text": prompt_input.content} + else: + return {"role": "ASSISTANT", "text": prompt_input.content} + def _base_params(self, prompt_stack: PromptStack) -> dict: user_message = prompt_stack.inputs[-1].content - history_messages = [self.__to_cohere_message(input) for input in prompt_stack.inputs[:-1]] + + history_messages = [self.tokenizer.prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]] return { "message": user_message, "chat_history": history_messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, + "max_tokens": self.max_tokens, } - - def __to_cohere_message(self, input: PromptStack.Input) -> dict[str, Any]: - return {"role": self.__to_cohere_role(input.role), "text": input.content} - - def __to_cohere_role(self, role: str) -> str: - if role == PromptStack.SYSTEM_ROLE: - return "SYSTEM" - if role == PromptStack.USER_ROLE: - return "USER" - elif role == PromptStack.ASSISTANT_ROLE: - return "CHATBOT" - else: - return "USER" diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 9f833c0354..c9e365c486 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -19,7 +19,6 @@ class GooglePromptDriver(BasePromptDriver): api_key: Google API key. model: Google model name. model_client: Custom `GenerativeModel` client. - tokenizer: Custom `GoogleTokenizer`. top_p: Optional value for top_p. top_k: Optional value for top_k. """ @@ -42,7 +41,7 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: inputs, generation_config=GenerationConfig( stop_sequences=self.tokenizer.stop_sequences, - max_output_tokens=self.max_output_tokens(inputs), + max_output_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, @@ -60,7 +59,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: stream=True, generation_config=GenerationConfig( stop_sequences=self.tokenizer.stop_sequences, - max_output_tokens=self.max_output_tokens(inputs), + max_output_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, @@ -70,6 +69,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: for chunk in response: yield TextArtifact(value=chunk.text) + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + parts = [prompt_input.content] + + if prompt_input.is_assistant(): + return {"role": "model", "parts": parts} + else: + return {"role": "user", "parts": parts} + def _default_model_client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) @@ -90,13 +97,6 @@ def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[Conten def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict: ContentDict = import_optional_dependency("google.generativeai.types").ContentDict + message = self.tokenizer.prompt_stack_input_to_message(prompt_input) - return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]}) - - def __to_google_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "user" - elif prompt_input.is_assistant(): - return "model" - else: - return "user" + return ContentDict(message) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index d6b021f299..c74e00609b 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -52,27 +52,20 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): ) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - prompt = self.__to_prompt(prompt_stack) + prompt = self.tokenizer.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( - prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), **self.params + prompt, return_full_text=False, max_new_tokens=self.max_tokens, **self.params ) return TextArtifact(value=response) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - prompt = self.__to_prompt(prompt_stack) + prompt = self.tokenizer.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( - prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), stream=True, **self.params + prompt, return_full_text=False, max_new_tokens=self.max_tokens, stream=True, **self.params ) for token in response: yield TextArtifact(value=token) - - def __to_prompt(self, prompt_stack: PromptStack) -> str: - tokens = self.tokenizer.tokenizer.apply_chat_template( - [{"role": i.role, "content": i.content} for i in prompt_stack.inputs], add_generation_prompt=True - ) - - return self.tokenizer.tokenizer.decode(tokens) diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 345545b6ff..dea63f552f 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Any, Literal +from typing import Optional, Literal from collections.abc import Iterator import openai from attrs import define, field, Factory @@ -79,15 +79,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: yield TextArtifact(value=delta_content) - def token_count(self, prompt_stack: PromptStack) -> int: - if isinstance(self.tokenizer, OpenAiTokenizer): - return self.tokenizer.count_tokens(self._prompt_stack_to_messages(prompt_stack)) - else: - return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)) - - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict[str, Any]]: - return [{"role": self.__to_openai_role(i), "content": i.content} for i in prompt_stack.inputs] - def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "model": self.model, @@ -102,7 +93,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: # JSON mode still requires a system input instructing the LLM to output JSON. prompt_stack.add_system_input("Provide your response as a valid JSON object.") - messages = self._prompt_stack_to_messages(prompt_stack) + messages = [self.tokenizer.prompt_stack_input_to_message(input) for input in prompt_stack.inputs] if self.max_tokens is not None: params["max_tokens"] = self.max_tokens @@ -110,11 +101,3 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["messages"] = messages return params - - def __to_openai_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "system" - elif prompt_input.is_assistant(): - return "assistant" - else: - return "user" diff --git a/griptape/drivers/prompt_model/base_prompt_model_driver.py b/griptape/drivers/prompt_model/base_prompt_model_driver.py index 096802370d..eb473f2456 100644 --- a/griptape/drivers/prompt_model/base_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/base_prompt_model_driver.py @@ -11,7 +11,6 @@ @define class BasePromptModelDriver(SerializableMixin, ABC): - max_tokens: Optional[int] = field(default=None, kw_only=True) prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) supports_streaming: bool = field(default=True, kw_only=True) diff --git a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py index a5a8a4dc95..4ed7f3ddf2 100644 --- a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py @@ -1,42 +1,45 @@ from __future__ import annotations + from attrs import define, field + from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack, import_optional_dependency from griptape.drivers import BasePromptModelDriver from griptape.tokenizers import HuggingFaceTokenizer +from griptape.tokenizers.base_tokenizer import BaseTokenizer +from griptape.utils import PromptStack, import_optional_dependency @define class SageMakerFalconPromptModelDriver(BasePromptModelDriver): - DEFAULT_MAX_TOKENS = 600 + DEFAULT_MAX_INPUT_TOKENS = 600 _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True) @property - def tokenizer(self) -> HuggingFaceTokenizer: - if self._tokenizer is None: + def tokenizer(self) -> BaseTokenizer: + if self._tokenizer is None and self.prompt_driver is not None: self._tokenizer = HuggingFaceTokenizer( tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained("tiiuae/falcon-40b"), - max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS, + max_output_tokens=self.prompt_driver.max_tokens or self.DEFAULT_MAX_INPUT_TOKENS, ) return self._tokenizer def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - return self.prompt_driver.prompt_stack_to_string(prompt_stack) + return self.tokenizer.prompt_stack_to_string(prompt_stack) def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_model_input(prompt_stack) - stop_sequences = self.prompt_driver.tokenizer.stop_sequences - return { - "max_new_tokens": self.prompt_driver.max_output_tokens(prompt), + "max_new_tokens": self.prompt_driver.max_tokens, "temperature": self.prompt_driver.temperature, "do_sample": True, - "stop": stop_sequences, + "stop": [ + *(self.tokenizer.tokenizer.eos_token if isinstance(self.tokenizer, HuggingFaceTokenizer) else []), + *self.prompt_driver.tokenizer.stop_sequences, + ], } def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: if isinstance(output, list): return TextArtifact(output[0]["generated_text"].strip()) else: - raise ValueError("output must be an instance of 'list'") + raise ValueError("Invalid output format.") diff --git a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py index 7e934d4a65..d02e6155a9 100644 --- a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py @@ -15,28 +15,23 @@ class SageMakerLlamaPromptModelDriver(BasePromptModelDriver): @property def tokenizer(self) -> HuggingFaceTokenizer: - if self._tokenizer is None: + if self._tokenizer is None and self.prompt_driver is not None: self._tokenizer = HuggingFaceTokenizer( tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained( - "meta-llama/Meta-Llama-3-8B-Instruct", model_max_length=self.DEFAULT_MAX_INPUT_TOKENS + "meta-llama/Meta-Llama-3-8B-Instruct" ), - max_output_tokens=self.max_tokens or self.DEFAULT_MAX_INPUT_TOKENS, + max_output_tokens=self.prompt_driver.max_tokens or self.DEFAULT_MAX_INPUT_TOKENS, ) return self._tokenizer def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - return self.tokenizer.tokenizer.apply_chat_template( # pyright: ignore - [{"role": i.role, "content": i.content} for i in prompt_stack.inputs], - tokenize=False, - add_generation_prompt=True, - ) + return self.tokenizer.prompt_stack_to_string(prompt_stack) def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack) return { - "max_new_tokens": self.prompt_driver.max_output_tokens(prompt), + "max_new_tokens": self.prompt_driver.max_tokens, "temperature": self.prompt_driver.temperature, - "stop": self.tokenizer.tokenizer.eos_token, + "stop": [self.tokenizer.tokenizer.eos_token, *self.prompt_driver.tokenizer.stop_sequences], } def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index adfa4b2db1..3b9c9ddef9 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -49,7 +49,7 @@ def query( ) user_message = self.user_template_generator.render(query=query) - message_token_count = self.prompt_driver.token_count( + message_token_count = self.prompt_driver.tokenizer.count_input_tokens_left( PromptStack( inputs=[ PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 6db05c92c2..d5d32cb0dd 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -45,3 +45,50 @@ def try_add_run(self, run: Run) -> None: ... @abstractmethod def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... + + def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: + """Add the Conversation Memory runs to the Prompt Stack by modifying the inputs in place. + + If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack + as possible without exceeding the token limit. + + Args: + memory: The Conversation Memory to add the Prompt Stack to. + index: Optional index to insert the Conversation Memory runs at. + Defaults to appending to the end of the Prompt Stack. + """ + num_runs_to_fit_in_prompt = len(self.runs) + + if self.autoprune and hasattr(self, "structure"): + should_prune = True + prompt_driver = self.structure.config.prompt_driver + temp_stack = PromptStack() + + # Try to determine how many Conversation Memory runs we can + # fit into the Prompt Stack without exceeding the token limit. + while should_prune and num_runs_to_fit_in_prompt > 0: + temp_stack.inputs = prompt_stack.inputs.copy() + + # Add n runs from Conversation Memory. + # Where we insert into the Prompt Stack doesn't matter here + # since we only care about the total token count. + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).inputs + temp_stack.inputs.extend(memory_inputs) + + # Convert the prompt stack into tokens left. + tokens_left = prompt_driver.tokenizer.count_input_tokens_left(temp_stack) + if tokens_left > 0: + # There are still tokens left, no need to prune. + should_prune = False + else: + # There were not any tokens left, prune one run and try again. + num_runs_to_fit_in_prompt -= 1 + + if num_runs_to_fit_in_prompt: + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).inputs + if index: + prompt_stack.inputs[index:index] = memory_inputs + else: + prompt_stack.inputs.extend(memory_inputs) + + return prompt_stack diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 75051db740..694a5050d8 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -34,7 +34,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - stack.add_conversation_memory(memory, 1) + memory.add_to_prompt_stack(stack, 1) return stack diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 3a37274343..e88f572912 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -21,6 +21,9 @@ @define class ToolkitTask(PromptTask, ActionsSubtaskOriginMixin): DEFAULT_MAX_STEPS = 20 + # Stop sequence for chain-of-thought in the framework. Using this "token-like" string to make it more unique, + # so that it doesn't trigger on accident. + RESPONSE_STOP_SEQUENCE = "<|Response|>" tools: list[BaseTool] = field(factory=list, kw_only=True) max_subtasks: int = field(default=DEFAULT_MAX_STEPS, kw_only=True) @@ -74,7 +77,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - stack.add_conversation_memory(memory, 1) + memory.add_to_prompt_stack(stack, 1) return stack @@ -95,17 +98,17 @@ def default_system_template_generator(self, _: PromptTask) -> str: action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), - stop_sequence=utils.constants.RESPONSE_STOP_SEQUENCE, + stop_sequence=self.RESPONSE_STOP_SEQUENCE, ) def default_assistant_subtask_template_generator(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/assistant_subtask.j2").render( - stop_sequence=utils.constants.RESPONSE_STOP_SEQUENCE, subtask=subtask + stop_sequence=self.RESPONSE_STOP_SEQUENCE, subtask=subtask ) def default_user_subtask_template_generator(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/user_subtask.j2").render( - stop_sequence=utils.constants.RESPONSE_STOP_SEQUENCE, subtask=subtask + stop_sequence=self.RESPONSE_STOP_SEQUENCE, subtask=subtask ) def actions_schema(self) -> Schema: @@ -126,6 +129,7 @@ def run(self) -> BaseArtifact: self.subtasks.clear() + self.prompt_driver.tokenizer.stop_sequences.extend([self.RESPONSE_STOP_SEQUENCE]) subtask = self.add_subtask(ActionsSubtask(self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text())) while True: diff --git a/griptape/tokenizers/__init__.py b/griptape/tokenizers/__init__.py index e69473acd4..03b0aefe5b 100644 --- a/griptape/tokenizers/__init__.py +++ b/griptape/tokenizers/__init__.py @@ -7,6 +7,7 @@ from griptape.tokenizers.voyageai_tokenizer import VoyageAiTokenizer from griptape.tokenizers.simple_tokenizer import SimpleTokenizer from griptape.tokenizers.dummy_tokenizer import DummyTokenizer +from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer __all__ = [ @@ -19,4 +20,5 @@ "VoyageAiTokenizer", "SimpleTokenizer", "DummyTokenizer", + "AmazonBedrockTokenizer", ] diff --git a/griptape/tokenizers/amazon_bedrock_tokenizer.py b/griptape/tokenizers/amazon_bedrock_tokenizer.py new file mode 100644 index 0000000000..e3720dea27 --- /dev/null +++ b/griptape/tokenizers/amazon_bedrock_tokenizer.py @@ -0,0 +1,48 @@ +from __future__ import annotations +from attrs import define, field +from typing import TYPE_CHECKING +from griptape.utils import PromptStack +from griptape.tokenizers import SimpleTokenizer + +if TYPE_CHECKING: + pass + + +@define() +class AmazonBedrockTokenizer(SimpleTokenizer): + MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = { + "anthropic.claude-3": 200000, + "anthropic.claude-v2:1": 200000, + "anthropic.claude": 100000, + "cohere.command-r": 128000, + "cohere.embed": 512, + "cohere.command": 4000, + "cohere": 1024, + "ai21": 8192, + "meta-llama3": 8000, + "meta-llama2": 4096, + "mistral": 32000, + "amazon": 4096, + } + MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = { + "anthropic.claude": 4096, + "cohere": 4096, + "ai21.j2": 8191, + "meta": 2048, + "amazon.titan-text-lite": 4096, + "amazon.titan-text-express": 8192, + "amazon.titan-text-premier": 3072, + "mistral": 8192, + } + + characters_per_token: int = field(default=4, kw_only=True) + + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + content = [{"text": prompt_input.content}] + + if prompt_input.is_system(): + return {"text": prompt_input.content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} + else: + return {"role": "user", "content": content} diff --git a/griptape/tokenizers/anthropic_tokenizer.py b/griptape/tokenizers/anthropic_tokenizer.py index 577df7b93a..3cb10bbd40 100644 --- a/griptape/tokenizers/anthropic_tokenizer.py +++ b/griptape/tokenizers/anthropic_tokenizer.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer +from griptape.utils import PromptStack if TYPE_CHECKING: from anthropic import Anthropic @@ -17,8 +18,18 @@ class AnthropicTokenizer(BaseTokenizer): default=Factory(lambda: import_optional_dependency("anthropic").Anthropic()), kw_only=True ) - def count_tokens(self, text: str | list) -> int: + def try_count_tokens(self, text: str) -> int: if isinstance(text, str): return self.client.count_tokens(text) else: raise ValueError("Text must be a string.") + + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + content = prompt_input.content + + if prompt_input.is_system(): + return {"role": "system", "content": content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} + else: + return {"role": "user", "content": content} diff --git a/griptape/tokenizers/base_tokenizer.py b/griptape/tokenizers/base_tokenizer.py index 179d2fb59e..fbffcef365 100644 --- a/griptape/tokenizers/base_tokenizer.py +++ b/griptape/tokenizers/base_tokenizer.py @@ -1,7 +1,8 @@ from __future__ import annotations +from typing import Any from abc import ABC, abstractmethod from attrs import define, field, Factory -from griptape import utils +from griptape.utils import PromptStack @define() @@ -10,7 +11,7 @@ class BaseTokenizer(ABC): MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {} model: str = field(kw_only=True) - stop_sequences: list[str] = field(default=Factory(lambda: [utils.constants.RESPONSE_STOP_SEQUENCE]), kw_only=True) + stop_sequences: list[str] = field(default=Factory(list), kw_only=True) max_input_tokens: int = field(kw_only=True, default=None) max_output_tokens: int = field(kw_only=True, default=None) @@ -21,7 +22,7 @@ def __attrs_post_init__(self) -> None: if self.max_output_tokens is None: self.max_output_tokens = self._default_max_output_tokens() - def count_input_tokens_left(self, text: str | list) -> int: + def count_input_tokens_left(self, text: str | PromptStack) -> int: diff = self.max_input_tokens - self.count_tokens(text) if diff > 0: @@ -29,7 +30,7 @@ def count_input_tokens_left(self, text: str | list) -> int: else: return 0 - def count_output_tokens_left(self, text: str | list) -> int: + def count_output_tokens_left(self, text: str | PromptStack) -> int: diff = self.max_output_tokens - self.count_tokens(text) if diff > 0: @@ -37,8 +38,56 @@ def count_output_tokens_left(self, text: str | list) -> int: else: return 0 + def count_tokens(self, text: str | PromptStack) -> int: + if isinstance(text, PromptStack): + return self.try_count_tokens(self.prompt_stack_to_string(text)) + else: + return self.try_count_tokens(text) + @abstractmethod - def count_tokens(self, text: str | list[dict]) -> int: ... + def try_count_tokens(self, text: Any) -> int: ... + + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + """Converts a Prompt Stack to a string for token counting or model input. + This base implementation will not be very accurate, and should be overridden by subclasses with model-specific tokens. + + Args: + prompt_stack: The Prompt Stack to convert to a string. + + Returns: + A single string representation of the Prompt Stack. + """ + prompt_lines = [] + + for i in prompt_stack.inputs: + if i.is_user(): + prompt_lines.append(f"User: {i.content}") + elif i.is_assistant(): + prompt_lines.append(f"Assistant: {i.content}") + else: + prompt_lines.append(i.content) + + prompt_lines.append("Assistant:") + + return "\n\n".join(prompt_lines) + + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + """Converts a PromptStack Input to a ChatML-style message dictionary for token counting or model input. + + Args: + prompt_input: The PromptStack Input to convert. + + Returns: + A dictionary with the role and content of the input. + """ + content = prompt_input.content + + if prompt_input.is_system(): + return {"role": "system", "content": content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} + else: + return {"role": "user", "content": content} def _default_max_input_tokens(self) -> int: tokens = next((v for k, v in self.MODEL_PREFIXES_TO_MAX_INPUT_TOKENS.items() if self.model.startswith(k)), None) diff --git a/griptape/tokenizers/cohere_tokenizer.py b/griptape/tokenizers/cohere_tokenizer.py index e6845d9ca5..2c53415108 100644 --- a/griptape/tokenizers/cohere_tokenizer.py +++ b/griptape/tokenizers/cohere_tokenizer.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from attrs import define, field from griptape.tokenizers import BaseTokenizer +from griptape.utils import PromptStack if TYPE_CHECKING: from cohere import Client @@ -14,8 +15,13 @@ class CohereTokenizer(BaseTokenizer): client: Client = field(kw_only=True) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - return len(self.client.tokenize(text=text, model=self.model).tokens) + def try_count_tokens(self, text: str) -> int: + return len(self.client.tokenize(text=text, model=self.model).tokens) + + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + if prompt_input.is_system(): + return {"role": "SYSTEM", "text": prompt_input.content} + elif prompt_input.is_user(): + return {"role": "USER", "text": prompt_input.content} else: - raise ValueError("Text must be a string.") + return {"role": "ASSISTANT", "text": prompt_input.content} diff --git a/griptape/tokenizers/dummy_tokenizer.py b/griptape/tokenizers/dummy_tokenizer.py index 74f6d104ce..8bab32f065 100644 --- a/griptape/tokenizers/dummy_tokenizer.py +++ b/griptape/tokenizers/dummy_tokenizer.py @@ -10,5 +10,5 @@ class DummyTokenizer(BaseTokenizer): max_input_tokens: int = field(init=False, default=0, kw_only=True) max_output_tokens: int = field(init=False, default=0, kw_only=True) - def count_tokens(self, text: str | list) -> int: + def try_count_tokens(self, text: str) -> int: raise DummyException(__class__.__name__, "count_tokens") diff --git a/griptape/tokenizers/google_tokenizer.py b/griptape/tokenizers/google_tokenizer.py index 55942f5977..cb6acd7086 100644 --- a/griptape/tokenizers/google_tokenizer.py +++ b/griptape/tokenizers/google_tokenizer.py @@ -3,9 +3,11 @@ from typing import TYPE_CHECKING from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer +from griptape.utils import PromptStack if TYPE_CHECKING: from google.generativeai import GenerativeModel + from google.generativeai.types import ContentDict @define() @@ -18,14 +20,32 @@ class GoogleTokenizer(BaseTokenizer): default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True ) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str) or isinstance(text, list): - return self.model_client.count_tokens(text).total_tokens + def count_tokens(self, text: str | PromptStack) -> int: + ContentDict = import_optional_dependency("google.generativeai.types").ContentDict + + if isinstance(text, PromptStack): + messages = [ContentDict(self.prompt_stack_input_to_message(input)) for input in text.inputs] + + return self.try_count_tokens(messages) else: - raise ValueError("Text must be a string or a list.") + return self.try_count_tokens( + ContentDict(self.prompt_stack_input_to_message(PromptStack.Input(content=text, role="user"))) + ) + + def try_count_tokens(self, text: ContentDict | list[ContentDict]) -> int: + print(text) + return self.model_client.count_tokens(text).total_tokens def _default_model_client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model) + + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + parts = [prompt_input.content] + + if prompt_input.is_assistant(): + return {"role": "model", "parts": parts} + else: + return {"role": "user", "parts": parts} diff --git a/griptape/tokenizers/huggingface_tokenizer.py b/griptape/tokenizers/huggingface_tokenizer.py index 84023a378e..663bd3cc0d 100644 --- a/griptape/tokenizers/huggingface_tokenizer.py +++ b/griptape/tokenizers/huggingface_tokenizer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from attrs import define, field, Factory +from griptape.utils import PromptStack from griptape.tokenizers import BaseTokenizer if TYPE_CHECKING: @@ -16,5 +17,31 @@ class HuggingFaceTokenizer(BaseTokenizer): ) max_output_tokens: int = field(kw_only=True) # pyright: ignore[reportGeneralTypeIssues] - def count_tokens(self, text: str | list) -> int: + def count_tokens(self, text: str | PromptStack) -> int: + if isinstance(text, PromptStack): + tokens = self.__prompt_stack_to_tokens(text) + + return len(tokens) + else: + return self.try_count_tokens(text) + + def try_count_tokens(self, text: str) -> int: return len(self.tokenizer.encode(text)) + + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + + def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + return {"role": prompt_input.role, "content": prompt_input.content} + + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + tokens = self.tokenizer.apply_chat_template( + [self.prompt_stack_input_to_message(i) for i in prompt_stack.inputs], + add_generation_prompt=True, + tokenize=True, + ) + + if isinstance(tokens, list): + return tokens + else: + raise ValueError("Invalid output type.") diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index ec127ca1a0..f4c6b4d51a 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -2,8 +2,9 @@ import logging from attrs import define import tiktoken -from griptape.tokenizers import BaseTokenizer from typing import Optional +from griptape.tokenizers import BaseTokenizer +from griptape.utils import PromptStack @define() @@ -64,10 +65,24 @@ def _default_max_output_tokens(self) -> int: else: return tokens - def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> int: - """ - Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: + def count_tokens(self, text: str | PromptStack) -> int: + if isinstance(text, PromptStack): + messages = [self.prompt_stack_input_to_message(i) for i in text.inputs] + + return self.__count_tokens_messages(messages, self.model) + else: + return self.try_count_tokens(text) + + def try_count_tokens(self, text: str) -> int: + return len(self.encoding.encode(text, allowed_special=set(self.stop_sequences))) + + def __count_tokens_messages(self, text: str | list[dict], model: Optional[str] = None) -> int: + """Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + Args: + text: A single message or a list of messages. + model: The model to use. Defaults to None. """ if isinstance(text, list): model = model if model else self.model @@ -97,16 +112,16 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i tokens_per_name = -1 elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: logging.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - return self.count_tokens(text, model="gpt-3.5-turbo-0613") + return self.__count_tokens_messages(text, model="gpt-3.5-turbo-0613") elif "gpt-4o" in model: logging.info("gpt-4o may update over time. Returning num tokens assuming gpt-4o-2024-05-13.") - return self.count_tokens(text, model="gpt-4o-2024-05-13") + return self.__count_tokens_messages(text, model="gpt-4o-2024-05-13") elif "gpt-4" in model: logging.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - return self.count_tokens(text, model="gpt-4-0613") + return self.__count_tokens_messages(text, model="gpt-4-0613") else: raise NotImplementedError( - f"""token_count() is not implemented for model {model}. + f"""count_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" ) diff --git a/griptape/tokenizers/simple_tokenizer.py b/griptape/tokenizers/simple_tokenizer.py index 864b92f962..06deefbb6f 100644 --- a/griptape/tokenizers/simple_tokenizer.py +++ b/griptape/tokenizers/simple_tokenizer.py @@ -11,10 +11,7 @@ class SimpleTokenizer(BaseTokenizer): max_output_tokens: int = field(kw_only=True, default=0) characters_per_token: int = field(kw_only=True) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token + def try_count_tokens(self, text: str) -> int: + num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token - return num_tokens - else: - raise ValueError("Text must be a string.") + return num_tokens diff --git a/griptape/tokenizers/voyageai_tokenizer.py b/griptape/tokenizers/voyageai_tokenizer.py index 565e53faac..922e6e6862 100644 --- a/griptape/tokenizers/voyageai_tokenizer.py +++ b/griptape/tokenizers/voyageai_tokenizer.py @@ -26,8 +26,5 @@ class VoyageAiTokenizer(BaseTokenizer): kw_only=True, ) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - return self.client.count_tokens([text]) - else: - raise ValueError("Text must be a str.") + def try_count_tokens(self, text: str) -> int: + return self.client.count_tokens([text]) diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 64ca9a9f78..2c4eb431d6 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -14,7 +14,6 @@ from .import_utils import import_optional_dependency from .import_utils import is_dependency_installed from .stream import Stream -from .constants import Constants as constants from .load_artifact_from_memory import load_artifact_from_memory from .deprecation import deprecation_warn @@ -40,7 +39,6 @@ def minify_json(value: str) -> str: "remove_null_values_in_dict_recursively", "dict_merge", "Stream", - "constants", "load_artifact_from_memory", "deprecation_warn", "load_file", diff --git a/griptape/utils/constants.py b/griptape/utils/constants.py deleted file mode 100644 index 7bee76750b..0000000000 --- a/griptape/utils/constants.py +++ /dev/null @@ -1,4 +0,0 @@ -class Constants: - # Stop sequence for chain-of-thought in the framework. Using this "token-like" string to make it more unique, - # so that it doesn't trigger on accident. - RESPONSE_STOP_SEQUENCE = "<|Response|>" diff --git a/griptape/utils/prompt_stack.py b/griptape/utils/prompt_stack.py index f04cef4863..fcb7e74d22 100644 --- a/griptape/utils/prompt_stack.py +++ b/griptape/utils/prompt_stack.py @@ -5,7 +5,7 @@ from griptape.mixins import SerializableMixin if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.tokenizers import BaseTokenizer @define @@ -15,6 +15,8 @@ class PromptStack(SerializableMixin): ASSISTANT_ROLE = "assistant" SYSTEM_ROLE = "system" + tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) # TODO: This should not belong here + @define class Input(SerializableMixin): content: str = field(metadata={"serializable": True}) @@ -50,50 +52,3 @@ def add_user_input(self, content: str) -> Input: def add_assistant_input(self, content: str) -> Input: return self.add_input(content, self.ASSISTANT_ROLE) - - def add_conversation_memory(self, memory: BaseConversationMemory, index: Optional[int] = None) -> list[Input]: - """Add the Conversation Memory runs to the Prompt Stack. - - If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack - as possible without exceeding the token limit. - - Args: - memory: The Conversation Memory to add the Prompt Stack to. - index: Optional index to insert the Conversation Memory runs at. - Defaults to appending to the end of the Prompt Stack. - """ - num_runs_to_fit_in_prompt = len(memory.runs) - - if memory.autoprune and hasattr(memory, "structure"): - should_prune = True - prompt_driver = memory.structure.config.prompt_driver - temp_stack = PromptStack() - - # Try to determine how many Conversation Memory runs we can - # fit into the Prompt Stack without exceeding the token limit. - while should_prune and num_runs_to_fit_in_prompt > 0: - temp_stack.inputs = self.inputs.copy() - - # Add n runs from Conversation Memory. - # Where we insert into the Prompt Stack doesn't matter here - # since we only care about the total token count. - memory_inputs = memory.to_prompt_stack(num_runs_to_fit_in_prompt).inputs - temp_stack.inputs.extend(memory_inputs) - - # Convert the prompt stack into tokens left. - prompt_string = prompt_driver.prompt_stack_to_string(temp_stack) - tokens_left = prompt_driver.tokenizer.count_input_tokens_left(prompt_string) - if tokens_left > 0: - # There are still tokens left, no need to prune. - should_prune = False - else: - # There were not any tokens left, prune one run and try again. - num_runs_to_fit_in_prompt -= 1 - - if num_runs_to_fit_in_prompt: - memory_inputs = memory.to_prompt_stack(num_runs_to_fit_in_prompt).inputs - if index: - self.inputs[index:index] = memory_inputs - else: - self.inputs.extend(memory_inputs) - return self.inputs diff --git a/tests/mocks/mock_tokenizer.py b/tests/mocks/mock_tokenizer.py index a333f9a13b..09b4711d61 100644 --- a/tests/mocks/mock_tokenizer.py +++ b/tests/mocks/mock_tokenizer.py @@ -9,5 +9,5 @@ class MockTokenizer(BaseTokenizer): max_input_tokens: int = field(default=1000, kw_only=True) max_output_tokens: int = field(default=1000, kw_only=True) - def count_tokens(self, text: str | list[dict]) -> int: + def try_count_tokens(self, text: str) -> int: return len(text) diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 8279fb091c..1dd83f96c8 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -18,7 +18,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "AnthropicPromptDriver", "temperature": 0.1, - "max_tokens": None, + "max_tokens": 1000, "stream": False, "model": "claude-3-opus-20240229", "top_p": 0.999, diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py index c6692e1ba1..d1b621750a 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py @@ -1,7 +1,6 @@ from botocore.response import StreamingBody from griptape.artifacts import TextArtifact from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerFalconPromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer, OpenAiTokenizer from griptape.utils import PromptStack from io import BytesIO from unittest.mock import Mock @@ -25,23 +24,6 @@ def mock_client(self, mocker): def test_init(self): assert AmazonSageMakerPromptDriver(endpoint="foo", prompt_model_driver=SageMakerFalconPromptModelDriver()) - def test_custom_tokenizer(self): - assert isinstance( - AmazonSageMakerPromptDriver( - endpoint="foo", prompt_model_driver=SageMakerFalconPromptModelDriver() - ).tokenizer, - HuggingFaceTokenizer, - ) - - assert isinstance( - AmazonSageMakerPromptDriver( - endpoint="foo", - tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), - prompt_model_driver=SageMakerFalconPromptModelDriver(), - ).tokenizer, - OpenAiTokenizer, - ) - def test_try_run(self, mock_model_driver, mock_client): # Given driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver) diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index c5009afac1..22178bbf39 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -63,9 +63,9 @@ def test_try_run(self, mock_client, model, system_enabled): # Then mock_client.return_value.messages.create.assert_called_once_with( messages=expected_messages, - stop_sequences=["<|Response|>"], + stop_sequences=[], model=driver.model, - max_tokens=4091, + max_tokens=1000, temperature=0.1, top_p=0.999, top_k=250, @@ -106,9 +106,9 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Then mock_stream_client.return_value.messages.create.assert_called_once_with( messages=expected_messages, - stop_sequences=["<|Response|>"], + stop_sequences=[], model=driver.model, - max_tokens=4091, + max_tokens=1000, temperature=0.1, stream=True, top_p=0.999, diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 1a06b5907b..0743402aa6 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -39,35 +39,6 @@ def test_run_via_pipeline_publishes_events(self, mocker): def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(inputs=[])), TextArtifact) - def test_token_count(self): - assert ( - MockPromptDriver().token_count( - PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)]) - ) - == 24 - ) - - def test_max_output_tokens(self): - assert MockPromptDriver().max_output_tokens("foobar") == 4090 - assert MockPromptDriver(max_tokens=5000).max_output_tokens("foobar") == 4090 - assert MockPromptDriver(max_tokens=100).max_output_tokens("foobar") == 100 - - def test_prompt_stack_to_string(self): - assert ( - MockPromptDriver().prompt_stack_to_string( - PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)]) - ) - == "User: foobar\n\nAssistant:" - ) - - def test_custom_prompt_stack_to_string(self): - assert ( - MockPromptDriver( - prompt_stack_to_string=lambda stack: f"Foo: {stack.inputs[0].content}" - ).prompt_stack_to_string(PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)])) - == "Foo: foobar" - ) - def instance_count(instances, clazz): return len([instance for instance in instances if isinstance(instance, clazz)]) diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 6e38bd503d..f655d3e516 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -2,7 +2,6 @@ from griptape.drivers import GooglePromptDriver from griptape.utils import PromptStack from unittest.mock import Mock -from tests.mocks.mock_tokenizer import MockTokenizer import pytest @@ -32,9 +31,7 @@ def test_try_run(self, mock_generative_model): prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") prompt_stack.add_generic_input("generic-input") - driver = GooglePromptDriver( - model="gemini-pro", api_key="api-key", tokenizer=MockTokenizer(model="gemini-pro"), top_p=0.5, top_k=50 - ) + driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) # When text_artifact = driver.try_run(prompt_stack) @@ -47,7 +44,7 @@ def test_try_run(self, mock_generative_model): {"parts": ["generic-input"], "role": "user"}, ], generation_config=GenerationConfig( - max_output_tokens=997, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=["<|Response|>"] + max_output_tokens=None, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] ), ) assert text_artifact.value == "model-output" @@ -59,14 +56,7 @@ def test_try_stream(self, mock_stream_generative_model): prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") prompt_stack.add_generic_input("generic-input") - driver = GooglePromptDriver( - model="gemini-pro", - api_key="api-key", - stream=True, - tokenizer=MockTokenizer(model="gemini-pro"), - top_p=0.5, - top_k=50, - ) + driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When text_artifact_stream = driver.try_stream(prompt_stack) @@ -80,9 +70,7 @@ def test_try_stream(self, mock_stream_generative_model): {"parts": ["generic-input"], "role": "user"}, ], stream=True, - generation_config=GenerationConfig( - max_output_tokens=997, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=["<|Response|>"] - ), + generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), ) assert text_artifact.value == "model-output" @@ -108,26 +96,3 @@ def test_prompt_stack_to_model_input(self): {"role": "model", "parts": ["assistant-input"]}, {"role": "user", "parts": ["user-input"]}, ] - - def test_to_content_dict(self): - # Given - driver = GooglePromptDriver(model="gemini-pro", api_key="1234") - - # When - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("system-input", "system")) == { - "role": "user", - "parts": ["system-input"], - } - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("user-input", "user")) == { - "role": "user", - "parts": ["user-input"], - } - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("assistant-input", "assistant")) == { - "role": "model", - "parts": ["assistant-input"], - } - - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("generic-input", "generic")) == { - "role": "user", - "parts": ["generic-input"], - } diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 6b91b56cbe..15bbb4ead0 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -10,6 +10,13 @@ def mock_client(self, mocker): mock_client.text_generation.return_value = "model-output" return mock_client + @pytest.fixture(autouse=True) + def tokenizer(self, mocker): + from_pretrained = tokenizer = mocker.patch("transformers.AutoTokenizer").from_pretrained + from_pretrained.return_value.apply_chat_template.return_value = [1, 2, 3] + + return tokenizer + @pytest.fixture def mock_client_stream(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 1f8f07cf5e..05188d7caf 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -161,30 +161,6 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack ) assert text_artifact.value == "model-output" - def test_try_run_max_tokens_limited_by_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): - # Given - max_tokens_request = 9999999 - driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=max_tokens_request - ) - tokens_left = driver.tokenizer.count_input_tokens_left(driver._prompt_stack_to_messages(prompt_stack)) - - # When - text_artifact = driver.try_run(prompt_stack) - - # Then - mock_chat_completion_create.assert_called_once_with( - model=driver.model, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - messages=messages, - max_tokens=max_tokens_request, - seed=driver.seed, - ) - assert max_tokens_request > tokens_left - assert text_artifact.value == "model-output" - def test_try_run_throws_when_prompt_stack_is_string(self): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) @@ -209,53 +185,6 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_ # Then e.value.args[0] == "Completion with more than one choice is not supported yet." - def test_token_count(self, prompt_stack, messages): - # Given - mock_tokenizer = Mock(spec=OpenAiTokenizer) - mock_tokenizer.count_tokens.return_value = 42 - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=mock_tokenizer) - - # When - token_count = driver.token_count(prompt_stack) - - # Then - mock_tokenizer.count_tokens.assert_called_once_with(messages) - assert token_count == 42 - - # Given - mock_tokenizer = Mock() - mock_tokenizer.count_tokens.return_value = 42 - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=mock_tokenizer) - - # When - token_count = driver.token_count(prompt_stack) - - # Then - mock_tokenizer.count_tokens.assert_called_once_with(driver.prompt_stack_to_string(prompt_stack)) - assert token_count == 42 - - def test_max_output_tokens(self, messages): - # Given - mock_tokenizer = Mock() - mock_tokenizer.count_output_tokens_left.return_value = 42 - driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=mock_tokenizer, max_tokens=45 - ) - - # When - max_output_tokens = driver.max_output_tokens(messages) - - # Then - mock_tokenizer.count_output_tokens_left.assert_called_once_with(messages) - assert max_output_tokens == 42 - - def test_max_output_tokens_with_max_tokens(self, messages): - max_tokens = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=42 - ).max_output_tokens(messages) - - assert max_tokens == 42 - def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, @@ -272,7 +201,12 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa temperature=driver.temperature, stop=driver.tokenizer.stop_sequences, user=driver.user, - messages=messages, + messages=[ + {"role": "generic", "content": "generic-input"}, + {"role": "system", "content": "system-input"}, + {"role": "user", "content": "user-input"}, + {"role": "assistant", "content": "assistant-input"}, + ], seed=driver.seed, max_tokens=1, ) diff --git a/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py index 78d9902299..2dad689ff2 100644 --- a/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py +++ b/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py @@ -5,6 +5,15 @@ class TestSageMakerFalconPromptModelDriver: + @pytest.fixture(autouse=True) + def tokenizer(self, mocker): + from_pretrained = tokenizer = mocker.patch("transformers.AutoTokenizer").from_pretrained + from_pretrained.return_value.apply_chat_template.return_value = [1, 2, 3] + from_pretrained.return_value.decode.return_value = "foo\n\nUser: bar" + from_pretrained.return_value.model_max_length = 8000 + + return tokenizer + @pytest.fixture def driver(self): return AmazonSageMakerPromptDriver( @@ -12,6 +21,7 @@ def driver(self): session=boto3.Session(region_name="us-east-1"), prompt_model_driver=SageMakerFalconPromptModelDriver(), temperature=0.12345, + max_tokens=590, ).prompt_model_driver @pytest.fixture @@ -40,4 +50,4 @@ def test_process_output(self, driver, stack): assert driver.process_output([{"generated_text": "foobar"}]).value == "foobar" def test_tokenizer_max_model_length(self, driver): - assert driver.tokenizer.tokenizer.model_max_length == 2048 + assert driver.tokenizer.tokenizer.model_max_length == 8000 diff --git a/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py index b39ce458e0..986d106765 100644 --- a/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py +++ b/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py @@ -6,21 +6,14 @@ class TestSageMakerLlamaPromptModelDriver: @pytest.fixture(autouse=True) - def llama3_instruct_tokenizer(self, mocker): - tokenizer = mocker.patch("transformers.AutoTokenizer").return_value - tokenizer.model_max_length = 8000 + def tokenizer(self, mocker): + from_pretrained = tokenizer = mocker.patch("transformers.AutoTokenizer").from_pretrained + from_pretrained.return_value.apply_chat_template.return_value = [1, 2, 3] + from_pretrained.return_value.decode.return_value = "model-output" + from_pretrained.return_value.model_max_length = 8000 return tokenizer - @pytest.fixture(autouse=True) - def hugging_face_tokenizer(self, mocker, llama3_instruct_tokenizer): - tokenizer = mocker.patch( - "griptape.drivers.prompt_model.sagemaker_llama_prompt_model_driver.HuggingFaceTokenizer" - ).return_value - tokenizer.count_output_tokens_left.return_value = 7991 - tokenizer.tokenizer = llama3_instruct_tokenizer - return tokenizer - @pytest.fixture def driver(self): return AmazonSageMakerPromptDriver( @@ -29,6 +22,7 @@ def driver(self): session=boto3.Session(region_name="us-east-1"), prompt_model_driver=SageMakerLlamaPromptModelDriver(), temperature=0.12345, + max_tokens=7991, ).prompt_model_driver @pytest.fixture @@ -43,15 +37,17 @@ def stack(self): def test_init(self, driver): assert driver.prompt_driver is not None - def test_prompt_stack_to_model_input(self, driver, stack, hugging_face_tokenizer): - driver.prompt_stack_to_model_input(stack) + def test_prompt_stack_to_model_input(self, driver, stack): + result = driver.prompt_stack_to_model_input(stack) - hugging_face_tokenizer.tokenizer.apply_chat_template.assert_called_once_with( + driver.tokenizer.tokenizer.apply_chat_template.assert_called_once_with( [{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}], - tokenize=False, + tokenize=True, add_generation_prompt=True, ) + assert result == "model-output" + def test_prompt_stack_to_model_params(self, driver, stack): assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 7991 assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 665dca9b29..c6eba3df3d 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,7 +1,10 @@ import json +from griptape.structures import Agent +from griptape.utils import PromptStack from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory from griptape.structures import Pipeline from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tokenizer import MockTokenizer from griptape.tasks import PromptTask @@ -69,3 +72,97 @@ def test_buffering(self): assert len(pipeline.conversation_memory.runs) == 2 assert pipeline.conversation_memory.runs[0].input == "run4" assert pipeline.conversation_memory.runs[1].input == "run5" + + def test_add_to_prompt_stack_autopruing_disabled(self): + agent = Agent(prompt_driver=MockPromptDriver()) + memory = ConversationMemory( + autoprune=False, + runs=[ + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack) + + assert len(prompt_stack.inputs) == 12 + + def test_add_to_prompt_stack_autopruing_enabled(self): + # All memory is pruned. + agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) + memory = ConversationMemory( + autoprune=True, + runs=[ + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + prompt_stack.add_system_input("fizz") + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack) + + assert len(prompt_stack.inputs) == 3 + + # No memory is pruned. + agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) + memory = ConversationMemory( + autoprune=True, + runs=[ + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + prompt_stack.add_system_input("fizz") + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack) + + assert len(prompt_stack.inputs) == 13 + + # One memory is pruned. + # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens + # so that a single memory is pruned. + agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) + memory = ConversationMemory( + autoprune=True, + runs=[ + # All of these sum to 155 tokens with the MockTokenizer. + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + # And then another 6 tokens from fizz for a total of 161 tokens. + prompt_stack.add_system_input("fizz") + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack, 1) + + # We expect one run (2 prompt stack inputs) to be pruned. + assert len(prompt_stack.inputs) == 11 + assert prompt_stack.inputs[0].content == "fizz" + assert prompt_stack.inputs[1].content == "foo2" + assert prompt_stack.inputs[2].content == "bar2" + assert prompt_stack.inputs[-2].content == "foo" + assert prompt_stack.inputs[-1].content == "bar" diff --git a/tests/unit/tokenizers/test_openai_tokenizer.py b/tests/unit/tokenizers/test_openai_tokenizer.py index b27080aa30..034ee68fe9 100644 --- a/tests/unit/tokenizers/test_openai_tokenizer.py +++ b/tests/unit/tokenizers/test_openai_tokenizer.py @@ -1,5 +1,6 @@ import pytest from griptape.tokenizers import OpenAiTokenizer +from griptape.utils.prompt_stack import PromptStack class TestOpenAiTokenizer: @@ -46,10 +47,15 @@ def test_initialize_with_unknown_model(self): ], indirect=["tokenizer"], ) - def test_token_count_for_messages(self, tokenizer, expected): + def test_token_count_for_prompt_stack(self, tokenizer, expected): assert ( tokenizer.count_tokens( - [{"role": "system", "content": "foobar baz"}, {"role": "user", "content": "how foobar am I?"}] + PromptStack( + inputs=[ + PromptStack.Input("foobar baz", role=PromptStack.SYSTEM_ROLE), + PromptStack.Input("how foobar am I?", role=PromptStack.USER_ROLE), + ] + ) ) == expected ) diff --git a/tests/unit/utils/test_base_tokenizer.py b/tests/unit/utils/test_base_tokenizer.py new file mode 100644 index 0000000000..9486159e5f --- /dev/null +++ b/tests/unit/utils/test_base_tokenizer.py @@ -0,0 +1,20 @@ +from griptape.utils import PromptStack +from tests.mocks.mock_tokenizer import MockTokenizer + + +class TestBaseTokenizer: + def test_token_count(self): + assert ( + MockTokenizer(model="foo bar").count_tokens( + PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)]) + ) + == 24 + ) + + def test_prompt_stack_to_string(self): + assert ( + MockTokenizer(model="foo bar").prompt_stack_to_string( + PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)]) + ) + == "User: foobar\n\nAssistant:" + ) diff --git a/tests/unit/utils/test_prompt_stack.py b/tests/unit/utils/test_prompt_stack.py index 253e8cd44b..80010abec7 100644 --- a/tests/unit/utils/test_prompt_stack.py +++ b/tests/unit/utils/test_prompt_stack.py @@ -1,9 +1,5 @@ import pytest from griptape.utils import PromptStack -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_tokenizer import MockTokenizer -from griptape.structures.agent import Agent -from griptape.memory.structure import ConversationMemory, Run class TestPromptStack: @@ -43,97 +39,3 @@ def test_add_assistant_input(self, prompt_stack): assert prompt_stack.inputs[0].role == "assistant" assert prompt_stack.inputs[0].content == "foo" - - def test_add_conversation_memory_autopruing_disabled(self): - agent = Agent(prompt_driver=MockPromptDriver()) - memory = ConversationMemory( - autoprune=False, - runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory) - - assert len(prompt_stack.inputs) == 12 - - def test_add_conversation_memory_autopruing_enabled(self): - # All memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) - memory = ConversationMemory( - autoprune=True, - runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory) - - assert len(prompt_stack.inputs) == 3 - - # No memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) - memory = ConversationMemory( - autoprune=True, - runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory) - - assert len(prompt_stack.inputs) == 13 - - # One memory is pruned. - # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens - # so that a single memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) - memory = ConversationMemory( - autoprune=True, - runs=[ - # All of these sum to 155 tokens with the MockTokenizer. - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - # And then another 6 tokens from fizz for a total of 161 tokens. - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory, 1) - - # We expect one run (2 prompt stack inputs) to be pruned. - assert len(prompt_stack.inputs) == 11 - assert prompt_stack.inputs[0].content == "fizz" - assert prompt_stack.inputs[1].content == "foo2" - assert prompt_stack.inputs[2].content == "bar2" - assert prompt_stack.inputs[-2].content == "foo" - assert prompt_stack.inputs[-1].content == "bar" diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index d21d52c93e..12934382fa 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -190,7 +190,7 @@ class TesterPromptDriverOption: "SAGEMAKER_LLAMA_7B": TesterPromptDriverOption( prompt_driver=AmazonSageMakerPromptDriver( endpoint=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"], - prompt_model_driver=SageMakerLlamaPromptModelDriver(max_tokens=4096), + prompt_model_driver=SageMakerLlamaPromptModelDriver(), ), enabled=False, ),