Skip to content

Commit

Permalink
Update BedrockPromptDriver to use Converse API (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Jun 5, 2024
1 parent 7a137e1 commit e167885
Show file tree
Hide file tree
Showing 33 changed files with 224 additions and 1,253 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`.
- **BREAKING**: Removed `AmazonBedrockPromptDriver.prompt_model_driver` as it is no longer needed with the `AmazonBedrockPromptDriver` Converse API implementation.
- **BREAKING**: Removed `BedrockClaudePromptModelDriver`.
- **BREAKING**: Removed `BedrockJurassicPromptModelDriver`.
- **BREAKING**: Removed `BedrockLlamaPromptModelDriver`.
- **BREAKING**: Removed `BedrockTitanPromptModelDriver`.
- **BREAKING**: Removed `BedrockClaudeTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `BedrockJurassicTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead.
- Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).
- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`.
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`.
Expand Down
160 changes: 42 additions & 118 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,47 @@ agent = Agent(
agent.run('Briefly explain how a computer works to a young child.')
```

### Amazon Bedrock

!!! info
This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras).

The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/)'s [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).

All models supported by the Converse API are available for use with this driver.

```python
from griptape.structures import Agent
from griptape.drivers import AmazonBedrockPromptDriver
from griptape.rules import Rule
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
)
),
rules=[
Rule(
value="You are a customer service agent that is classifying emails by type. I want you to give your answer and then explain it."
)
],
)
agent.run(
"""How would you categorize this email?
<email>
Can I use my Mixmaster 4000 to mix paint, or is it only meant for mixing food?
</email>
Categories are:
(A) Pre-sale question
(B) Broken or defective item
(C) Billing question
(D) Other (please explain)"""
)
```

### Hugging Face Hub

!!! info
Expand Down Expand Up @@ -339,7 +380,7 @@ agent.run("How many helicopters can a human eat in one sitting?")
```

### Multi Model Prompt Drivers
Certain LLM providers such as Amazon SageMaker and Amazon Bedrock supports many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)
Certain LLM providers such as Amazon SageMaker support many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)
through the [prompt_model_driver](../../reference/griptape/drivers/prompt/base_multi_model_prompt_driver.md#griptape.drivers.prompt.base_multi_model_prompt_driver.BaseMultiModelPromptDriver.prompt_model_driver) parameter.
[Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)s allows for model-specific customization for Prompt Drivers.

Expand Down Expand Up @@ -422,120 +463,3 @@ agent = Agent(
agent.run("What is a good lasagna recipe?")

```

#### Amazon Bedrock

!!! info
This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras).

The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/) for inference on AWS.

##### Amazon Titan

To use this model with Amazon Bedrock, use the [BedrockTitanPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.md).

```python
from griptape.structures import Agent
from griptape.drivers import AmazonBedrockPromptDriver, BedrockTitanPromptModelDriver
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1",
prompt_model_driver=BedrockTitanPromptModelDriver(
top_p=1,
)
)
)
)
agent.run(
"Write an informational article for children about how birds fly."
"Compare how birds fly to how airplanes fly."
'Make sure to use the word "Thrust" at least three times.'
)
```

##### Anthropic Claude

To use this model with Amazon Bedrock, use the [BedrockClaudePromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.md).

```python
from griptape.structures import Agent
from griptape.drivers import AmazonBedrockPromptDriver, BedrockClaudePromptModelDriver
from griptape.rules import Rule
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
prompt_model_driver=BedrockClaudePromptModelDriver(
top_p=1,
)
)
),
rules=[
Rule(
value="You are a customer service agent that is classifying emails by type. I want you to give your answer and then explain it."
)
],
)
agent.run(
"""How would you categorize this email?
<email>
Can I use my Mixmaster 4000 to mix paint, or is it only meant for mixing food?
</email>
Categories are:
(A) Pre-sale question
(B) Broken or defective item
(C) Billing question
(D) Other (please explain)"""
)
```
##### Meta Llama 2

To use this model with Amazon Bedrock, use the [BedrockLlamaPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.md).

```python
from griptape.structures import Agent
from griptape.drivers import AmazonBedrockPromptDriver, BedrockLlamaPromptModelDriver
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="meta.llama2-13b-chat-v1",
prompt_model_driver=BedrockLlamaPromptModelDriver(),
)
)
)
agent.run(
"Write an article about impact of high inflation to GDP of a country"
)
```

##### Ai21 Jurassic

To use this model with Amazon Bedrock, use the [BedrockJurassicPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.md).

