Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce scope of Amazon SageMaker Prompt Driver to only Jumpstart #836

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Removed `BedrockJurassicTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead.
- **BREAKING**: Removed `SagemakerFalconPromptDriver`, use `AmazonSageMakerJumpstartPromptDriver` instead.
- **BREAKING**: Removed `SagemakerLlamaPromptDriver`, use `AmazonSageMakerJumpstartPromptDriver` instead.
- **BREAKING**: Renamed `AmazonSagemakerPromptDriver` to `AmazonSageMakerJumpstartPromptDriver`.
- Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).
- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`.
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
Expand Down
65 changes: 4 additions & 61 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,87 +379,30 @@ agent = Agent(
agent.run("How many helicopters can a human eat in one sitting?")
```

### Multi Model Prompt Drivers
Certain LLM providers such as Amazon SageMaker support many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)
through the [prompt_model_driver](../../reference/griptape/drivers/prompt/base_multi_model_prompt_driver.md#griptape.drivers.prompt.base_multi_model_prompt_driver.BaseMultiModelPromptDriver.prompt_model_driver) parameter.
[Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)s allows for model-specific customization for Prompt Drivers.


#### Amazon SageMaker
#### Amazon SageMaker Jumpstart

!!! info
This driver requires the `drivers-prompt-amazon-sagemaker` [extra](../index.md#extras).

The [AmazonSageMakerPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.md) uses [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) for inference on AWS.

!!! info
For single model endpoints, the `model` parameter does not need to be specified.
For multi-model endpoints, the `model` parameter should be the inference component name.

!!! warning
Make sure that the selected prompt model driver is compatible with the selected model. Note that even the same
logical model can require different prompt model drivers depending on how it is bundled in the endpoint. For
example, the reponse format are different for `Meta-Llama-3-8B-Instruct` when deployed via
"Amazon SageMaker JumpStart" and "Hugging Face on Amazon SageMaker".

##### Llama

!!! info
`SageMakerLlamaPromptModelDriver` requires a tokenizer corresponding to a [Gated Model](https://huggingface.co/docs/hub/en/models-gated) on Hugging Face.

Make sure to request access to the [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model on Hugging Face and configure your environment for hugging face use.
The [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) uses [Amazon SageMaker Jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html) for inference on AWS.

```python title="PYTEST_IGNORE"
import os
from griptape.structures import Agent
from griptape.drivers import (
AmazonSageMakerPromptDriver,
SageMakerLlamaPromptModelDriver,
)
from griptape.rules import Rule
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonSageMakerPromptDriver(
endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"],
model=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_INFERENCE_COMPONENT_NAME"],
prompt_model_driver=SageMakerLlamaPromptModelDriver(),
temperature=0.75,
)
),
rules=[
Rule(
value="You are a helpful, respectful and honest assistant who is also a swarthy pirate."
"You only speak like a pirate and you never break character."
)
],
)

agent.run("Hello!")
```

##### Falcon

```python title="PYTEST_IGNORE"
import os
from griptape.structures import Agent
from griptape.drivers import (
AmazonSageMakerPromptDriver,
AmazonSageMakerJumpstartPromptDriver,
SageMakerFalconPromptModelDriver,
)
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonSageMakerPromptDriver(
prompt_driver=AmazonSageMakerJumpstartPromptDriver(
endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
model=os.environ["SAGEMAKER_FALCON_INFERENCE_COMPONENT_NAME"],
prompt_model_driver=SageMakerFalconPromptModelDriver(),
)
)
)

agent.run("What is a good lasagna recipe?")

```
13 changes: 2 additions & 11 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from .prompt.huggingface_pipeline_prompt_driver import HuggingFacePipelinePromptDriver
from .prompt.huggingface_hub_prompt_driver import HuggingFaceHubPromptDriver
from .prompt.anthropic_prompt_driver import AnthropicPromptDriver
from .prompt.amazon_sagemaker_prompt_driver import AmazonSageMakerPromptDriver
from .prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver
from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver
from .prompt.google_prompt_driver import GooglePromptDriver
from .prompt.base_multi_model_prompt_driver import BaseMultiModelPromptDriver
from .prompt.dummy_prompt_driver import DummyPromptDriver

from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
Expand Down Expand Up @@ -52,10 +51,6 @@
from .sql.snowflake_sql_driver import SnowflakeSqlDriver
from .sql.sql_driver import SqlDriver

from .prompt_model.base_prompt_model_driver import BasePromptModelDriver
from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver
from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver

from .image_generation_model.base_image_generation_model_driver import BaseImageGenerationModelDriver
from .image_generation_model.bedrock_stable_diffusion_image_generation_model_driver import (
BedrockStableDiffusionImageGenerationModelDriver,
Expand Down Expand Up @@ -119,10 +114,9 @@
"HuggingFacePipelinePromptDriver",
"HuggingFaceHubPromptDriver",
"AnthropicPromptDriver",
"AmazonSageMakerPromptDriver",
"AmazonSageMakerJumpstartPromptDriver",
"AmazonBedrockPromptDriver",
"GooglePromptDriver",
"BaseMultiModelPromptDriver",
"DummyPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
Expand Down Expand Up @@ -158,9 +152,6 @@
"AmazonRedshiftSqlDriver",
"SnowflakeSqlDriver",
"SqlDriver",
"BasePromptModelDriver",
"SageMakerLlamaPromptModelDriver",
"SageMakerFalconPromptModelDriver",
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.artifacts import TextArtifact
from griptape.drivers.prompt.base_prompt_driver import BasePromptDriver
from griptape.utils import import_optional_dependency
from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver

if TYPE_CHECKING:
from griptape.utils import PromptStack
import boto3


@define
class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sagemaker_client: Any = field(
default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
Expand All @@ -22,17 +23,35 @@ class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
model: str = field(default=None, kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(
tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
max_output_tokens=self.max_tokens,
),
takes_self=True,
),
kw_only=True,
)

@stream.validator # pyright: ignore
def validate_stream(self, _, stream):
if stream:
raise ValueError("streaming is not supported")

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
prompt = self.tokenizer.tokenizer.apply_chat_template(
[{"role": i.role, "content": i.content} for i in prompt_stack.inputs],
tokenize=False,
add_generation_prompt=True,
)

payload = {
"inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
"parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
"inputs": prompt,
"parameters": {"temperature": self.temperature, "max_new_tokens": self.max_tokens, "do_sample": True},
}

response = self.sagemaker_client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Expand All @@ -43,10 +62,13 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact:

decoded_body = json.loads(response["Body"].read().decode("utf8"))

if decoded_body:
return self.prompt_model_driver.process_output(decoded_body)
if isinstance(decoded_body, list):
if decoded_body:
return TextArtifact(decoded_body[0]["generated_text"])
else:
raise ValueError("model response is empty")
else:
raise Exception("model response is empty")
return TextArtifact(decoded_body["generated_text"])

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
raise NotImplementedError("streaming is not supported")
38 changes: 0 additions & 38 deletions griptape/drivers/prompt/base_multi_model_prompt_driver.py

This file was deleted.

2 changes: 1 addition & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
default=Factory(lambda: (ImportError, ValueError)), kw_only=True
)
model: str = field(metadata={"serializable": True})
tokenizer: BaseTokenizer
tokenizer: BaseTokenizer = field(kw_only=True)
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

def max_output_tokens(self, text: str | list) -> int:
Expand Down
Empty file.
29 changes: 0 additions & 29 deletions griptape/drivers/prompt_model/base_prompt_model_driver.py

This file was deleted.

This file was deleted.

This file was deleted.

Loading
Loading