diff --git a/CHANGELOG.md b/CHANGELOG.md index ad589408c..38a9d53b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BedrockJurassicTokenizer`, use `SimpleTokenizer` instead. - **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead. - **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead. +- **BREAKING**: Removed `SagemakerFalconPromptDriver`, use `AmazonSageMakerJumpstartPromptDriver` instead. +- **BREAKING**: Removed `SagemakerLlamaPromptDriver`, use `AmazonSageMakerJumpstartPromptDriver` instead. +- **BREAKING**: Renamed `AmazonSagemakerPromptDriver` to `AmazonSageMakerJumpstartPromptDriver`. - Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html). - `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`. - Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`. diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 859dcec03..434166e6a 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -379,87 +379,30 @@ agent = Agent( agent.run("How many helicopters can a human eat in one sitting?") ``` -### Multi Model Prompt Drivers -Certain LLM providers such as Amazon SageMaker support many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md) -through the [prompt_model_driver](../../reference/griptape/drivers/prompt/base_multi_model_prompt_driver.md#griptape.drivers.prompt.base_multi_model_prompt_driver.BaseMultiModelPromptDriver.prompt_model_driver) parameter. -[Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)s allows for model-specific customization for Prompt Drivers. - - -#### Amazon SageMaker +#### Amazon SageMaker Jumpstart !!! info This driver requires the `drivers-prompt-amazon-sagemaker` [extra](../index.md#extras). -The [AmazonSageMakerPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.md) uses [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) for inference on AWS. - -!!! info - For single model endpoints, the `model` parameter does not need to be specified. - For multi-model endpoints, the `model` parameter should be the inference component name. - -!!! warning - Make sure that the selected prompt model driver is compatible with the selected model. Note that even the same - logical model can require different prompt model drivers depending on how it is bundled in the endpoint. For - example, the reponse format are different for `Meta-Llama-3-8B-Instruct` when deployed via - "Amazon SageMaker JumpStart" and "Hugging Face on Amazon SageMaker". - -##### Llama - -!!! info - `SageMakerLlamaPromptModelDriver` requires a tokenizer corresponding to a [Gated Model](https://huggingface.co/docs/hub/en/models-gated) on Hugging Face. - - Make sure to request access to the [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model on Hugging Face and configure your environment for hugging face use. +The [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) uses [Amazon SageMaker Jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html) for inference on AWS. ```python title="PYTEST_IGNORE" import os from griptape.structures import Agent from griptape.drivers import ( - AmazonSageMakerPromptDriver, - SageMakerLlamaPromptModelDriver, -) -from griptape.rules import Rule -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AmazonSageMakerPromptDriver( - endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], - model=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_INFERENCE_COMPONENT_NAME"], - prompt_model_driver=SageMakerLlamaPromptModelDriver(), - temperature=0.75, - ) - ), - rules=[ - Rule( - value="You are a helpful, respectful and honest assistant who is also a swarthy pirate." - "You only speak like a pirate and you never break character." - ) - ], -) - -agent.run("Hello!") -``` - -##### Falcon - -```python title="PYTEST_IGNORE" -import os -from griptape.structures import Agent -from griptape.drivers import ( - AmazonSageMakerPromptDriver, + AmazonSageMakerJumpstartPromptDriver, SageMakerFalconPromptModelDriver, ) from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - prompt_driver=AmazonSageMakerPromptDriver( + prompt_driver=AmazonSageMakerJumpstartPromptDriver( endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], model=os.environ["SAGEMAKER_FALCON_INFERENCE_COMPONENT_NAME"], - prompt_model_driver=SageMakerFalconPromptModelDriver(), ) ) ) agent.run("What is a good lasagna recipe?") - ``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 02a3882fb..73230a5e4 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -7,10 +7,9 @@ from .prompt.huggingface_pipeline_prompt_driver import HuggingFacePipelinePromptDriver from .prompt.huggingface_hub_prompt_driver import HuggingFaceHubPromptDriver from .prompt.anthropic_prompt_driver import AnthropicPromptDriver -from .prompt.amazon_sagemaker_prompt_driver import AmazonSageMakerPromptDriver +from .prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver from .prompt.google_prompt_driver import GooglePromptDriver -from .prompt.base_multi_model_prompt_driver import BaseMultiModelPromptDriver from .prompt.dummy_prompt_driver import DummyPromptDriver from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver @@ -52,10 +51,6 @@ from .sql.snowflake_sql_driver import SnowflakeSqlDriver from .sql.sql_driver import SqlDriver -from .prompt_model.base_prompt_model_driver import BasePromptModelDriver -from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver -from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver - from .image_generation_model.base_image_generation_model_driver import BaseImageGenerationModelDriver from .image_generation_model.bedrock_stable_diffusion_image_generation_model_driver import ( BedrockStableDiffusionImageGenerationModelDriver, @@ -119,10 +114,9 @@ "HuggingFacePipelinePromptDriver", "HuggingFaceHubPromptDriver", "AnthropicPromptDriver", - "AmazonSageMakerPromptDriver", + "AmazonSageMakerJumpstartPromptDriver", "AmazonBedrockPromptDriver", "GooglePromptDriver", - "BaseMultiModelPromptDriver", "DummyPromptDriver", "BaseConversationMemoryDriver", "LocalConversationMemoryDriver", @@ -158,9 +152,6 @@ "AmazonRedshiftSqlDriver", "SnowflakeSqlDriver", "SqlDriver", - "BasePromptModelDriver", - "SageMakerLlamaPromptModelDriver", - "SageMakerFalconPromptModelDriver", "BaseImageGenerationModelDriver", "BedrockStableDiffusionImageGenerationModelDriver", "BedrockTitanImageGenerationModelDriver", diff --git a/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py similarity index 58% rename from griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py rename to griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 2934ea642..af5dbdf5a 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -3,9 +3,10 @@ from typing import TYPE_CHECKING, Any from collections.abc import Iterator from attrs import define, field, Factory +from griptape.tokenizers import HuggingFaceTokenizer from griptape.artifacts import TextArtifact +from griptape.drivers.prompt.base_prompt_driver import BasePromptDriver from griptape.utils import import_optional_dependency -from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver if TYPE_CHECKING: from griptape.utils import PromptStack @@ -13,7 +14,7 @@ @define -class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver): +class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) sagemaker_client: Any = field( default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True @@ -22,6 +23,17 @@ class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver): model: str = field(default=None, kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) + tokenizer: HuggingFaceTokenizer = field( + default=Factory( + lambda self: HuggingFaceTokenizer( + tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), + max_output_tokens=self.max_tokens, + ), + takes_self=True, + ), + kw_only=True, + ) @stream.validator # pyright: ignore def validate_stream(self, _, stream): @@ -29,10 +41,17 @@ def validate_stream(self, _, stream): raise ValueError("streaming is not supported") def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + prompt = self.tokenizer.tokenizer.apply_chat_template( + [{"role": i.role, "content": i.content} for i in prompt_stack.inputs], + tokenize=False, + add_generation_prompt=True, + ) + payload = { - "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack), - "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack), + "inputs": prompt, + "parameters": {"temperature": self.temperature, "max_new_tokens": self.max_tokens, "do_sample": True}, } + response = self.sagemaker_client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", @@ -43,10 +62,13 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: decoded_body = json.loads(response["Body"].read().decode("utf8")) - if decoded_body: - return self.prompt_model_driver.process_output(decoded_body) + if isinstance(decoded_body, list): + if decoded_body: + return TextArtifact(decoded_body[0]["generated_text"]) + else: + raise ValueError("model response is empty") else: - raise Exception("model response is empty") + return TextArtifact(decoded_body["generated_text"]) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: raise NotImplementedError("streaming is not supported") diff --git a/griptape/drivers/prompt/base_multi_model_prompt_driver.py b/griptape/drivers/prompt/base_multi_model_prompt_driver.py deleted file mode 100644 index 5411ea730..000000000 --- a/griptape/drivers/prompt/base_multi_model_prompt_driver.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from abc import ABC -from .base_prompt_driver import BasePromptDriver -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from griptape.tokenizers import BaseTokenizer - from griptape.drivers import BasePromptModelDriver - - -@define -class BaseMultiModelPromptDriver(BasePromptDriver, ABC): - """Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models. - - Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack - into a model input and parameters, and to process the model output. - - Attributes: - model: Name of the model to use. - tokenizer: Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver. - prompt_model_driver: Prompt Model Driver to use. - """ - - tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) - prompt_model_driver: BasePromptModelDriver = field(kw_only=True, metadata={"serializable": True}) - stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - - @stream.validator # pyright: ignore - def validate_stream(self, _, stream): - if stream and not self.prompt_model_driver.supports_streaming: - raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming") - - def __attrs_post_init__(self) -> None: - self.prompt_model_driver.prompt_driver = self - - if not self.tokenizer: - self.tokenizer = self.prompt_model_driver.tokenizer diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 096035f8b..5dbe9c7e3 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -39,7 +39,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): default=Factory(lambda: (ImportError, ValueError)), kw_only=True ) model: str = field(metadata={"serializable": True}) - tokenizer: BaseTokenizer + tokenizer: BaseTokenizer = field(kw_only=True) stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def max_output_tokens(self, text: str | list) -> int: diff --git a/griptape/drivers/prompt_model/__init__.py b/griptape/drivers/prompt_model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/drivers/prompt_model/base_prompt_model_driver.py b/griptape/drivers/prompt_model/base_prompt_model_driver.py deleted file mode 100644 index 096802370..000000000 --- a/griptape/drivers/prompt_model/base_prompt_model_driver.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Optional -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptDriver -from griptape.tokenizers import BaseTokenizer -from griptape.mixins import SerializableMixin - - -@define -class BasePromptModelDriver(SerializableMixin, ABC): - max_tokens: Optional[int] = field(default=None, kw_only=True) - prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) - supports_streaming: bool = field(default=True, kw_only=True) - - @property - @abstractmethod - def tokenizer(self) -> BaseTokenizer: ... - - @abstractmethod - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict: ... - - @abstractmethod - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: ... - - @abstractmethod - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: ... diff --git a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py deleted file mode 100644 index a5a8a4dc9..000000000 --- a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack, import_optional_dependency -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer - - -@define -class SageMakerFalconPromptModelDriver(BasePromptModelDriver): - DEFAULT_MAX_TOKENS = 600 - - _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> HuggingFaceTokenizer: - if self._tokenizer is None: - self._tokenizer = HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained("tiiuae/falcon-40b"), - max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS, - ) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - return self.prompt_driver.prompt_stack_to_string(prompt_stack) - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_model_input(prompt_stack) - stop_sequences = self.prompt_driver.tokenizer.stop_sequences - - return { - "max_new_tokens": self.prompt_driver.max_output_tokens(prompt), - "temperature": self.prompt_driver.temperature, - "do_sample": True, - "stop": stop_sequences, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - if isinstance(output, list): - return TextArtifact(output[0]["generated_text"].strip()) - else: - raise ValueError("output must be an instance of 'list'") diff --git a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py deleted file mode 100644 index 7e934d4a6..000000000 --- a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack, import_optional_dependency -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer - - -@define -class SageMakerLlamaPromptModelDriver(BasePromptModelDriver): - # Default context length for all Llama 3 models is 8K as per https://huggingface.co/blog/llama3 - DEFAULT_MAX_INPUT_TOKENS = 8000 - - _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> HuggingFaceTokenizer: - if self._tokenizer is None: - self._tokenizer = HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained( - "meta-llama/Meta-Llama-3-8B-Instruct", model_max_length=self.DEFAULT_MAX_INPUT_TOKENS - ), - max_output_tokens=self.max_tokens or self.DEFAULT_MAX_INPUT_TOKENS, - ) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - return self.tokenizer.tokenizer.apply_chat_template( # pyright: ignore - [{"role": i.role, "content": i.content} for i in prompt_stack.inputs], - tokenize=False, - add_generation_prompt=True, - ) - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack) - return { - "max_new_tokens": self.prompt_driver.max_output_tokens(prompt), - "temperature": self.prompt_driver.temperature, - "stop": self.tokenizer.tokenizer.eos_token, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - # This output format is specific to the Llama 3 Instruct models when deployed via SageMaker JumpStart. - if isinstance(output, dict): - return TextArtifact(output["generated_text"]) - else: - raise ValueError("Invalid output format.") diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 2dd454b53..6ba23d6fe 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -103,7 +103,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: from griptape.utils.import_utils import import_optional_dependency, is_dependency_installed # These modules are required to avoid `NameError`s when resolving types. - from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver, BasePromptModelDriver + from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.structures import Structure from griptape.utils import PromptStack from griptape.tokenizers.base_tokenizer import BaseTokenizer @@ -121,7 +121,6 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BasePromptDriver": BasePromptDriver, "BaseTokenizer": BaseTokenizer, - "BasePromptModelDriver": BasePromptModelDriver, "boto3": boto3, "Client": Client, }, diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py new file mode 100644 index 000000000..6ad50bb52 --- /dev/null +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -0,0 +1,93 @@ +from typing import Any +from botocore.response import StreamingBody +from griptape.tokenizers import HuggingFaceTokenizer +from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver +from griptape.utils import PromptStack +from io import BytesIO +import json +import pytest + + +def to_streaming_body(data: Any) -> StreamingBody: + bytes = json.dumps(data).encode("utf-8") + + return StreamingBody(BytesIO(bytes), len(bytes)) + + +class TestAmazonSageMakerJumpstartPromptDriver: + @pytest.fixture(autouse=True) + def tokenizer(self, mocker): + from_pretrained = tokenizer = mocker.patch("transformers.AutoTokenizer").from_pretrained + from_pretrained.return_value.apply_chat_template.return_value = "foo\n\nUser: bar" + from_pretrained.return_value.model_max_length = 8000 + + return tokenizer + + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + return mocker.patch("boto3.Session").return_value.client.return_value + + def test_init(self): + assert AmazonSageMakerJumpstartPromptDriver(endpoint="foo") + + def test_try_run(self, mock_client): + # Given + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model") + prompt_stack = PromptStack() + prompt_stack.add_user_input("prompt-stack") + + # When + response_body = [{"generated_text": "foobar"}] + mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} + text_artifact = driver.try_run(prompt_stack) + assert isinstance(driver.tokenizer, HuggingFaceTokenizer) + + # Then + mock_client.invoke_endpoint.assert_called_with( + EndpointName=driver.endpoint, + ContentType="application/json", + Body=json.dumps( + { + "inputs": "foo\n\nUser: bar", + "parameters": {"temperature": driver.temperature, "max_new_tokens": 250, "do_sample": True}, + } + ), + CustomAttributes="accept_eula=true", + ) + + assert text_artifact.value == "foobar" + + # When + response_body = {"generated_text": "foobar"} + mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} + text_artifact = driver.try_run(prompt_stack) + assert isinstance(driver.tokenizer, HuggingFaceTokenizer) + + # Then + mock_client.invoke_endpoint.assert_called_with( + EndpointName=driver.endpoint, + ContentType="application/json", + Body=json.dumps( + { + "inputs": "foo\n\nUser: bar", + "parameters": {"temperature": driver.temperature, "max_new_tokens": 250, "do_sample": True}, + } + ), + CustomAttributes="accept_eula=true", + ) + + assert text_artifact.value == "foobar" + + def test_try_run_throws_on_empty_response(self, mock_client): + # Given + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model") + mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body([])} + prompt_stack = PromptStack() + prompt_stack.add_user_input("prompt-stack") + + # When + with pytest.raises(Exception) as e: + driver.try_run(prompt_stack) + + # Then + assert e.value.args[0] == "model response is empty" diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py deleted file mode 100644 index c6692e1ba..000000000 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py +++ /dev/null @@ -1,90 +0,0 @@ -from botocore.response import StreamingBody -from griptape.artifacts import TextArtifact -from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerFalconPromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer, OpenAiTokenizer -from griptape.utils import PromptStack -from io import BytesIO -from unittest.mock import Mock -import json -import pytest - - -class TestAmazonSageMakerPromptDriver: - @pytest.fixture - def mock_model_driver(self): - mock_model_driver = Mock() - mock_model_driver.prompt_stack_to_model_input.return_value = "model-inputs" - mock_model_driver.prompt_stack_to_model_params.return_value = "model-params" - mock_model_driver.process_output.return_value = TextArtifact("model-output") - return mock_model_driver - - @pytest.fixture(autouse=True) - def mock_client(self, mocker): - return mocker.patch("boto3.Session").return_value.client.return_value - - def test_init(self): - assert AmazonSageMakerPromptDriver(endpoint="foo", prompt_model_driver=SageMakerFalconPromptModelDriver()) - - def test_custom_tokenizer(self): - assert isinstance( - AmazonSageMakerPromptDriver( - endpoint="foo", prompt_model_driver=SageMakerFalconPromptModelDriver() - ).tokenizer, - HuggingFaceTokenizer, - ) - - assert isinstance( - AmazonSageMakerPromptDriver( - endpoint="foo", - tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), - prompt_model_driver=SageMakerFalconPromptModelDriver(), - ).tokenizer, - OpenAiTokenizer, - ) - - def test_try_run(self, mock_model_driver, mock_client): - # Given - driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver) - prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") - response_body = "invoke-endpoint-response-body" - mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} - - # When - text_artifact = driver.try_run(prompt_stack) - - # Then - mock_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack) - mock_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack) - mock_client.invoke_endpoint.assert_called_once_with( - EndpointName=driver.endpoint, - ContentType="application/json", - Body=json.dumps( - { - "inputs": mock_model_driver.prompt_stack_to_model_input.return_value, - "parameters": mock_model_driver.prompt_stack_to_model_params.return_value, - } - ), - CustomAttributes="accept_eula=true", - ) - mock_model_driver.process_output.assert_called_once_with(response_body) - assert text_artifact == mock_model_driver.process_output.return_value - - def test_try_run_throws_on_empty_response(self, mock_model_driver, mock_client): - # Given - driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver) - mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body("")} - prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") - - # When - with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) - - # Then - assert e.value.args[0] == "model response is empty" - - -def to_streaming_body(text: str) -> StreamingBody: - bytes = json.dumps(text).encode("utf-8") - return StreamingBody(BytesIO(bytes), len(bytes)) diff --git a/tests/unit/drivers/prompt_models/__init__.py b/tests/unit/drivers/prompt_models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py deleted file mode 100644 index 78d990229..000000000 --- a/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py +++ /dev/null @@ -1,43 +0,0 @@ -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerFalconPromptModelDriver - - -class TestSageMakerFalconPromptModelDriver: - @pytest.fixture - def driver(self): - return AmazonSageMakerPromptDriver( - endpoint="endpoint-name", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=SageMakerFalconPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("foo") - stack.add_user_input("bar") - - return stack - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack): - model_input = driver.prompt_stack_to_model_input(stack) - - assert isinstance(model_input, str) - assert model_input.startswith("foo\n\nUser: bar") - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 590 - assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 - - def test_process_output(self, driver, stack): - assert driver.process_output([{"generated_text": "foobar"}]).value == "foobar" - - def test_tokenizer_max_model_length(self, driver): - assert driver.tokenizer.tokenizer.model_max_length == 2048 diff --git a/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py deleted file mode 100644 index b39ce458e..000000000 --- a/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py +++ /dev/null @@ -1,67 +0,0 @@ -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerLlamaPromptModelDriver - - -class TestSageMakerLlamaPromptModelDriver: - @pytest.fixture(autouse=True) - def llama3_instruct_tokenizer(self, mocker): - tokenizer = mocker.patch("transformers.AutoTokenizer").return_value - tokenizer.model_max_length = 8000 - - return tokenizer - - @pytest.fixture(autouse=True) - def hugging_face_tokenizer(self, mocker, llama3_instruct_tokenizer): - tokenizer = mocker.patch( - "griptape.drivers.prompt_model.sagemaker_llama_prompt_model_driver.HuggingFaceTokenizer" - ).return_value - tokenizer.count_output_tokens_left.return_value = 7991 - tokenizer.tokenizer = llama3_instruct_tokenizer - return tokenizer - - @pytest.fixture - def driver(self): - return AmazonSageMakerPromptDriver( - endpoint="endpoint-name", - model="inference-component-name", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=SageMakerLlamaPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("foo") - stack.add_user_input("bar") - - return stack - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack, hugging_face_tokenizer): - driver.prompt_stack_to_model_input(stack) - - hugging_face_tokenizer.tokenizer.apply_chat_template.assert_called_once_with( - [{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}], - tokenize=False, - add_generation_prompt=True, - ) - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 7991 - assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 - - def test_process_output(self, driver, stack): - assert driver.process_output({"generated_text": "foobar"}).value == "foobar" - - def test_process_output_invalid_format(self, driver, stack): - with pytest.raises(ValueError): - assert driver.process_output([{"generated_text": "foobar"}]) - - def test_tokenizer_max_model_length(self, driver): - assert driver.tokenizer.tokenizer.model_max_length == 8000 diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 592843279..352ae089d 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -18,9 +18,7 @@ OpenAiChatPromptDriver, OpenAiCompletionPromptDriver, AzureOpenAiChatPromptDriver, - AmazonSageMakerPromptDriver, - SageMakerLlamaPromptModelDriver, - SageMakerFalconPromptModelDriver, + AmazonSageMakerJumpstartPromptDriver, GooglePromptDriver, ) @@ -199,17 +197,11 @@ class TesterPromptDriverOption: prompt_driver=AmazonBedrockPromptDriver(model="mistral.mistral-small-2402-v1:0"), enabled=True ), "SAGEMAKER_LLAMA_7B": TesterPromptDriverOption( - prompt_driver=AmazonSageMakerPromptDriver( - endpoint=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"], - prompt_model_driver=SageMakerLlamaPromptModelDriver(max_tokens=4096), - ), + prompt_driver=AmazonSageMakerJumpstartPromptDriver(endpoint=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"]), enabled=False, ), "SAGEMAKER_FALCON_7b": TesterPromptDriverOption( - prompt_driver=AmazonSageMakerPromptDriver( - endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], - prompt_model_driver=SageMakerFalconPromptModelDriver(), - ), + prompt_driver=AmazonSageMakerJumpstartPromptDriver(endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"]), enabled=False, ), "GOOGLE_GEMINI_PRO": TesterPromptDriverOption(