Skip to content

Commit

Permalink
Fix missing maxTokens in AmazonBedrockPromptDriver (#1123)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 29, 2024
1 parent ba47112 commit ab25735
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Parsing streaming response with some OpenAi compatible services.
- Issue in `PromptSummaryEngine` if there are no artifacts during recursive summarization.
- Issue in `GooglePromptDriver` using Tools with no schema.
- Missing `maxTokens` inference parameter in `AmazonBedrockPromptDriver`.
- Incorrect model in `OpenAiDriverConfig`'s `text_to_speech_driver`.

**Note**: This release includes breaking changes. Please refer to the [Migration Guide](./MIGRATION.md#030x-to-031x) for details.
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"modelId": self.model,
"messages": messages,
"system": system_messages,
"inferenceConfig": {"temperature": self.temperature},
"inferenceConfig": {"temperature": self.temperature, "maxTokens": self.max_tokens},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
mock_converse.assert_called_once_with(
modelId=driver.model,
messages=messages,
inferenceConfig={"temperature": driver.temperature},
inferenceConfig={"temperature": driver.temperature, "maxTokens": driver.max_tokens},
additionalModelRequestFields={},
**({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}),
**(
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_
mock_converse_stream.assert_called_once_with(
modelId=driver.model,
messages=messages,
inferenceConfig={"temperature": driver.temperature},
inferenceConfig={"temperature": driver.temperature, "maxTokens": driver.max_tokens},
additionalModelRequestFields={},
**({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}),
**(
Expand Down

0 comments on commit ab25735

Please sign in to comment.