From 5479d770e1ac1ec1cc89b6f8136d86e381d89de8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 29 Aug 2024 15:35:45 -0700 Subject: [PATCH] Fix missing maxTokens in AmazonBedrockPromptDriver --- CHANGELOG.md | 1 + griptape/drivers/prompt/amazon_bedrock_prompt_driver.py | 2 +- .../unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c351588c4..7d2e987da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parsing streaming response with some OpenAi compatible services. - Issue in `PromptSummaryEngine` if there are no artifacts during recursive summarization. - Issue in `GooglePromptDriver` using Tools with no schema. +- Missing `maxTokens` inference parameter in `AmazonBedrockPromptDriver`. **Note**: This release includes breaking changes. Please refer to the [Migration Guide](./MIGRATION.md#030x-to-031x) for details. diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b663d06fd..bc339f618 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -98,7 +98,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "modelId": self.model, "messages": messages, "system": system_messages, - "inferenceConfig": {"temperature": self.temperature}, + "inferenceConfig": {"temperature": self.temperature, "maxTokens": self.max_tokens}, "additionalModelRequestFields": self.additional_model_request_fields, **( {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index ebe25bb28..c36c46074 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -344,7 +344,7 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): mock_converse.assert_called_once_with( modelId=driver.model, messages=messages, - inferenceConfig={"temperature": driver.temperature}, + inferenceConfig={"temperature": driver.temperature, "maxTokens": driver.max_tokens}, additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( @@ -376,7 +376,7 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ mock_converse_stream.assert_called_once_with( modelId=driver.model, messages=messages, - inferenceConfig={"temperature": driver.temperature}, + inferenceConfig={"temperature": driver.temperature, "maxTokens": driver.max_tokens}, additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **(