diff --git a/CHANGELOG.md b/CHANGELOG.md
index 189dd785d2..ad589408cd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`.
+- **BREAKING**: Removed `AmazonBedrockPromptDriver.prompt_model_driver` as it is no longer needed with the `AmazonBedrockPromptDriver` Converse API implementation.
+- **BREAKING**: Removed `BedrockClaudePromptModelDriver`.
+- **BREAKING**: Removed `BedrockJurassicPromptModelDriver`.
+- **BREAKING**: Removed `BedrockLlamaPromptModelDriver`.
+- **BREAKING**: Removed `BedrockTitanPromptModelDriver`.
+- **BREAKING**: Removed `BedrockClaudeTokenizer`, use `SimpleTokenizer` instead.
+- **BREAKING**: Removed `BedrockJurassicTokenizer`, use `SimpleTokenizer` instead.
+- **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead.
+- **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead.
+- Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).
- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`.
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`.
diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md
index c659c7e82c..859dcec030 100644
--- a/docs/griptape-framework/drivers/prompt-drivers.md
+++ b/docs/griptape-framework/drivers/prompt-drivers.md
@@ -232,6 +232,47 @@ agent = Agent(
agent.run('Briefly explain how a computer works to a young child.')
```
+### Amazon Bedrock
+
+!!! info
+ This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras).
+
+The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/)'s [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).
+
+All models supported by the Converse API are available for use with this driver.
+
+```python
+from griptape.structures import Agent
+from griptape.drivers import AmazonBedrockPromptDriver
+from griptape.rules import Rule
+from griptape.config import StructureConfig
+
+agent = Agent(
+ config=StructureConfig(
+ prompt_driver=AmazonBedrockPromptDriver(
+ model="anthropic.claude-3-sonnet-20240229-v1:0",
+ )
+ ),
+ rules=[
+ Rule(
+ value="You are a customer service agent that is classifying emails by type. I want you to give your answer and then explain it."
+ )
+ ],
+)
+agent.run(
+ """How would you categorize this email?
+
+ Can I use my Mixmaster 4000 to mix paint, or is it only meant for mixing food?
+
+
+ Categories are:
+ (A) Pre-sale question
+ (B) Broken or defective item
+ (C) Billing question
+ (D) Other (please explain)"""
+)
+```
+
### Hugging Face Hub
!!! info
@@ -339,7 +380,7 @@ agent.run("How many helicopters can a human eat in one sitting?")
```
### Multi Model Prompt Drivers
-Certain LLM providers such as Amazon SageMaker and Amazon Bedrock supports many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)
+Certain LLM providers such as Amazon SageMaker support many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)
through the [prompt_model_driver](../../reference/griptape/drivers/prompt/base_multi_model_prompt_driver.md#griptape.drivers.prompt.base_multi_model_prompt_driver.BaseMultiModelPromptDriver.prompt_model_driver) parameter.
[Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)s allows for model-specific customization for Prompt Drivers.
@@ -422,120 +463,3 @@ agent = Agent(
agent.run("What is a good lasagna recipe?")
```
-
-#### Amazon Bedrock
-
-!!! info
- This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras).
-
-The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/) for inference on AWS.
-
-##### Amazon Titan
-
-To use this model with Amazon Bedrock, use the [BedrockTitanPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.md).
-
-```python
-from griptape.structures import Agent
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockTitanPromptModelDriver
-from griptape.config import StructureConfig
-
-agent = Agent(
- config=StructureConfig(
- prompt_driver=AmazonBedrockPromptDriver(
- model="amazon.titan-text-express-v1",
- prompt_model_driver=BedrockTitanPromptModelDriver(
- top_p=1,
- )
- )
- )
-)
-agent.run(
- "Write an informational article for children about how birds fly."
- "Compare how birds fly to how airplanes fly."
- 'Make sure to use the word "Thrust" at least three times.'
-)
-```
-
-##### Anthropic Claude
-
-To use this model with Amazon Bedrock, use the [BedrockClaudePromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.md).
-
-```python
-from griptape.structures import Agent
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockClaudePromptModelDriver
-from griptape.rules import Rule
-from griptape.config import StructureConfig
-
-agent = Agent(
- config=StructureConfig(
- prompt_driver=AmazonBedrockPromptDriver(
- model="anthropic.claude-3-sonnet-20240229-v1:0",
- prompt_model_driver=BedrockClaudePromptModelDriver(
- top_p=1,
- )
- )
- ),
- rules=[
- Rule(
- value="You are a customer service agent that is classifying emails by type. I want you to give your answer and then explain it."
- )
- ],
-)
-agent.run(
- """How would you categorize this email?
-
- Can I use my Mixmaster 4000 to mix paint, or is it only meant for mixing food?
-
-
- Categories are:
- (A) Pre-sale question
- (B) Broken or defective item
- (C) Billing question
- (D) Other (please explain)"""
-)
-```
-##### Meta Llama 2
-
-To use this model with Amazon Bedrock, use the [BedrockLlamaPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.md).
-
-```python
-from griptape.structures import Agent
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockLlamaPromptModelDriver
-from griptape.config import StructureConfig
-
-agent = Agent(
- config=StructureConfig(
- prompt_driver=AmazonBedrockPromptDriver(
- model="meta.llama2-13b-chat-v1",
- prompt_model_driver=BedrockLlamaPromptModelDriver(),
- )
- )
-)
-agent.run(
- "Write an article about impact of high inflation to GDP of a country"
-)
-```
-
-##### Ai21 Jurassic
-
-To use this model with Amazon Bedrock, use the [BedrockJurassicPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.md).
-
-```python
-from griptape.structures import Agent
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockJurassicPromptModelDriver
-from griptape.config import StructureConfig
-
-agent = Agent(
- config=StructureConfig(
- prompt_driver=AmazonBedrockPromptDriver(
- model="ai21.j2-ultra-v1",
- prompt_model_driver=BedrockJurassicPromptModelDriver(top_p=0.95),
- temperature=0.7,
- )
- )
-)
-agent.run(
- "Suggest an outline for a blog post based on a title. "
- "Title: How I put the pro in prompt engineering."
-)
-```
diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md
index f76094dad7..219fd4412d 100644
--- a/docs/griptape-framework/structures/task-memory.md
+++ b/docs/griptape-framework/structures/task-memory.md
@@ -206,7 +206,6 @@ from griptape.config import (
)
from griptape.drivers import (
AmazonBedrockPromptDriver,
- BedrockTitanPromptModelDriver,
AmazonBedrockTitanEmbeddingDriver,
LocalVectorStoreDriver,
OpenAiChatPromptDriver,
@@ -227,7 +226,6 @@ agent = Agent(
query_engine=VectorQueryEngine(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1",
- prompt_model_driver=BedrockTitanPromptModelDriver(),
),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver()
diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py
index cefb97f574..e70d9c8196 100644
--- a/griptape/config/amazon_bedrock_structure_config.py
+++ b/griptape/config/amazon_bedrock_structure_config.py
@@ -11,7 +11,6 @@
BasePromptDriver,
BaseVectorStoreDriver,
BedrockClaudeImageQueryModelDriver,
- BedrockClaudePromptModelDriver,
BedrockTitanImageGenerationModelDriver,
LocalVectorStoreDriver,
)
@@ -21,11 +20,7 @@
class AmazonBedrockStructureConfig(StructureConfig):
prompt_driver: BasePromptDriver = field(
default=Factory(
- lambda: AmazonBedrockPromptDriver(
- model="anthropic.claude-3-sonnet-20240229-v1:0",
- stream=False,
- prompt_model_driver=BedrockClaudePromptModelDriver(),
- )
+ lambda: AmazonBedrockPromptDriver(model="anthropic.claude-3-sonnet-20240229-v1:0", stream=False)
),
metadata={"serializable": True},
)
diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py
index 3050989fe9..02a3882fbd 100644
--- a/griptape/drivers/__init__.py
+++ b/griptape/drivers/__init__.py
@@ -55,10 +55,6 @@
from .prompt_model.base_prompt_model_driver import BasePromptModelDriver
from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver
from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver
-from .prompt_model.bedrock_titan_prompt_model_driver import BedrockTitanPromptModelDriver
-from .prompt_model.bedrock_claude_prompt_model_driver import BedrockClaudePromptModelDriver
-from .prompt_model.bedrock_jurassic_prompt_model_driver import BedrockJurassicPromptModelDriver
-from .prompt_model.bedrock_llama_prompt_model_driver import BedrockLlamaPromptModelDriver
from .image_generation_model.base_image_generation_model_driver import BaseImageGenerationModelDriver
from .image_generation_model.bedrock_stable_diffusion_image_generation_model_driver import (
@@ -165,10 +161,6 @@
"BasePromptModelDriver",
"SageMakerLlamaPromptModelDriver",
"SageMakerFalconPromptModelDriver",
- "BedrockTitanPromptModelDriver",
- "BedrockClaudePromptModelDriver",
- "BedrockJurassicPromptModelDriver",
- "BedrockLlamaPromptModelDriver",
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
index 15ce67c4cc..54ac78b452 100644
--- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
+++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@@ -3,7 +3,7 @@
from typing import Any, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
-from griptape.tokenizers import BedrockCohereTokenizer
+from griptape.tokenizers.simple_tokenizer import SimpleTokenizer
from griptape.utils import import_optional_dependency
if TYPE_CHECKING:
@@ -28,8 +28,8 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
model: str = field(default=DEFAULT_MODEL, kw_only=True)
input_type: str = field(default="search_query", kw_only=True)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
- tokenizer: BedrockCohereTokenizer = field(
- default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True
+ tokenizer: SimpleTokenizer = field(
+ default=Factory(lambda self: SimpleTokenizer(characters_per_token=4), takes_self=True), kw_only=True
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
index a510c618c2..c6a692bf3b 100644
--- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
+++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@@ -3,7 +3,7 @@
from typing import Any, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
-from griptape.tokenizers import BedrockTitanTokenizer
+from griptape.tokenizers import SimpleTokenizer
from griptape.utils import import_optional_dependency
if TYPE_CHECKING:
@@ -24,8 +24,8 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
- tokenizer: BedrockTitanTokenizer = field(
- default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True
+ tokenizer: SimpleTokenizer = field(
+ default=Factory(lambda self: SimpleTokenizer(characters_per_token=4), takes_self=True), kw_only=True
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
index 0675e7f92d..f3d03a6844 100644
--- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
+++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@@ -1,11 +1,11 @@
from __future__ import annotations
-import json
from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
from attrs import define, field, Factory
+from griptape.drivers import BasePromptDriver
from griptape.artifacts import TextArtifact
from griptape.utils import import_optional_dependency
-from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver
+from griptape.tokenizers import SimpleTokenizer, BaseTokenizer
if TYPE_CHECKING:
from griptape.utils import PromptStack
@@ -13,43 +13,55 @@
@define
-class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver):
+class AmazonBedrockPromptDriver(BasePromptDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)
+ additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True)
+ tokenizer: BaseTokenizer = field(default=Factory(lambda: SimpleTokenizer(characters_per_token=4)), kw_only=True)
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
- model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
- payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
- if isinstance(model_input, dict):
- payload.update(model_input)
+ response = self.bedrock_client.converse(**self._base_params(prompt_stack))
- response = self.bedrock_client.invoke_model(
- modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
- )
+ output_message = response["output"]["message"]
+ output_content = output_message["content"][0]["text"]
- response_body = response["body"].read()
+ return TextArtifact(output_content)
- if response_body:
- return self.prompt_model_driver.process_output(response_body)
+ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
+ response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))
+
+ stream = response.get("stream")
+ if stream is not None:
+ for event in stream:
+ if "contentBlockDelta" in event:
+ yield TextArtifact(event["contentBlockDelta"]["delta"]["text"])
else:
raise Exception("model response is empty")
- def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
- model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
- payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
- if isinstance(model_input, dict):
- payload.update(model_input)
-
- response = self.bedrock_client.invoke_model_with_response_stream(
- modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
- )
-
- response_body = response["body"]
- if response_body:
- for chunk in response["body"]:
- chunk_bytes = chunk["chunk"]["bytes"]
- yield self.prompt_model_driver.process_output(chunk_bytes)
+ def _base_params(self, prompt_stack: PromptStack) -> dict:
+ system_messages = [
+ {"text": input.content} for input in prompt_stack.inputs if input.is_system() and input.content
+ ]
+ messages = [
+ {"role": self.__to_amazon_bedrock_role(input), "content": [{"text": input.content}]}
+ for input in prompt_stack.inputs
+ if not input.is_system()
+ ]
+
+ return {
+ "modelId": self.model,
+ "messages": messages,
+ "system": system_messages,
+ "inferenceConfig": {"temperature": self.temperature},
+ "additionalModelRequestFields": self.additional_model_request_fields,
+ }
+
+ def __to_amazon_bedrock_role(self, prompt_input: PromptStack.Input) -> str:
+ if prompt_input.is_system():
+ return "system"
+ elif prompt_input.is_assistant():
+ return "assistant"
else:
- raise Exception("model response is empty")
+ return "user"
diff --git a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
deleted file mode 100644
index 2b4c547a98..0000000000
--- a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
+++ /dev/null
@@ -1,87 +0,0 @@
-from __future__ import annotations
-from typing import Optional
-import json
-from attrs import define, field
-from griptape.artifacts import TextArtifact
-from griptape.utils import PromptStack
-from griptape.drivers import BasePromptModelDriver, AmazonBedrockPromptDriver
-from griptape.tokenizers import BedrockClaudeTokenizer
-
-
-@define
-class BedrockClaudePromptModelDriver(BasePromptModelDriver):
- ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example
-
- top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
- top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
- _tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True)
- prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
-
- @property
- def tokenizer(self) -> BedrockClaudeTokenizer:
- """Returns the tokenizer for this driver.
-
- We need to pass the `session` field from the Prompt Driver to the
- Tokenizer. However, the Prompt Driver is not initialized until after
- the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
- field a @property that is only initialized when it is first accessed.
- This ensures that by the time we need to initialize the Tokenizer, the
- Prompt Driver has already been initialized.
-
- See this thread more more information: https://github.com/griptape-ai/griptape/issues/244
-
- Returns:
- BedrockClaudeTokenizer: The tokenizer for this driver.
- """
- if self._tokenizer:
- return self._tokenizer
- else:
- self._tokenizer = BedrockClaudeTokenizer(model=self.prompt_driver.model)
- return self._tokenizer
-
- def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
- messages = [
- {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
- for prompt_input in prompt_stack.inputs
- if not prompt_input.is_system()
- ]
- system = next((i for i in prompt_stack.inputs if i.is_system()), None)
-
- if system is None:
- return {"messages": messages}
- else:
- return {"messages": messages, "system": system.content}
-
- def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
- input = self.prompt_stack_to_model_input(prompt_stack)
-
- return {
- "stop_sequences": self.tokenizer.stop_sequences,
- "temperature": self.prompt_driver.temperature,
- "top_p": self.top_p,
- "top_k": self.top_k,
- "max_tokens": self.prompt_driver.max_output_tokens(self.prompt_driver.prompt_stack_to_string(prompt_stack)),
- "anthropic_version": self.ANTHROPIC_VERSION,
- **input,
- }
-
- def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact:
- if isinstance(output, bytes):
- body = json.loads(output.decode())
- else:
- raise Exception("Output must be bytes.")
-
- if body["type"] == "content_block_delta":
- return TextArtifact(value=body["delta"]["text"])
- elif body["type"] == "message":
- return TextArtifact(value=body["content"][0]["text"])
- else:
- return TextArtifact(value="")
-
- def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
- if prompt_input.is_system():
- return "system"
- elif prompt_input.is_assistant():
- return "assistant"
- else:
- return "user"
diff --git a/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
deleted file mode 100644
index 4da99e88f9..0000000000
--- a/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from __future__ import annotations
-from typing import Optional
-import json
-from attrs import define, field
-from griptape.artifacts import TextArtifact
-from griptape.utils import PromptStack
-from griptape.drivers import BasePromptModelDriver
-from griptape.tokenizers import BedrockJurassicTokenizer
-from griptape.drivers import AmazonBedrockPromptDriver
-
-
-@define
-class BedrockJurassicPromptModelDriver(BasePromptModelDriver):
- top_p: float = field(default=0.9, kw_only=True, metadata={"serializable": True})
- _tokenizer: BedrockJurassicTokenizer = field(default=None, kw_only=True)
- prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
- supports_streaming: bool = field(default=False, kw_only=True)
-
- @property
- def tokenizer(self) -> BedrockJurassicTokenizer:
- """Returns the tokenizer for this driver.
-
- We need to pass the `session` field from the Prompt Driver to the
- Tokenizer. However, the Prompt Driver is not initialized until after
- the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
- field a @property that is only initialized when it is first accessed.
- This ensures that by the time we need to initialize the Tokenizer, the
- Prompt Driver has already been initialized.
-
- See this thread more more information: https://github.com/griptape-ai/griptape/issues/244
-
- Returns:
- BedrockJurassicTokenizer: The tokenizer for this driver.
- """
- if self._tokenizer:
- return self._tokenizer
- else:
- self._tokenizer = BedrockJurassicTokenizer(model=self.prompt_driver.model)
- return self._tokenizer
-
- def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
- prompt_lines = []
-
- for i in prompt_stack.inputs:
- if i.is_user():
- prompt_lines.append(f"User: {i.content}")
- elif i.is_assistant():
- prompt_lines.append(f"Assistant: {i.content}")
- elif i.is_system():
- prompt_lines.append(f"System: {i.content}")
- else:
- prompt_lines.append(i.content)
- prompt_lines.append("Assistant:")
-
- prompt = "\n".join(prompt_lines)
-
- return {"prompt": prompt}
-
- def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
- prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]
-
- return {
- "maxTokens": self.prompt_driver.max_output_tokens(prompt),
- "temperature": self.prompt_driver.temperature,
- "stopSequences": self.tokenizer.stop_sequences,
- "countPenalty": {"scale": 0},
- "presencePenalty": {"scale": 0},
- "frequencyPenalty": {"scale": 0},
- }
-
- def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact:
- if isinstance(output, bytes):
- body = json.loads(output.decode())
- else:
- raise Exception("Output must be bytes.")
- return TextArtifact(body["completions"][0]["data"]["text"])
diff --git a/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py
deleted file mode 100644
index 951583c51d..0000000000
--- a/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from __future__ import annotations
-import json
-import itertools as it
-from typing import Optional
-from attrs import define, field
-from griptape.artifacts import TextArtifact
-from griptape.utils import PromptStack
-from griptape.drivers import BasePromptModelDriver
-from griptape.tokenizers import BedrockLlamaTokenizer
-from griptape.drivers import AmazonBedrockPromptDriver
-
-
-@define
-class BedrockLlamaPromptModelDriver(BasePromptModelDriver):
- top_p: float = field(default=0.9, kw_only=True)
- _tokenizer: BedrockLlamaTokenizer = field(default=None, kw_only=True)
- prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
-
- @property
- def tokenizer(self) -> BedrockLlamaTokenizer:
- """Returns the tokenizer for this driver.
-
- We need to pass the `session` field from the Prompt Driver to the
- Tokenizer. However, the Prompt Driver is not initialized until after
- the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
- field a @property that is only initialized when it is first accessed.
- This ensures that by the time we need to initialize the Tokenizer, the
- Prompt Driver has already been initialized.
-
- See this thread more more information: https://github.com/griptape-ai/griptape/issues/244
-
- Returns:
- BedrockLlamaTokenizer: The tokenizer for this driver.
- """
- if self._tokenizer:
- return self._tokenizer
- else:
- self._tokenizer = BedrockLlamaTokenizer(model=self.prompt_driver.model)
- return self._tokenizer
-
- def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
- """
- Converts a `PromptStack` to a string that can be used as the input to the model.
-
- Prompt structure adapted from https://huggingface.co/blog/llama2#how-to-prompt-llama-2
-
- Args:
- prompt_stack: The `PromptStack` to convert.
- """
- prompt_lines = []
-
- inputs = iter(prompt_stack.inputs)
- input_pairs: list[tuple] = list(it.zip_longest(inputs, inputs))
- for input_pair in input_pairs:
- first_input: PromptStack.Input = input_pair[0]
- second_input: Optional[PromptStack.Input] = input_pair[1]
-
- if first_input.is_system():
- prompt_lines.append(f"[INST] <>\n{first_input.content}\n<>\n\n")
- if second_input:
- if second_input.is_user():
- prompt_lines.append(f"{second_input.content} [/INST]")
- else:
- raise Exception("System input must be followed by user input.")
- elif first_input.is_assistant():
- prompt_lines.append(f" {first_input.content} ")
- if second_input:
- if second_input.is_user():
- prompt_lines.append(f"[INST] {second_input.content} [/INST]")
- else:
- raise Exception("Assistant input must be followed by user input.")
- elif first_input.is_user():
- prompt_lines.append(f"[INST] {first_input.content} [/INST]")
- if second_input:
- if second_input.is_assistant():
- prompt_lines.append(f" {second_input.content} ")
- else:
- raise Exception("User input must be followed by assistant input.")
-
- return "".join(prompt_lines)
-
- def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
- prompt = self.prompt_stack_to_model_input(prompt_stack)
-
- return {
- "prompt": prompt,
- "max_gen_len": self.prompt_driver.max_output_tokens(prompt),
- "temperature": self.prompt_driver.temperature,
- "top_p": self.top_p,
- }
-
- def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact:
- # When streaming, the response body comes back as bytes.
- if isinstance(output, bytes):
- output = output.decode()
- elif isinstance(output, list) or isinstance(output, dict):
- raise Exception("Invalid output format.")
-
- body = json.loads(output)
-
- return TextArtifact(body["generation"])
diff --git a/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
deleted file mode 100644
index 5f5bbc1d26..0000000000
--- a/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from __future__ import annotations
-from typing import Optional
-import json
-from attrs import define, field
-from griptape.artifacts import TextArtifact
-from griptape.utils import PromptStack
-from griptape.drivers import BasePromptModelDriver
-from griptape.tokenizers import BedrockTitanTokenizer
-from griptape.drivers import AmazonBedrockPromptDriver
-
-
-@define
-class BedrockTitanPromptModelDriver(BasePromptModelDriver):
- top_p: float = field(default=0.9, kw_only=True, metadata={"serializable": True})
- _tokenizer: BedrockTitanTokenizer = field(default=None, kw_only=True)
- prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
-
- @property
- def tokenizer(self) -> BedrockTitanTokenizer:
- """Returns the tokenizer for this driver.
-
- We need to pass the `session` field from the Prompt Driver to the
- Tokenizer. However, the Prompt Driver is not initialized until after
- the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
- field a @property that is only initialized when it is first accessed.
- This ensures that by the time we need to initialize the Tokenizer, the
- Prompt Driver has already been initialized.
-
- See this thread for more information: https://github.com/griptape-ai/griptape/issues/244
-
- Returns:
- BedrockTitanTokenizer: The tokenizer for this driver.
- """
- if self._tokenizer:
- return self._tokenizer
- else:
- self._tokenizer = BedrockTitanTokenizer(model=self.prompt_driver.model)
- return self._tokenizer
-
- def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
- prompt_lines = []
-
- for i in prompt_stack.inputs:
- if i.is_user():
- prompt_lines.append(f"User: {i.content}")
- elif i.is_assistant():
- prompt_lines.append(f"Bot: {i.content}")
- elif i.is_system():
- prompt_lines.append(f"Instructions: {i.content}")
- else:
- prompt_lines.append(i.content)
- prompt_lines.append("Bot:")
-
- prompt = "\n\n".join(prompt_lines)
-
- return {"inputText": prompt}
-
- def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
- prompt = self.prompt_stack_to_model_input(prompt_stack)["inputText"]
-
- return {
- "textGenerationConfig": {
- "maxTokenCount": self.prompt_driver.max_output_tokens(prompt),
- "stopSequences": self.tokenizer.stop_sequences,
- "temperature": self.prompt_driver.temperature,
- "topP": self.top_p,
- }
- }
-
- def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact:
- # When streaming, the response body comes back as bytes.
- if isinstance(output, str) or isinstance(output, bytes):
- if isinstance(output, bytes):
- output = output.decode()
-
- body = json.loads(output)
-
- if self.prompt_driver.stream:
- return TextArtifact(body["outputText"])
- else:
- return TextArtifact(body["results"][0]["outputText"])
- else:
- raise ValueError("output must be an instance of 'str' or 'bytes'")
diff --git a/griptape/tokenizers/__init__.py b/griptape/tokenizers/__init__.py
index b116f9fb07..e69473acd4 100644
--- a/griptape/tokenizers/__init__.py
+++ b/griptape/tokenizers/__init__.py
@@ -3,11 +3,6 @@
from griptape.tokenizers.cohere_tokenizer import CohereTokenizer
from griptape.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
from griptape.tokenizers.anthropic_tokenizer import AnthropicTokenizer
-from griptape.tokenizers.bedrock_titan_tokenizer import BedrockTitanTokenizer
-from griptape.tokenizers.bedrock_cohere_tokenizer import BedrockCohereTokenizer
-from griptape.tokenizers.bedrock_jurassic_tokenizer import BedrockJurassicTokenizer
-from griptape.tokenizers.bedrock_claude_tokenizer import BedrockClaudeTokenizer
-from griptape.tokenizers.bedrock_llama_tokenizer import BedrockLlamaTokenizer
from griptape.tokenizers.google_tokenizer import GoogleTokenizer
from griptape.tokenizers.voyageai_tokenizer import VoyageAiTokenizer
from griptape.tokenizers.simple_tokenizer import SimpleTokenizer
@@ -20,11 +15,6 @@
"CohereTokenizer",
"HuggingFaceTokenizer",
"AnthropicTokenizer",
- "BedrockTitanTokenizer",
- "BedrockCohereTokenizer",
- "BedrockJurassicTokenizer",
- "BedrockClaudeTokenizer",
- "BedrockLlamaTokenizer",
"GoogleTokenizer",
"VoyageAiTokenizer",
"SimpleTokenizer",
diff --git a/griptape/tokenizers/bedrock_claude_tokenizer.py b/griptape/tokenizers/bedrock_claude_tokenizer.py
deleted file mode 100644
index d44116e2ce..0000000000
--- a/griptape/tokenizers/bedrock_claude_tokenizer.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from attrs import define
-from griptape.tokenizers import AnthropicTokenizer
-
-
-@define()
-class BedrockClaudeTokenizer(AnthropicTokenizer):
- MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {
- "anthropic.claude-3": 200000,
- "anthropic.claude-v2:1": 200000,
- "anthropic.claude": 100000,
- }
- MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"anthropic.claude": 4096}
diff --git a/griptape/tokenizers/bedrock_cohere_tokenizer.py b/griptape/tokenizers/bedrock_cohere_tokenizer.py
deleted file mode 100644
index 44ccb4ac66..0000000000
--- a/griptape/tokenizers/bedrock_cohere_tokenizer.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from __future__ import annotations
-from attrs import define, field
-from .simple_tokenizer import SimpleTokenizer
-
-
-@define()
-class BedrockCohereTokenizer(SimpleTokenizer):
- # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
- DEFAULT_CHARACTERS_PER_TOKEN = 4
- MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"cohere": 1024}
- MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"cohere": 4096}
-
- model: str = field(kw_only=True)
- characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True)
diff --git a/griptape/tokenizers/bedrock_jurassic_tokenizer.py b/griptape/tokenizers/bedrock_jurassic_tokenizer.py
deleted file mode 100644
index 7511138b3c..0000000000
--- a/griptape/tokenizers/bedrock_jurassic_tokenizer.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from __future__ import annotations
-from attrs import define, field, Factory
-from .simple_tokenizer import SimpleTokenizer
-
-
-@define()
-class BedrockJurassicTokenizer(SimpleTokenizer):
- DEFAULT_CHARACTERS_PER_TOKEN = 6 # https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html#model-customization-prepare-finetuning
- MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"ai21": 8192}
- MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {
- "ai21.j2-mid-v1": 8191,
- "ai21.j2-ultra-v1": 8191,
- "ai21.j2-large-v1": 8191,
- "ai21": 2048,
- }
-
- model: str = field(kw_only=True)
- characters_per_token: int = field(
- default=Factory(lambda self: self.DEFAULT_CHARACTERS_PER_TOKEN, takes_self=True), kw_only=True
- )
diff --git a/griptape/tokenizers/bedrock_llama_tokenizer.py b/griptape/tokenizers/bedrock_llama_tokenizer.py
deleted file mode 100644
index e7d1ec8295..0000000000
--- a/griptape/tokenizers/bedrock_llama_tokenizer.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from __future__ import annotations
-from attrs import define, field
-from .simple_tokenizer import SimpleTokenizer
-
-
-@define()
-class BedrockLlamaTokenizer(SimpleTokenizer):
- DEFAULT_CHARACTERS_PER_TOKEN = 6 # https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html#model-customization-prepare-finetuning
- MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"meta": 2048}
- MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"meta": 2048}
-
- model: str = field(kw_only=True)
- characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True)
- stop_sequences: list[str] = field(factory=list, kw_only=True)
diff --git a/griptape/tokenizers/bedrock_titan_tokenizer.py b/griptape/tokenizers/bedrock_titan_tokenizer.py
deleted file mode 100644
index 0d8ba0273b..0000000000
--- a/griptape/tokenizers/bedrock_titan_tokenizer.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from __future__ import annotations
-from attrs import define, field, Factory
-from .simple_tokenizer import SimpleTokenizer
-
-
-@define()
-class BedrockTitanTokenizer(SimpleTokenizer):
- DEFAULT_CHARACTERS_PER_TOKEN = 6 # https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html#model-customization-prepare-finetuning
- MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"amazon": 4096}
- MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"amazon": 8000}
-
- model: str = field(kw_only=True)
- characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True)
- stop_sequences: list[str] = field(default=Factory(lambda: ["User:"]), kw_only=True)
diff --git a/griptape/tokenizers/simple_tokenizer.py b/griptape/tokenizers/simple_tokenizer.py
index 484afe69f0..864b92f962 100644
--- a/griptape/tokenizers/simple_tokenizer.py
+++ b/griptape/tokenizers/simple_tokenizer.py
@@ -7,6 +7,8 @@
@define()
class SimpleTokenizer(BaseTokenizer):
model: Optional[str] = field(init=False, kw_only=True, default=None)
+ max_input_tokens: int = field(kw_only=True, default=0)
+ max_output_tokens: int = field(kw_only=True, default=0)
characters_per_token: int = field(kw_only=True)
def count_tokens(self, text: str | list) -> int:
diff --git a/poetry.lock b/poetry.lock
index 65e767cc5c..87b88792ad 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -278,17 +278,17 @@ lxml = ["lxml"]
[[package]]
name = "boto3"
-version = "1.34.106"
+version = "1.34.119"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "boto3-1.34.106-py3-none-any.whl", hash = "sha256:d3be4e1dd5d546a001cd4da805816934cbde9d395316546e9411fec341ade5cf"},
- {file = "boto3-1.34.106.tar.gz", hash = "sha256:6165b8cf1c7e625628ab28b32f9027064c8f5e5fca1c38d7fc228cd22069a19f"},
+ {file = "boto3-1.34.119-py3-none-any.whl", hash = "sha256:8f9c43c54b3dfaa36c4a0d7b42c417227a515bc7a2e163e62802780000a5a3e2"},
+ {file = "boto3-1.34.119.tar.gz", hash = "sha256:cea2365a25b2b83a97e77f24ac6f922ef62e20636b42f9f6ee9f97188f9c1c03"},
]
[package.dependencies]
-botocore = ">=1.34.106,<1.35.0"
+botocore = ">=1.34.119,<1.35.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@@ -706,13 +706,13 @@ xray = ["mypy-boto3-xray (>=1.34.0,<1.35.0)"]
[[package]]
name = "botocore"
-version = "1.34.106"
+version = "1.34.119"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
- {file = "botocore-1.34.106-py3-none-any.whl", hash = "sha256:4baf0e27c2dfc4f4d0dee7c217c716e0782f9b30e8e1fff983fce237d88f73ae"},
- {file = "botocore-1.34.106.tar.gz", hash = "sha256:921fa5202f88c3e58fdcb4b3acffd56d65b24bca47092ee4b27aa988556c0be6"},
+ {file = "botocore-1.34.119-py3-none-any.whl", hash = "sha256:4bdf7926a1290b2650d62899ceba65073dd2693e61c35f5cdeb3a286a0aaa27b"},
+ {file = "botocore-1.34.119.tar.gz", hash = "sha256:b253f15b24b87b070e176af48e8ef146516090429d30a7d8b136a4c079b28008"},
]
[package.dependencies]
@@ -6096,4 +6096,4 @@ loaders-pdf = ["pypdf"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
-content-hash = "7d073f2b3d8b7e16b3315c7c59925884fadda9c0b0c179b191a748de669f2e0c"
+content-hash = "d0913fb4119f352710d722c55029eaa74e948d57e3fa5ffb345ca4e690f20ec2"
diff --git a/pyproject.toml b/pyproject.toml
index b750051e07..53c07257ed 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@ cohere = { version = "^5.5.4", optional = true }
anthropic = { version = "^0.20.0", optional = true }
transformers = { version = "^4.39.3", optional = true }
huggingface-hub = { version = ">=0.13", optional = true }
-boto3 = { version = "^1.28.2", optional = true }
+boto3 = { version = "^1.34.119", optional = true }
sqlalchemy-redshift = { version = "*", optional = true }
snowflake-sqlalchemy = { version = "^1.4.7", optional = true }
pinecone-client = { version = "^3", optional = true }
diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py
index 5b8c63a987..33b286f940 100644
--- a/tests/unit/config/test_amazon_bedrock_structure_config.py
+++ b/tests/unit/config/test_amazon_bedrock_structure_config.py
@@ -38,7 +38,6 @@ def test_to_dict(self, config):
"prompt_driver": {
"max_tokens": None,
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
- "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999},
"stream": False,
"temperature": 0.1,
"type": "AmazonBedrockPromptDriver",
diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
index 1bd94d3e9a..8dc9ecdbbf 100644
--- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
+++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
@@ -1,117 +1,83 @@
-from botocore.response import StreamingBody
-from griptape.artifacts import TextArtifact
-from griptape.drivers import AmazonBedrockPromptDriver
-from griptape.drivers import BedrockClaudePromptModelDriver, BedrockTitanPromptModelDriver
-from griptape.tokenizers import AnthropicTokenizer, BedrockTitanTokenizer
-from io import StringIO
-from unittest.mock import Mock
-import json
import pytest
+from griptape.utils import PromptStack
+from griptape.drivers import AmazonBedrockPromptDriver
+
class TestAmazonBedrockPromptDriver:
+ @pytest.fixture()
+ def mock_converse(self, mocker):
+ mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream
+
+ mock_converse.return_value = {"output": {"message": {"content": [{"text": "model-output"}]}}}
+
+ return mock_converse
+
+ @pytest.fixture()
+ def mock_converse_stream(self, mocker):
+ mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream
+
+ mock_converse_stream.return_value = {"stream": [{"contentBlockDelta": {"delta": {"text": "model-output"}}}]}
+
+ return mock_converse_stream
+
@pytest.fixture
- def mock_prompt_model_driver(self):
- mock_prompt_model_driver = Mock()
- mock_prompt_model_driver.prompt_stack_to_model_params.return_value = {"model-param-key": "model-param-value"}
- mock_prompt_model_driver.process_output.return_value = TextArtifact("model-output")
- return mock_prompt_model_driver
-
- @pytest.fixture(autouse=True)
- def mock_client(self, mocker):
- return mocker.patch("boto3.Session").return_value.client.return_value
-
- def test_init(self):
- assert AmazonBedrockPromptDriver(model="anthropic.claude", prompt_model_driver=BedrockClaudePromptModelDriver())
-
- def test_custom_tokenizer(self):
- assert isinstance(
- AmazonBedrockPromptDriver(
- model="anthropic.claude", prompt_model_driver=BedrockClaudePromptModelDriver()
- ).tokenizer,
- AnthropicTokenizer,
- )
+ def prompt_stack(self):
+ prompt_stack = PromptStack()
+ prompt_stack.add_generic_input("generic-input")
+ prompt_stack.add_system_input("system-input")
+ prompt_stack.add_user_input("user-input")
+ prompt_stack.add_assistant_input("assistant-input")
- assert isinstance(
- AmazonBedrockPromptDriver(
- model="titan",
- tokenizer=BedrockTitanTokenizer(model="amazon"),
- prompt_model_driver=BedrockTitanPromptModelDriver(),
- ).tokenizer,
- BedrockTitanTokenizer,
- )
+ return prompt_stack
- @pytest.mark.parametrize("model_inputs", [{"model-input-key": "model-input-value"}, "not-a-dict"])
- def test_try_run(self, model_inputs, mock_prompt_model_driver, mock_client):
+ @pytest.fixture
+ def messages(self):
+ return [
+ {"role": "user", "content": [{"text": "generic-input"}]},
+ {"role": "system", "content": [{"text": "system-input"}]},
+ {"role": "user", "content": [{"text": "user-input"}]},
+ {"role": "assistant", "content": [{"text": "assistant-input"}]},
+ ]
+
+ def test_try_run(self, mock_converse, prompt_stack, messages):
# Given
- driver = AmazonBedrockPromptDriver(model="model", prompt_model_driver=mock_prompt_model_driver)
- prompt_stack = "prompt-stack"
- response_body = "invoke-model-response-body"
- mock_prompt_model_driver.prompt_stack_to_model_input.return_value = model_inputs
- mock_client.invoke_model.return_value = {"body": to_streaming_body(response_body)}
+ driver = AmazonBedrockPromptDriver(model="foo bar")
# When
text_artifact = driver.try_run(prompt_stack)
# Then
- mock_prompt_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack)
- mock_prompt_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack)
- mock_client.invoke_model.assert_called_once_with(
+ mock_converse.assert_called_once_with(
modelId=driver.model,
- contentType="application/json",
- accept="application/json",
- body=json.dumps(
- {
- **mock_prompt_model_driver.prompt_stack_to_model_params.return_value,
- **(model_inputs if isinstance(model_inputs, dict) else {}),
- }
- ),
+ messages=[
+ {"role": "user", "content": [{"text": "generic-input"}]},
+ {"role": "user", "content": [{"text": "user-input"}]},
+ {"role": "assistant", "content": [{"text": "assistant-input"}]},
+ ],
+ system=[{"text": "system-input"}],
+ inferenceConfig={"temperature": driver.temperature},
+ additionalModelRequestFields={},
)
- mock_prompt_model_driver.process_output.assert_called_once_with(response_body)
- assert text_artifact == mock_prompt_model_driver.process_output.return_value
+ assert text_artifact.value == "model-output"
- @pytest.mark.parametrize("model_inputs", [{"model-input-key": "model-input-value"}, "not-a-dict"])
- def test_try_stream_run(self, model_inputs, mock_prompt_model_driver, mock_client):
+ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages):
# Given
- driver = AmazonBedrockPromptDriver(model="model", prompt_model_driver=mock_prompt_model_driver, stream=True)
- prompt_stack = "prompt-stack"
- model_response = "invoke-model-response-body"
- response_body = [{"chunk": {"bytes": model_response}}]
- mock_prompt_model_driver.prompt_stack_to_model_input.return_value = model_inputs
- mock_client.invoke_model_with_response_stream.return_value = {"body": response_body}
+ driver = AmazonBedrockPromptDriver(model="foo bar", stream=True)
# When
text_artifact = next(driver.try_stream(prompt_stack))
# Then
- mock_prompt_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack)
- mock_prompt_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack)
- mock_client.invoke_model_with_response_stream.assert_called_once_with(
+ mock_converse_stream.assert_called_once_with(
modelId=driver.model,
- contentType="application/json",
- accept="application/json",
- body=json.dumps(
- {
- **mock_prompt_model_driver.prompt_stack_to_model_params.return_value,
- **(model_inputs if isinstance(model_inputs, dict) else {}),
- }
- ),
+ messages=[
+ {"role": "user", "content": [{"text": "generic-input"}]},
+ {"role": "user", "content": [{"text": "user-input"}]},
+ {"role": "assistant", "content": [{"text": "assistant-input"}]},
+ ],
+ system=[{"text": "system-input"}],
+ inferenceConfig={"temperature": driver.temperature},
+ additionalModelRequestFields={},
)
- mock_prompt_model_driver.process_output.assert_called_once_with(model_response)
- assert text_artifact.value == mock_prompt_model_driver.process_output.return_value.value
-
- def test_try_run_throws_on_empty_response(self, mock_prompt_model_driver, mock_client):
- # Given
- driver = AmazonBedrockPromptDriver(model="model", prompt_model_driver=mock_prompt_model_driver)
- mock_client.invoke_model.return_value = {"body": to_streaming_body("")}
-
- # When
- with pytest.raises(Exception) as e:
- driver.try_run("prompt-stack")
-
- # Then
- assert e.value.args[0] == "model response is empty"
-
-
-def to_streaming_body(text: str) -> StreamingBody:
- return StreamingBody(StringIO(text), len(text))
+ assert text_artifact.value == "model-output"
diff --git a/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py
deleted file mode 100644
index 94bc971223..0000000000
--- a/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py
+++ /dev/null
@@ -1,129 +0,0 @@
-from unittest import mock
-import json
-import boto3
-import pytest
-from griptape.utils import PromptStack
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockClaudePromptModelDriver
-
-
-class TestBedrockClaudePromptModelDriver:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
-
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- @pytest.fixture
- def driver(self, request):
- return AmazonBedrockPromptDriver(
- model=request.param,
- session=boto3.Session(region_name="us-east-1"),
- prompt_model_driver=BedrockClaudePromptModelDriver(),
- temperature=0.12345,
- ).prompt_model_driver
-
- @pytest.mark.parametrize(
- "driver,",
- [
- ("anthropic.claude-v2"),
- ("anthropic.claude-v2:1"),
- ("anthropic.claude-3-sonnet-20240229-v1:0"),
- ("anthropic.claude-3-haiku-20240307-v1:0"),
- ],
- indirect=["driver"],
- )
- def test_init(self, driver):
- assert driver.prompt_driver is not None
-
- @pytest.mark.parametrize(
- "driver,",
- [
- ("anthropic.claude-v2"),
- ("anthropic.claude-v2:1"),
- ("anthropic.claude-3-sonnet-20240229-v1:0"),
- ("anthropic.claude-3-haiku-20240307-v1:0"),
- ],
- indirect=["driver"],
- )
- @pytest.mark.parametrize("system_enabled", [True, False])
- def test_prompt_stack_to_model_input(self, driver, system_enabled):
- stack = PromptStack()
- if system_enabled:
- stack.add_system_input("foo")
- stack.add_user_input("bar")
- stack.add_assistant_input("baz")
- stack.add_generic_input("qux")
-
- expected_messages = [
- {"role": "user", "content": "bar"},
- {"role": "assistant", "content": "baz"},
- {"role": "user", "content": "qux"},
- ]
- actual = driver.prompt_stack_to_model_input(stack)
- expected = {"messages": expected_messages, **({"system": "foo"} if system_enabled else {})}
-
- assert actual == expected
-
- @pytest.mark.parametrize(
- "driver,",
- [
- ("anthropic.claude-v2"),
- ("anthropic.claude-v2:1"),
- ("anthropic.claude-3-sonnet-20240229-v1:0"),
- ("anthropic.claude-3-haiku-20240307-v1:0"),
- ],
- indirect=["driver"],
- )
- @pytest.mark.parametrize("system_enabled", [True, False])
- def test_prompt_stack_to_model_params(self, driver, system_enabled):
- stack = PromptStack()
- if system_enabled:
- stack.add_system_input("foo")
- stack.add_user_input("bar")
- stack.add_assistant_input("baz")
- stack.add_generic_input("qux")
-
- max_tokens = driver.prompt_driver.max_output_tokens(driver.prompt_driver.prompt_stack_to_string(stack))
-
- expected = {
- "temperature": 0.12345,
- "max_tokens": max_tokens,
- "anthropic_version": driver.ANTHROPIC_VERSION,
- "messages": [
- {"role": "user", "content": "bar"},
- {"role": "assistant", "content": "baz"},
- {"role": "user", "content": "qux"},
- ],
- "top_p": 0.999,
- "top_k": 250,
- "stop_sequences": ["<|Response|>"],
- **({"system": "foo"} if system_enabled else {}),
- }
-
- assert driver.prompt_stack_to_model_params(stack) == expected
-
- @pytest.mark.parametrize(
- "driver,",
- [
- ("anthropic.claude-v2"),
- ("anthropic.claude-v2:1"),
- ("anthropic.claude-3-sonnet-20240229-v1:0"),
- ("anthropic.claude-3-haiku-20240307-v1:0"),
- ],
- indirect=["driver"],
- )
- def test_process_output(self, driver):
- assert (
- driver.process_output(json.dumps({"type": "message", "content": [{"text": "foobar"}]}).encode()).value
- == "foobar"
- )
- assert (
- driver.process_output(
- json.dumps({"type": "content_block_delta", "delta": {"text": "foobar"}}).encode()
- ).value
- == "foobar"
- )
diff --git a/tests/unit/drivers/prompt_models/test_bedrock_jurassic_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_jurassic_prompt_model_driver.py
deleted file mode 100644
index e0d6f1f02a..0000000000
--- a/tests/unit/drivers/prompt_models/test_bedrock_jurassic_prompt_model_driver.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from unittest import mock
-import json
-import boto3
-import pytest
-from griptape.utils import PromptStack
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockJurassicPromptModelDriver
-
-
-class TestBedrockJurassicPromptModelDriver:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"prompt": {"tokens": [{}, {}, {}]}}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- return mock_session_object
-
- @pytest.fixture
- def driver(self):
- return AmazonBedrockPromptDriver(
- model="ai21.j2-ultra",
- session=boto3.Session(region_name="us-east-1"),
- prompt_model_driver=BedrockJurassicPromptModelDriver(),
- temperature=0.12345,
- ).prompt_model_driver
-
- @pytest.fixture
- def stack(self):
- stack = PromptStack()
-
- stack.add_system_input("foo")
- stack.add_user_input("bar")
-
- return stack
-
- def test_driver_stream(self):
- with pytest.raises(ValueError):
- AmazonBedrockPromptDriver(
- model="ai21.j2-ultra",
- session=boto3.Session(region_name="us-east-1"),
- prompt_model_driver=BedrockJurassicPromptModelDriver(),
- temperature=0.12345,
- stream=True,
- ).prompt_model_driver
-
- def test_init(self, driver):
- assert driver.prompt_driver is not None
-
- def test_prompt_stack_to_model_input(self, driver, stack):
- model_input = driver.prompt_stack_to_model_input(stack)
-
- assert isinstance(model_input, dict)
- assert model_input["prompt"].startswith("System: foo\nUser: bar\nAssistant:")
-
- def test_prompt_stack_to_model_params(self, driver, stack):
- assert driver.prompt_stack_to_model_params(stack)["maxTokens"] == 2042
- assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345
-
- def test_process_output(self, driver):
- assert (
- driver.process_output(json.dumps({"completions": [{"data": {"text": "foobar"}}]}).encode()).value
- == "foobar"
- )
diff --git a/tests/unit/drivers/prompt_models/test_bedrock_llama_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_llama_prompt_model_driver.py
deleted file mode 100644
index 8cb4b2c941..0000000000
--- a/tests/unit/drivers/prompt_models/test_bedrock_llama_prompt_model_driver.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from unittest import mock
-import json
-import boto3
-import pytest
-from griptape.tokenizers import BedrockLlamaTokenizer
-from griptape.utils import PromptStack
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockLlamaPromptModelDriver
-
-
-class TestBedrockLlamaPromptModelDriver:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"generation_token_count": 13}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- return mock_session_object
-
- @pytest.fixture
- def driver(self):
- return AmazonBedrockPromptDriver(
- model="meta.llama",
- session=boto3.Session(region_name="us-east-1"),
- prompt_model_driver=BedrockLlamaPromptModelDriver(),
- temperature=0.12345,
- ).prompt_model_driver
-
- @pytest.fixture
- def stack(self):
- stack = PromptStack()
-
- stack.add_system_input("{{ system_prompt }}")
- stack.add_user_input("{{ usr_msg_1 }}")
- stack.add_assistant_input("{{ model_msg_1 }}")
- stack.add_user_input("{{ usr_msg_2 }}")
-
- return stack
-
- def test_init(self, driver):
- assert driver.prompt_driver is not None
-
- def test_prompt_stack_to_model_input(self, driver, stack):
- model_input = driver.prompt_stack_to_model_input(stack)
-
- assert isinstance(model_input, str)
- assert (
- model_input
- == "[INST] <>\n{{ system_prompt }}\n<>\n\n{{ usr_msg_1 }} [/INST] {{ model_msg_1 }} [INST] {{ usr_msg_2 }} [/INST]"
- )
-
- def test_prompt_stack_to_model_params(self, driver, stack):
- assert driver.prompt_stack_to_model_params(stack)["max_gen_len"] == 2026
- assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345
-
- def test_process_output(self, driver):
- assert driver.process_output(json.dumps({"generation": "foobar"})).value == "foobar"
diff --git a/tests/unit/drivers/prompt_models/test_bedrock_titan_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_titan_prompt_model_driver.py
deleted file mode 100644
index ae6d436b5b..0000000000
--- a/tests/unit/drivers/prompt_models/test_bedrock_titan_prompt_model_driver.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from unittest import mock
-import json
-import boto3
-import pytest
-from griptape.utils import PromptStack
-from griptape.drivers import AmazonBedrockPromptDriver, BedrockTitanPromptModelDriver
-
-
-class TestBedrockTitanPromptModelDriver:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"inputTextTokenCount": 13}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- return mock_session_object
-
- @pytest.fixture
- def driver(self):
- return AmazonBedrockPromptDriver(
- model="amazon.titan",
- session=boto3.Session(region_name="us-east-1"),
- prompt_model_driver=BedrockTitanPromptModelDriver(),
- temperature=0.12345,
- ).prompt_model_driver
-
- @pytest.fixture
- def stack(self):
- stack = PromptStack()
-
- stack.add_system_input("foo")
- stack.add_user_input("bar")
-
- return stack
-
- def test_init(self, driver):
- assert driver.prompt_driver is not None
-
- def test_prompt_stack_to_model_input(self, driver, stack):
- model_input = driver.prompt_stack_to_model_input(stack)
-
- assert isinstance(model_input, dict)
- assert model_input["inputText"].startswith("Instructions: foo\n\nUser: bar\n\nBot:")
-
- def test_prompt_stack_to_model_params(self, driver, stack):
- assert driver.prompt_stack_to_model_params(stack)["textGenerationConfig"]["maxTokenCount"] == 7994
- assert driver.prompt_stack_to_model_params(stack)["textGenerationConfig"]["temperature"] == 0.12345
-
- def test_process_output(self, driver):
- assert driver.process_output(json.dumps({"results": [{"outputText": "foobar"}]})).value == "foobar"
diff --git a/tests/unit/tokenizers/test_bedrock_claude_tokenizer.py b/tests/unit/tokenizers/test_bedrock_claude_tokenizer.py
deleted file mode 100644
index 4f6e6723f7..0000000000
--- a/tests/unit/tokenizers/test_bedrock_claude_tokenizer.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import pytest
-from griptape.tokenizers import BedrockClaudeTokenizer
-
-
-class TestBedrockClaudeTokenizer:
- @pytest.fixture
- def tokenizer(self, request):
- return BedrockClaudeTokenizer(model=request.param)
-
- @pytest.mark.parametrize(
- "tokenizer,expected",
- [
- ("anthropic.claude-v2:1", 5),
- ("anthropic.claude-v2", 5),
- ("anthropic.claude-3-sonnet-20240229-v1:0", 5),
- ("anthropic.claude-3-haiku-20240307-v1:0", 5),
- ],
- indirect=["tokenizer"],
- )
- def test_token_count(self, tokenizer, expected):
- assert tokenizer.count_tokens("foo bar huzzah") == expected
-
- @pytest.mark.parametrize(
- "tokenizer,expected",
- [
- ("anthropic.claude-v2", 99995),
- ("anthropic.claude-v2:1", 199995),
- ("anthropic.claude-3-sonnet-20240229-v1:0", 199995),
- ("anthropic.claude-3-haiku-20240307-v1:0", 199995),
- ],
- indirect=["tokenizer"],
- )
- def test_input_tokens_left(self, tokenizer, expected):
- assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected
-
- @pytest.mark.parametrize(
- "tokenizer,expected",
- [
- ("anthropic.claude-v2", 4091),
- ("anthropic.claude-v2:1", 4091),
- ("anthropic.claude-3-sonnet-20240229-v1:0", 4091),
- ("anthropic.claude-3-haiku-20240307-v1:0", 4091),
- ],
- indirect=["tokenizer"],
- )
- def test_output_tokens_left(self, tokenizer, expected):
- assert tokenizer.count_output_tokens_left("foo bar huzzah") == expected
diff --git a/tests/unit/tokenizers/test_bedrock_cohere_tokenizer.py b/tests/unit/tokenizers/test_bedrock_cohere_tokenizer.py
deleted file mode 100644
index 6238b0e54a..0000000000
--- a/tests/unit/tokenizers/test_bedrock_cohere_tokenizer.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import pytest
-from unittest import mock
-from griptape.tokenizers import BedrockCohereTokenizer
-
-
-class TestBedrockCohereTokenizer:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"inputTextTokenCount": 2}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- def test_input_tokens_left(self):
- assert BedrockCohereTokenizer(model="cohere").count_input_tokens_left("foo bar") == 1022
-
- def test_output_tokens_left(self):
- assert BedrockCohereTokenizer(model="cohere").count_output_tokens_left("foo bar") == 4094
diff --git a/tests/unit/tokenizers/test_bedrock_jurassic_tokenizer.py b/tests/unit/tokenizers/test_bedrock_jurassic_tokenizer.py
deleted file mode 100644
index 59c42493b5..0000000000
--- a/tests/unit/tokenizers/test_bedrock_jurassic_tokenizer.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import pytest
-from unittest import mock
-from griptape.tokenizers import BedrockJurassicTokenizer
-
-
-class TestBedrockJurassicTokenizer:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"prompt": {"tokens": [{}, {}, {}]}}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- @pytest.fixture
- def tokenizer(self, request):
- return BedrockJurassicTokenizer(model=request.param)
-
- @pytest.mark.parametrize(
- "tokenizer,expected",
- [("ai21.j2-mid-v1", 8186), ("ai21.j2-ultra-v1", 8186), ("ai21.j2-large-v1", 8186), ("ai21.j2-large-v2", 8186)],
- indirect=["tokenizer"],
- )
- def test_input_tokens_left(self, tokenizer, expected):
- assert tokenizer.count_input_tokens_left("System: foo\nUser: bar\nAssistant:") == expected
-
- @pytest.mark.parametrize(
- "tokenizer,expected",
- [("ai21.j2-mid-v1", 8185), ("ai21.j2-ultra-v1", 8185), ("ai21.j2-large-v1", 8185), ("ai21.j2-large-v2", 2042)],
- indirect=["tokenizer"],
- )
- def test_output_tokens_left(self, tokenizer, expected):
- assert tokenizer.count_output_tokens_left("System: foo\nUser: bar\nAssistant:") == expected
diff --git a/tests/unit/tokenizers/test_bedrock_llama_tokenizer.py b/tests/unit/tokenizers/test_bedrock_llama_tokenizer.py
deleted file mode 100644
index da842f0f35..0000000000
--- a/tests/unit/tokenizers/test_bedrock_llama_tokenizer.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import pytest
-from unittest import mock
-from griptape.tokenizers import BedrockLlamaTokenizer
-
-
-class TestBedrockLlamaTokenizer:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"generation_token_count": 13}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- def test_input_tokens_left(self):
- assert (
- BedrockLlamaTokenizer(model="meta.llama").count_input_tokens_left(
- "[INST] <>\n{{ system_prompt }}\n<>\n\n{{ usr_msg_1 }} [/INST] {{ model_msg_1 }} [INST] {{ usr_msg_2 }} [/INST]"
- )
- == 2026
- )
-
- def test_ouput_tokens_left(self):
- assert (
- BedrockLlamaTokenizer(model="meta.llama").count_output_tokens_left(
- "[INST] <>\n{{ system_prompt }}\n<>\n\n{{ usr_msg_1 }} [/INST] {{ model_msg_1 }} [INST] {{ usr_msg_2 }} [/INST]"
- )
- == 2026
- )
diff --git a/tests/unit/tokenizers/test_bedrock_titan_tokenizer.py b/tests/unit/tokenizers/test_bedrock_titan_tokenizer.py
deleted file mode 100644
index c4f4f42ad1..0000000000
--- a/tests/unit/tokenizers/test_bedrock_titan_tokenizer.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import pytest
-from unittest import mock
-from griptape.tokenizers import BedrockTitanTokenizer
-
-
-class TestBedrockTitanTokenizer:
- @pytest.fixture(autouse=True)
- def mock_session(self, mocker):
- fake_tokenization = '{"inputTextTokenCount": 13}'
- mock_session_class = mocker.patch("boto3.Session")
-
- mock_session_object = mock.Mock()
- mock_client = mock.Mock()
- mock_response = mock.Mock()
-
- mock_response.get().read.return_value = fake_tokenization
- mock_client.invoke_model.return_value = mock_response
- mock_session_object.client.return_value = mock_client
- mock_session_class.return_value = mock_session_object
-
- def test_input_tokens_left(self):
- assert (
- BedrockTitanTokenizer(model="amazon.titan").count_input_tokens_left("Instructions: foo\nUser: bar\nBot:")
- == 4090
- )
-
- def test_output_tokens_left(self):
- assert (
- BedrockTitanTokenizer(model="amazon.titan").count_output_tokens_left("Instructions: foo\nUser: bar\nBot:")
- == 7994
- )
diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py
index 4f111b8d8e..bdf8f09bed 100644
--- a/tests/utils/structure_tester.py
+++ b/tests/utils/structure_tester.py
@@ -14,10 +14,6 @@
BasePromptDriver,
AmazonBedrockPromptDriver,
AnthropicPromptDriver,
- BedrockClaudePromptModelDriver,
- BedrockJurassicPromptModelDriver,
- BedrockTitanPromptModelDriver,
- BedrockLlamaPromptModelDriver,
CoherePromptDriver,
OpenAiChatPromptDriver,
OpenAiCompletionPromptDriver,
@@ -143,69 +139,42 @@ class TesterPromptDriverOption:
prompt_driver=CoherePromptDriver(model="command", api_key=os.environ["COHERE_API_KEY"]), enabled=True
),
"BEDROCK_TITAN": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="amazon.titan-tg1-large", prompt_model_driver=BedrockTitanPromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="amazon.titan-tg1-large"), enabled=True
),
"BEDROCK_CLAUDE_INSTANT": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="anthropic.claude-instant-v1", prompt_model_driver=BedrockClaudePromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-instant-v1"), enabled=True
),
"BEDROCK_CLAUDE_2": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="anthropic.claude-v2", prompt_model_driver=BedrockClaudePromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-v2"), enabled=True
),
"BEDROCK_CLAUDE_2.1": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="anthropic.claude-v2:1", prompt_model_driver=BedrockClaudePromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-v2:1"), enabled=True
),
"BEDROCK_CLAUDE_3_SONNET": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="anthropic.claude-3-sonnet-20240229-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-3-sonnet-20240229-v1:0"), enabled=True
),
"BEDROCK_CLAUDE_3_HAIKU": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="anthropic.claude-3-haiku-20240307-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-3-haiku-20240307-v1:0"), enabled=True
),
"BEDROCK_J2": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="ai21.j2-ultra", prompt_model_driver=BedrockJurassicPromptModelDriver()
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="ai21.j2-ultra"), enabled=True
),
"BEDROCK_LLAMA2_13B": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="meta.llama2-13b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="meta.llama2-13b-chat-v1"), enabled=True
),
"BEDROCK_LLAMA2_70B": TesterPromptDriverOption(
- prompt_driver=AmazonBedrockPromptDriver(
- model="meta.llama2-70b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1
- ),
- enabled=True,
+ prompt_driver=AmazonBedrockPromptDriver(model="meta.llama2-70b-chat-v1"), enabled=True
),
"SAGEMAKER_LLAMA_7B": TesterPromptDriverOption(
prompt_driver=AmazonSageMakerPromptDriver(
- model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"],
+ endpoint=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"],
prompt_model_driver=SageMakerLlamaPromptModelDriver(max_tokens=4096),
),
enabled=False,
),
"SAGEMAKER_FALCON_7b": TesterPromptDriverOption(
prompt_driver=AmazonSageMakerPromptDriver(
- model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
+ endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
prompt_model_driver=SageMakerFalconPromptModelDriver(),
),
enabled=False,