```python
from griptape.structures import Agent
from griptape.drivers import AmazonBedrockPromptDriver, BedrockJurassicPromptModelDriver
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="ai21.j2-ultra-v1",
prompt_model_driver=BedrockJurassicPromptModelDriver(top_p=0.95),
temperature=0.7,
)
)
)
agent.run(
"Suggest an outline for a blog post based on a title. "
"Title: How I put the pro in prompt engineering."
)
```
2 changes: 0 additions & 2 deletions docs/griptape-framework/structures/task-memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ from griptape.config import (
)
from griptape.drivers import (
AmazonBedrockPromptDriver,
BedrockTitanPromptModelDriver,
AmazonBedrockTitanEmbeddingDriver,
LocalVectorStoreDriver,
OpenAiChatPromptDriver,
Expand All @@ -227,7 +226,6 @@ agent = Agent(
query_engine=VectorQueryEngine(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1",
prompt_model_driver=BedrockTitanPromptModelDriver(),
),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver()
Expand Down
7 changes: 1 addition & 6 deletions griptape/config/amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
BasePromptDriver,
BaseVectorStoreDriver,
BedrockClaudeImageQueryModelDriver,
BedrockClaudePromptModelDriver,
BedrockTitanImageGenerationModelDriver,
LocalVectorStoreDriver,
)
Expand All @@ -21,11 +20,7 @@
class AmazonBedrockStructureConfig(StructureConfig):
prompt_driver: BasePromptDriver = field(
default=Factory(
lambda: AmazonBedrockPromptDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
stream=False,
prompt_model_driver=BedrockClaudePromptModelDriver(),
)
lambda: AmazonBedrockPromptDriver(model="anthropic.claude-3-sonnet-20240229-v1:0", stream=False)
),
metadata={"serializable": True},
)
Expand Down
8 changes: 0 additions & 8 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@
from .prompt_model.base_prompt_model_driver import BasePromptModelDriver
from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver
from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver
from .prompt_model.bedrock_titan_prompt_model_driver import BedrockTitanPromptModelDriver
from .prompt_model.bedrock_claude_prompt_model_driver import BedrockClaudePromptModelDriver
from .prompt_model.bedrock_jurassic_prompt_model_driver import BedrockJurassicPromptModelDriver
from .prompt_model.bedrock_llama_prompt_model_driver import BedrockLlamaPromptModelDriver

from .image_generation_model.base_image_generation_model_driver import BaseImageGenerationModelDriver
from .image_generation_model.bedrock_stable_diffusion_image_generation_model_driver import (
Expand Down Expand Up @@ -165,10 +161,6 @@
"BasePromptModelDriver",
"SageMakerLlamaPromptModelDriver",
"SageMakerFalconPromptModelDriver",
"BedrockTitanPromptModelDriver",
"BedrockClaudePromptModelDriver",
"BedrockJurassicPromptModelDriver",
"BedrockLlamaPromptModelDriver",
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import BedrockCohereTokenizer
from griptape.tokenizers.simple_tokenizer import SimpleTokenizer
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
Expand All @@ -28,8 +28,8 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
model: str = field(default=DEFAULT_MODEL, kw_only=True)
input_type: str = field(default="search_query", kw_only=True)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
tokenizer: BedrockCohereTokenizer = field(
default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True
tokenizer: SimpleTokenizer = field(
default=Factory(lambda self: SimpleTokenizer(characters_per_token=4), takes_self=True), kw_only=True
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import BedrockTitanTokenizer
from griptape.tokenizers import SimpleTokenizer
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
Expand All @@ -24,8 +24,8 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
tokenizer: BedrockTitanTokenizer = field(
default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True
tokenizer: SimpleTokenizer = field(
default=Factory(lambda self: SimpleTokenizer(characters_per_token=4), takes_self=True), kw_only=True
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
Expand Down
70 changes: 41 additions & 29 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,67 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.drivers import BasePromptDriver
from griptape.artifacts import TextArtifact
from griptape.utils import import_optional_dependency
from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver
from griptape.tokenizers import SimpleTokenizer, BaseTokenizer

if TYPE_CHECKING:
from griptape.utils import PromptStack
import boto3


@define
class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver):
class AmazonBedrockPromptDriver(BasePromptDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)
additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True)
tokenizer: BaseTokenizer = field(default=Factory(lambda: SimpleTokenizer(characters_per_token=4)), kw_only=True)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
if isinstance(model_input, dict):
payload.update(model_input)
response = self.bedrock_client.converse(**self._base_params(prompt_stack))

response = self.bedrock_client.invoke_model(
modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
)
output_message = response["output"]["message"]
output_content = output_message["content"][0]["text"]

response_body = response["body"].read()
return TextArtifact(output_content)

if response_body:
return self.prompt_model_driver.process_output(response_body)
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

stream = response.get("stream")
if stream is not None:
for event in stream:
if "contentBlockDelta" in event:
yield TextArtifact(event["contentBlockDelta"]["delta"]["text"])
else:
raise Exception("model response is empty")

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
if isinstance(model_input, dict):
payload.update(model_input)

response = self.bedrock_client.invoke_model_with_response_stream(
modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
)

response_body = response["body"]
if response_body:
for chunk in response["body"]:
chunk_bytes = chunk["chunk"]["bytes"]
yield self.prompt_model_driver.process_output(chunk_bytes)
def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [
{"text": input.content} for input in prompt_stack.inputs if input.is_system() and input.content
]
messages = [
{"role": self.__to_amazon_bedrock_role(input), "content": [{"text": input.content}]}
for input in prompt_stack.inputs
if not input.is_system()
]

return {
"modelId": self.model,
"messages": messages,
"system": system_messages,
"inferenceConfig": {"temperature": self.temperature},
"additionalModelRequestFields": self.additional_model_request_fields,
}

def __to_amazon_bedrock_role(self, prompt_input: PromptStack.Input) -> str:
if prompt_input.is_system():
return "system"
elif prompt_input.is_assistant():
return "assistant"
else:
raise Exception("model response is empty")
return "user"
Loading

0 comments on commit e167885

Please sign in to comment.