From 0e925c437ee2f7922558099ca0539a5b3587c982 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 1 Feb 2024 15:28:12 +0100 Subject: [PATCH 01/16] Add AmazonBedrockChatGenerator, add Anthropic Claude support --- .../generators/amazon_bedrock/__init__.py | 1 + .../amazon_bedrock/chat/__init__.py | 3 + .../amazon_bedrock/chat/adapters.py | 137 +++++++++++ .../amazon_bedrock/chat/chat_generator.py | 232 ++++++++++++++++++ 4 files changed, 373 insertions(+) create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 236347b61..0d875716d 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import AmazonBedrockChatGenerator from .generator import AmazonBedrockGenerator __all__ = ["AmazonBedrockGenerator"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py new file mode 100644 index 000000000..cc54ba60c --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -0,0 +1,137 @@ +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List + +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk + +from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler + +logger = logging.getLogger(__name__) + + +class BedrockModelChatAdapter(ABC): + """ + Base class for Amazon Bedrock model adapters. + """ + + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + self.generation_kwargs = generation_kwargs + + @abstractmethod + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """Prepares the body for the Amazon Bedrock request.""" + + def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """Extracts the responses from the Amazon Bedrock response.""" + return self._extract_messages_from_response(response_body) + + def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + tokens: List[str] = [] + for event in stream: + chunk = event.get("chunk") + if chunk: + decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + token = self._extract_token_from_stream(decoded_chunk) + # take all the rest key/value pairs from the chunk, add them to the metadata + stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token} + stream_chunk = StreamingChunk(content=token, meta=stream_metadata) + # callback the stream handler with StreamingChunk + stream_handler(stream_chunk) + tokens.append(token) + responses = ["".join(tokens).lstrip()] + return responses + + def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges the default params with the inference kwargs and model kwargs. + + Includes param if it's in kwargs or its default is not None (i.e. it is actually defined). + """ + kwargs = self.generation_kwargs.copy() + kwargs.update(inference_kwargs) + return { + param: kwargs.get(param, default) + for param, default in default_params.items() + if param in kwargs or default is not None + } + + @abstractmethod + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """Extracts the responses from the Amazon Bedrock response.""" + + @abstractmethod + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """Extracts the token from a streaming chunk.""" + + +class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Anthropic Claude model. + """ + + ANTHROPIC_USER_TOKEN = "\n\nHuman:" + ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" + + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + super().__init__(generation_kwargs) + + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Anthropic Claude has a limit of at least 100000 tokens + # https://docs.anthropic.com/claude/reference/input-and-output-sizes + model_max_length = self.generation_kwargs.get("model_max_length", 100000) + + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # TODO use Anthropic tokenizer to get the precise prompt length + # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting + self.prompt_handler = DefaultPromptHandler( + model="gpt2", + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + default_params = { + "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, + "stop_sequences": ["\n\nHuman:"], + "temperature": None, + "top_p": None, + "top_k": None, + } + + # combine stop words with default stop sequences + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.get("stop_words", []) + if stop_sequences: + inference_kwargs["stop_sequences"] = stop_sequences + params = self._get_params(inference_kwargs, default_params) + body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + conversation = [] + for index, message in enumerate(messages): + if message.is_from(ChatRole.USER): + conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}") + elif message.is_from(ChatRole.ASSISTANT): + conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") + elif message.is_from(ChatRole.FUNCTION): + raise ValueError("anthropic does not support function calls.") + elif message.is_from(ChatRole.SYSTEM) and index == 0: + # Until we transition to the new chat message format system messages will be ignored + # see https://docs.anthropic.com/claude/reference/messages_post for more details + logger.warning( + "System messages are not fully supported by the current version of Claude and will be ignored." + ) + else: + raise ValueError(f"Unsupported message role: {message.role}") + + return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " + + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + metadata = {k: v for (k, v) in response_body.items() if k != "completion"} + return [ChatMessage.from_assistant(response_body["completion"], meta=metadata)] + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("completion", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py new file mode 100644 index 000000000..2e14a1d5d --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -0,0 +1,232 @@ +import json +import logging +import re +from typing import Any, Callable, ClassVar, Dict, List, Optional, Type + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.utils import deserialize_callback_handler +from haystack.dataclasses import ChatMessage, StreamingChunk + +from ..errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError +from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter + +logger = logging.getLogger(__name__) + +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + + +@component +class AmazonBedrockChatGenerator: + """ + AmazonBedrockChatGenerator enables text generation via Amazon Bedrock chat hosted models. For example, to use + the Anthropic Claude model, simply initialize the AmazonBedrockChatGenerator with the 'anthropic.claude-v2' + model name. + + ```python + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.components.generators.utils import print_streaming_chunk + + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + + client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) + client.warm_up() + client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) + + ``` + + If you prefer non-streaming mode, simply remove the `streaming_callback` parameter, capture the return value of the + component's run method and the AmazonBedrockChatGenerator will return the response in a non-streaming mode. + """ + + SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { + r"anthropic.claude.*": AnthropicClaudeChatAdapter + } + + def __init__( + self, + model: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): + """ + Initializes the AmazonBedrockChatGenerator with the provided parameters. The parameters are passed to the + Amazon Bedrock client. + + Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded + automatically from the environment or the AWS configuration file and do not need to be provided explicitly via + the constructor. + + :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to + be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param generation_kwargs: Additional generation keyword arguments passed to the model. The defined keyword + parameters are specific to a specific model and can be found in the model's documentation. For example, the + Anthropic Claude generation parameters can be found [here](https://docs.anthropic.com/claude/reference/complete_post). + :param stop_words: A list of stop words that stop model generation when encountered. They can be provided via + this parameter or via models generation_kwargs under a model's specific key for stop words. For example, the + Anthropic Claude stop words are provided via the `stop_sequences` key. + :param streaming_callback: A callback function that is called when a new chunk is received from the stream. + By default, the model is not set up for streaming. To enable streaming simply set this parameter to a callback + function that will handle the streaming chunks. The callback function will receive a StreamingChunk object and + switch the streaming mode on. + """ + if not model: + msg = "'model' cannot be None or empty string" + raise ValueError(msg) + self.model = model + try: + session = self.get_aws_session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_profile_name=aws_profile_name, + ) + self.client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + model_adapter_cls = self.get_model_adapter(model=model) + if not model_adapter_cls: + msg = f"AmazonBedrockGenerator doesn't support the model {model}." + raise AmazonBedrockConfigurationError(msg) + self.model_adapter = model_adapter_cls(generation_kwargs or {}) + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback + + def invoke(self, *args, **kwargs): + kwargs = kwargs.copy() + messages: List[ChatMessage] = kwargs.pop("messages", []) + # check if the prompt is a list of ChatMessage objects + if not (isinstance(messages, list) and all(isinstance(message, ChatMessage) for message in messages)): + msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." + raise ValueError(msg) + + body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs) + try: + if self.streaming_callback: + response = self.client.invoke_model_with_response_stream( + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + ) + response_stream = response["body"] + responses = self.model_adapter.get_stream_responses( + stream=response_stream, stream_handler=self.streaming_callback + ) + else: + response = self.client.invoke_model( + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + ) + response_body = json.loads(response.get("body").read().decode("utf-8")) + responses = self.model_adapter.get_responses(response_body=response_body) + except ClientError as exception: + msg = ( + f"Could not connect to Amazon Bedrock model {self.model}. " + f"Make sure your AWS environment is configured correctly, " + f"the model is available in the configured AWS region, and you have access." + ) + raise AmazonBedrockInferenceError(msg) from exception + + return responses + + @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))} + + @classmethod + def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): + if re.fullmatch(pattern, model): + return adapter + return None + + @classmethod + def aws_configured(cls, **kwargs) -> bool: + """ + Checks whether AWS configuration is provided. + :param kwargs: The kwargs passed down to the generator. + :return: True if AWS configuration is provided, False otherwise. + """ + aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) + return aws_config_provided + + @classmethod + def get_aws_session( + cls, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, + ): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :return: The created AWS session. + """ + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, + ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + raise AWSConfigurationError(msg) from e + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + return default_to_dict(self, model=self.model) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) From 94de537cdaaa8e52209eada62ba9a1b66864df17 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 1 Feb 2024 16:03:14 +0100 Subject: [PATCH 02/16] Add Meta Llama 2 chat model support --- .../amazon_bedrock/chat/adapters.py | 47 ++++++++++++++++++- .../amazon_bedrock/chat/chat_generator.py | 5 +- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index cc54ba60c..d1dc033f8 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -12,7 +12,7 @@ class BedrockModelChatAdapter(ABC): """ - Base class for Amazon Bedrock model adapters. + Base class for Amazon Bedrock chat model adapters. """ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: @@ -135,3 +135,48 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("completion", "") + + +class MetaLlama2ChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Meta Llama model(s). + """ + + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + super().__init__(generation_kwargs) + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Llama 2 has context window size of 4096 tokens + model_max_length = self.generation_kwargs.get("model_max_length", 4096) + # Truncate prompt if prompt tokens > model_max_length-max_length + self.prompt_handler = DefaultPromptHandler( + model="meta-llama/Llama-2-7b-chat-hf", + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_gen_len") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + default_params = { + "max_gen_len": self.generation_kwargs.get("max_gen_len") or 512, + "temperature": None, + "top_p": None, + } + + # combine stop words with default stop sequences + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.get("stop_words", []) + if stop_sequences: + inference_kwargs["stop_sequences"] = stop_sequences + params = self._get_params(inference_kwargs, default_params) + body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(conversation=messages, tokenize=False) + return prepared_prompt + + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + metadata = {k: v for (k, v) in response_body.items() if k != "generation"} + return [ChatMessage.from_assistant(response_body["generation"], meta=metadata)] + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("generation", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 2e14a1d5d..8116de40c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -10,7 +10,7 @@ from haystack.dataclasses import ChatMessage, StreamingChunk from ..errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError -from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter +from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter logger = logging.getLogger(__name__) @@ -50,7 +50,8 @@ class AmazonBedrockChatGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"anthropic.claude.*": AnthropicClaudeChatAdapter + r"anthropic.claude.*": AnthropicClaudeChatAdapter, + r"meta.llama2.*": MetaLlama2ChatAdapter, } def __init__( From 40a1e1dd5f3dff1a959edc933dcf477f52d93722 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 2 Feb 2024 16:15:55 +0100 Subject: [PATCH 03/16] Add unit tests, small fixes --- .../generators/amazon_bedrock/__init__.py | 2 +- .../amazon_bedrock/chat/adapters.py | 52 ++-- .../amazon_bedrock/chat/chat_generator.py | 15 +- .../generators/amazon_bedrock/generator.py | 2 +- .../tests/test_amazon_chat_bedrock.py | 251 ++++++++++++++++++ 5 files changed, 295 insertions(+), 27 deletions(-) create mode 100644 integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 0d875716d..2d33beb42 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -4,4 +4,4 @@ from .chat.chat_generator import AmazonBedrockChatGenerator from .generator import AmazonBedrockGenerator -__all__ = ["AmazonBedrockGenerator"] +__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index d1dc033f8..04b3e3ded 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -42,19 +42,34 @@ def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk] responses = ["".join(tokens).lstrip()] return responses - def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: """ - Merges the default params with the inference kwargs and model kwargs. + Updates target_dict with values from updates_dict. Merges lists instead of overriding them. - Includes param if it's in kwargs or its default is not None (i.e. it is actually defined). + :param target_dict: The dictionary to update. + :param updates_dict: The dictionary with updates. """ - kwargs = self.generation_kwargs.copy() - kwargs.update(inference_kwargs) - return { - param: kwargs.get(param, default) - for param, default in default_params.items() - if param in kwargs or default is not None - } + for key, value in updates_dict.items(): + if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): + # Merge lists and remove duplicates + target_dict[key] = list(sorted(set(target_dict[key] + value))) + else: + # Override the value in target_dict + target_dict[key] = value + + def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges params from inference_kwargs with the default params and self.generation_kwargs. + Uses a helper function to merge lists or override values as necessary. + """ + # Start with a copy of default_params + kwargs = default_params.copy() + + # Update the default params with self.generation_kwargs and finally inference_kwargs + self._update_params(kwargs, self.generation_kwargs) + self._update_params(kwargs, inference_kwargs) + + return kwargs @abstractmethod def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: @@ -96,13 +111,10 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ default_params = { "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, "stop_sequences": ["\n\nHuman:"], - "temperature": None, - "top_p": None, - "top_k": None, } - # combine stop words with default stop sequences - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.get("stop_words", []) + # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences params = self._get_params(inference_kwargs, default_params) @@ -156,14 +168,10 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - default_params = { - "max_gen_len": self.generation_kwargs.get("max_gen_len") or 512, - "temperature": None, - "top_p": None, - } + default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - # combine stop words with default stop sequences - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.get("stop_words", []) + # combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences params = self._get_params(inference_kwargs, default_params) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 8116de40c..00e77be32 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -40,7 +40,6 @@ class AmazonBedrockChatGenerator: client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) - client.warm_up() client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) ``` @@ -124,7 +123,11 @@ def invoke(self, *args, **kwargs): kwargs = kwargs.copy() messages: List[ChatMessage] = kwargs.pop("messages", []) # check if the prompt is a list of ChatMessage objects - if not (isinstance(messages, list) and all(isinstance(message, ChatMessage) for message in messages)): + if not ( + isinstance(messages, list) + and len(messages) > 0 + and all(isinstance(message, ChatMessage) for message in messages) + ): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) @@ -217,7 +220,13 @@ def to_dict(self) -> Dict[str, Any]: Serialize this component to a dictionary. :return: The serialized component as a dictionary. """ - return default_to_dict(self, model=self.model) + return default_to_dict( + self, + model=self.model, + stop_words=self.stop_words, + generation_kwargs=self.model_adapter.generation_kwargs, + streaming_callback=self.streaming_callback, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 4c43c9a09..72314fc39 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -45,7 +45,7 @@ class AmazonBedrockGenerator: Usage example: ```python - from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator generator = AmazonBedrockGenerator( model="anthropic.claude-v2", diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py new file mode 100644 index 000000000..071a5a711 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -0,0 +1,251 @@ +from typing import Optional, Type +from unittest.mock import MagicMock, patch + +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage + +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator +from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( + MetaLlama2ChatAdapter, + AnthropicClaudeChatAdapter, BedrockModelChatAdapter, +) + + +@pytest.fixture +def mock_auto_tokenizer(): + with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: + mock_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_tokenizer + yield mock_tokenizer + + +# create a fixture with mocked boto3 client and session +@pytest.fixture +def mock_boto3_session(): + with patch("boto3.Session") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_prompt_handler(): + with patch( + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" + ) as mock_prompt_handler: + yield mock_prompt_handler + + +def test_to_dict(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the to_dict method returns the correct dictionary without aws credentials + """ + generator = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + aws_profile_name="some_fake_profile", + aws_region_name="fake_region", + generation_kwargs={"temperature": 0.7}, + streaming_callback=print_streaming_chunk, + ) + + expected_dict = { + "type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator", + "init_parameters": { + "model": "anthropic.claude-v2", + "generation_kwargs": {"temperature": 0.7}, + "stop_words": [], + "streaming_callback": print_streaming_chunk, + }, + } + + assert generator.to_dict() == expected_dict + + +def test_from_dict(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the from_dict method returns the correct object + """ + generator = AmazonBedrockChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator", + "init_parameters": { + "model": "anthropic.claude-v2", + "generation_kwargs": {"temperature": 0.7}, + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + }, + } + ) + assert generator.model == "anthropic.claude-v2" + assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} + assert generator.streaming_callback == print_streaming_chunk + + +def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the default constructor sets the correct values + """ + + layer = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + aws_profile_name="some_fake_profile", + aws_region_name="fake_region", + ) + + assert layer.model == "anthropic.claude-v2" + + assert layer.model_adapter.prompt_handler is not None + assert layer.model_adapter.prompt_handler.model_max_length == 100000 + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + +def test_constructor_with_generation_kwargs(mock_auto_tokenizer, mock_boto3_session): + """ + Test that model_kwargs are correctly set in the constructor + """ + generation_kwargs = {"temperature": 0.7} + + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) + assert "temperature" in layer.model_adapter.generation_kwargs + assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 + + +def test_constructor_with_empty_model(): + """ + Test that the constructor raises an error when the model is empty + """ + with pytest.raises(ValueError, match="cannot be None or empty string"): + AmazonBedrockChatGenerator(model="") + + +@pytest.mark.unit +def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): + """ + Test invoke raises an error if no messages are provided + """ + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2") + with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires"): + layer.invoke() + + +@pytest.mark.unit +@pytest.mark.parametrize( + "model, expected_model_adapter", + [ + ("anthropic.claude-v1", AnthropicClaudeChatAdapter), + ("anthropic.claude-v2", AnthropicClaudeChatAdapter), + ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), + ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial + ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), + ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), + ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("unknown_model", None), + ], +) +def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]): + """ + Test that the correct model adapter is returned for a given model + """ + model_adapter = AmazonBedrockChatGenerator.get_model_adapter(model=model) + assert model_adapter == expected_model_adapter + + +class TestAnthropicClaudeAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = AnthropicClaudeChatAdapter(generation_kwargs={}) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", + "max_tokens_to_sample": 512, + "stop_sequences": ["\n\nHuman:"], + } + + body = layer.prepare_body([ChatMessage.from_user(prompt)]) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, + "top_p": 0.8, + "top_k": 4}) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", + 'max_tokens_to_sample': 69, + "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body( + [ChatMessage.from_user(prompt)], + top_p=0.8, + top_k=5, + max_tokens_to_sample=69, + stop_sequences=["CUSTOM_STOP"], + ) + + assert body == expected_body + + +class TestMetaLlama2ChatAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = MetaLlama2ChatAdapter(generation_kwargs={}) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "[INST] Hello, how are you? [/INST]", + "max_gen_len": 512, + } + + body = layer.prepare_body([ChatMessage.from_user(prompt)]) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = MetaLlama2ChatAdapter(generation_kwargs={"temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "stop_sequences": ["CUSTOM_STOP"]}) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "[INST] Hello, how are you? [/INST]", + 'max_gen_len': 69, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body( + [ChatMessage.from_user(prompt)], + temperature=0.7, + top_p=0.8, + top_k=5, + max_gen_len=69, + stop_sequences=["CUSTOM_STOP"], + ) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = MetaLlama2ChatAdapter(generation_kwargs={}) + response_body = {"generation": "This is a single response."} + expected_response = "This is a single response." + response_message = adapter.get_responses(response_body) + assert response_message == [ChatMessage.from_assistant(expected_response)] From 86bf4175eee96a768d56146adaa67427c66a0657 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 2 Feb 2024 16:43:21 +0100 Subject: [PATCH 04/16] Rename print_streaming_chunk back to default_streaming_callback until new haystack-ai is released --- .../amazon_bedrock/tests/test_amazon_chat_bedrock.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 071a5a711..3c0277627 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock, patch import pytest -from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.components.generators.utils import default_streaming_callback +from haystack.dataclasses import ChatMessage, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( @@ -47,7 +47,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): aws_profile_name="some_fake_profile", aws_region_name="fake_region", generation_kwargs={"temperature": 0.7}, - streaming_callback=print_streaming_chunk, + streaming_callback=default_streaming_callback, ) expected_dict = { @@ -56,7 +56,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], - "streaming_callback": print_streaming_chunk, + "streaming_callback": default_streaming_callback, }, } @@ -73,13 +73,13 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, - "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", }, } ) assert generator.model == "anthropic.claude-v2" assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} - assert generator.streaming_callback == print_streaming_chunk + assert generator.streaming_callback == default_streaming_callback def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): From c3e4bcd1d20761bf1e5c63430552bd77e793879f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 2 Feb 2024 16:55:56 +0100 Subject: [PATCH 05/16] Hatch lint --- .../amazon_bedrock/chat/adapters.py | 8 +++-- .../amazon_bedrock/chat/chat_generator.py | 7 ++++- .../tests/test_amazon_chat_bedrock.py | 29 +++++++++---------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 04b3e3ded..fb978e092 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -52,7 +52,7 @@ def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, An for key, value in updates_dict.items(): if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): # Merge lists and remove duplicates - target_dict[key] = list(sorted(set(target_dict[key] + value))) + target_dict[key] = sorted(set(target_dict[key] + value)) else: # Override the value in target_dict target_dict[key] = value @@ -129,7 +129,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: elif message.is_from(ChatRole.ASSISTANT): conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") elif message.is_from(ChatRole.FUNCTION): - raise ValueError("anthropic does not support function calls.") + error_message = "Anthropic does not support function calls." + raise ValueError(error_message) elif message.is_from(ChatRole.SYSTEM) and index == 0: # Until we transition to the new chat message format system messages will be ignored # see https://docs.anthropic.com/claude/reference/messages_post for more details @@ -137,7 +138,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: "System messages are not fully supported by the current version of Claude and will be ignored." ) else: - raise ValueError(f"Unsupported message role: {message.role}") + invalid_role = f"Invalid role {message.role} for message {message.content}" + raise ValueError(invalid_role) return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 00e77be32..94bec3a72 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -9,7 +9,12 @@ from haystack.components.generators.utils import deserialize_callback_handler from haystack.dataclasses import ChatMessage, StreamingChunk -from ..errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError +from haystack_integrations.components.generators.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) + from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter logger = logging.getLogger(__name__) diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 3c0277627..e045fb790 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -3,14 +3,17 @@ import pytest from haystack.components.generators.utils import default_streaming_callback -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( + AnthropicClaudeChatAdapter, + BedrockModelChatAdapter, MetaLlama2ChatAdapter, - AnthropicClaudeChatAdapter, BedrockModelChatAdapter, ) +clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" + @pytest.fixture def mock_auto_tokenizer(): @@ -30,7 +33,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler @@ -49,9 +52,8 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): generation_kwargs={"temperature": 0.7}, streaming_callback=default_streaming_callback, ) - expected_dict = { - "type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator", + "type": clazz, "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, @@ -69,7 +71,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): """ generator = AmazonBedrockChatGenerator.from_dict( { - "type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator", + "type": clazz, "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, @@ -180,13 +182,11 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, - "top_p": 0.8, - "top_k": 4}) + layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - 'max_tokens_to_sample': 69, + "max_tokens_to_sample": 69, "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8, @@ -218,14 +218,13 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MetaLlama2ChatAdapter(generation_kwargs={"temperature": 0.7, - "top_p": 0.8, - "top_k": 5, - "stop_sequences": ["CUSTOM_STOP"]}) + layer = MetaLlama2ChatAdapter( + generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} + ) prompt = "Hello, how are you?" expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", - 'max_gen_len': 69, + "max_gen_len": 69, "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8, From 091d190ce98b7d463b4d952f7ad0c2c38b3ee157 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Sat, 3 Feb 2024 10:21:51 +0100 Subject: [PATCH 06/16] Test updates --- .../tests/test_amazon_chat_bedrock.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index e045fb790..d0bcf3424 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -168,7 +168,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed class TestAnthropicClaudeAdapter: - def test_prepare_body_with_default_params(self) -> None: + def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { @@ -181,7 +181,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body - def test_prepare_body_with_custom_inference_params(self) -> None: + def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { @@ -205,7 +205,11 @@ def test_prepare_body_with_custom_inference_params(self) -> None: class TestMetaLlama2ChatAdapter: - def test_prepare_body_with_default_params(self) -> None: + + @pytest.mark.integration + def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None: + # leave this test as integration because we really need only tokenizer from HF + # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { @@ -217,7 +221,10 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body + @pytest.mark.integration def test_prepare_body_with_custom_inference_params(self) -> None: + # leave this test as integration because we really need only tokenizer from HF + # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter( generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} ) @@ -242,7 +249,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body - def test_get_responses(self) -> None: + def test_get_responses(self, mock_auto_tokenizer) -> None: adapter = MetaLlama2ChatAdapter(generation_kwargs={}) response_body = {"generation": "This is a single response."} expected_response = "This is a single response." From b41cc5db151ff05f67c6cbc3e702ff2978f48e91 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Sat, 3 Feb 2024 10:31:16 +0100 Subject: [PATCH 07/16] Minor fix --- integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index d0bcf3424..affcc078f 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -207,7 +207,7 @@ def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) -> class TestMetaLlama2ChatAdapter: @pytest.mark.integration - def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None: + def test_prepare_body_with_default_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter(generation_kwargs={}) From a38e5b08a4fa1a99fc9e11e514a02e81d0cee67f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 6 Feb 2024 13:25:30 +0100 Subject: [PATCH 08/16] Use gpt2 tokenizer and llama 2 chat template --- .../amazon_bedrock/chat/adapters.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index fb978e092..dce52ac03 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -153,7 +153,33 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: class MetaLlama2ChatAdapter(BedrockModelChatAdapter): """ - Model adapter for the Meta Llama model(s). + Model adapter for the Meta Llama 2 models. + """ + + # Llama 2 chat template + chat_template = """ + {% if messages[0]['role'] == 'system' %} + {% set loop_messages = messages[1:] %} + {% set system_message = messages[0]['content'] %} + {% else %} + {% set loop_messages = messages %} + {% set system_message = false %} + {% endif %} + {% for message in loop_messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if loop.index0 == 0 and system_message != false %} + {% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %} + {% else %} + {% set content = message['content'] %} + {% endif %} + {% if message['role'] == 'user' %} + {{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ ' ' + content.strip() + ' ' + eos_token }} + {% endif %} + {% endfor %} """ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: @@ -164,7 +190,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: model_max_length = self.generation_kwargs.get("model_max_length", 4096) # Truncate prompt if prompt tokens > model_max_length-max_length self.prompt_handler = DefaultPromptHandler( - model="meta-llama/Llama-2-7b-chat-hf", + model="gpt2", # use gpt2 tokenizer to estimate prompt length model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_gen_len") or 512, ) @@ -181,7 +207,9 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ return body def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(conversation=messages, tokenize=False) + prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( + conversation=messages, tokenize=False, chat_template=self.chat_template + ) return prepared_prompt def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: From e603bc193662f33aadcbfe06e754a1d81b794314 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 6 Feb 2024 13:25:44 +0100 Subject: [PATCH 09/16] Pylint --- .../generators/amazon_bedrock/adapters.py | 13 +- .../generators/amazon_bedrock/errors.py | 5 +- .../generators/amazon_bedrock/generator.py | 32 +---- .../tests/test_amazon_bedrock.py | 127 ++++-------------- .../tests/test_amazon_chat_bedrock.py | 12 +- 5 files changed, 34 insertions(+), 155 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index 40ba0bc67..eca81c3f1 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -150,12 +150,7 @@ class AmazonTitanAdapter(BedrockModelAdapter): """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = { - "maxTokenCount": self.max_length, - "stopSequences": None, - "temperature": None, - "topP": None, - } + default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None} params = self._get_params(inference_kwargs, default_params) body = {"inputText": prompt, "textGenerationConfig": params} @@ -175,11 +170,7 @@ class MetaLlama2ChatAdapter(BedrockModelAdapter): """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = { - "max_gen_len": self.max_length, - "temperature": None, - "top_p": None, - } + default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None} params = self._get_params(inference_kwargs, default_params) body = {"prompt": prompt, **params} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py index aa8a3f6e4..53c28ad1d 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py @@ -10,10 +10,7 @@ class AmazonBedrockError(Exception): `AmazonBedrockError.message` will exist and have the expected content. """ - def __init__( - self, - message: Optional[str] = None, - ): + def __init__(self, message: Optional[str] = None): super().__init__() if message: self.message = message diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 72314fc39..2d19159f9 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -15,16 +15,8 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from .errors import ( - AmazonBedrockConfigurationError, - AmazonBedrockInferenceError, - AWSConfigurationError, -) -from .handlers import ( - DefaultPromptHandler, - DefaultTokenStreamingHandler, - TokenStreamingHandler, -) +from .errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError +from .handlers import DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler logger = logging.getLogger(__name__) @@ -112,9 +104,7 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - model="gpt2", - model_max_length=model_max_length, - max_length=self.max_length or 100, + model="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 ) model_adapter_cls = self.get_model_adapter(model=model) @@ -203,10 +193,7 @@ def invoke(self, *args, **kwargs): try: if stream: response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), - modelId=self.model, - accept="application/json", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_stream = response["body"] handler: TokenStreamingHandler = kwargs.get( @@ -216,10 +203,7 @@ def invoke(self, *args, **kwargs): responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( - body=json.dumps(body), - modelId=self.model, - accept="application/json", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_body = json.loads(response.get("body").read().decode("utf-8")) responses = self.model_adapter.get_responses(response_body=response_body) @@ -296,11 +280,7 @@ def to_dict(self) -> Dict[str, Any]: Serialize this component to a dictionary. :return: The serialized component as a dictionary. """ - return default_to_dict( - self, - model=self.model, - max_length=self.max_length, - ) + return default_to_dict(self, model=self.model, max_length=self.max_length) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index b08e9dfd5..6be07b06a 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -56,10 +56,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", - "init_parameters": { - "model": "anthropic.claude-v2", - "max_length": 99, - }, + "init_parameters": {"model": "anthropic.claude-v2", "max_length": 99}, } assert generator.to_dict() == expected_dict @@ -73,10 +70,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): generator = AmazonBedrockGenerator.from_dict( { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", - "init_parameters": { - "model": "anthropic.claude-v2", - "max_length": 99, - }, + "init_parameters": {"model": "anthropic.claude-v2", "max_length": 99}, } ) @@ -181,9 +175,7 @@ def test_short_prompt_is_not_truncated(mock_boto3_session): with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( - "anthropic.claude-v2", - max_length=max_length_generated_text, - model_max_length=total_model_max_length, + "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length ) prompt_after_resize = layer._ensure_token_limit(mock_prompt_text) @@ -216,9 +208,7 @@ def test_long_prompt_is_truncated(mock_boto3_session): with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( - "anthropic.claude-v2", - max_length=max_length_generated_text, - model_max_length=total_model_max_length, + "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length ) prompt_after_resize = layer._ensure_token_limit(long_prompt_text) @@ -238,10 +228,7 @@ def test_supports_for_valid_aws_configuration(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): - supported = AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - ) + supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args assert kwargs["byOutputModality"] == "TEXT" @@ -253,10 +240,7 @@ def test_supports_raises_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): - AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_fake_profile", - ) + AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_fake_profile") @pytest.mark.unit @@ -269,10 +253,7 @@ def test_supports_for_invalid_bedrock_config(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - ) + AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") @pytest.mark.unit @@ -285,10 +266,7 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - ) + AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") @pytest.mark.unit @@ -318,9 +296,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - stream=True, + model="anthropic.claude-v2", aws_profile_name="some_real_profile", stream=True ) assert supported @@ -337,15 +313,8 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): with patch( "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, - ), pytest.raises( - AmazonBedrockConfigurationError, - match="The model ai21.j2-mid-v1 doesn't support streaming.", - ): - AmazonBedrockGenerator.supports( - model="ai21.j2-mid-v1", - aws_profile_name="some_real_profile", - stream=True, - ) + ), pytest.raises(AmazonBedrockConfigurationError, match="The model ai21.j2-mid-v1 doesn't support streaming."): + AmazonBedrockGenerator.supports(model="ai21.j2-mid-v1", aws_profile_name="some_real_profile", stream=True) @pytest.mark.unit @@ -665,15 +634,9 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_responses_multiple_responses(self) -> None: adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) response_body = { - "generations": [ - {"text": "This is a single response."}, - {"text": "This is a second response."}, - ] + "generations": [{"text": "This is a single response."}, {"text": "This is a second response."}] } - expected_responses = [ - "This is a single response.", - "This is a second response.", - ] + expected_responses = ["This is a single response.", "This is a second response."] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: @@ -854,10 +817,7 @@ def test_get_responses_multiple_responses(self) -> None: {"data": {"text": "This is a second response."}}, ] } - expected_responses = [ - "This is a single response.", - "This is a second response.", - ] + expected_responses = ["This is a single response.", "This is a second response."] assert adapter.get_responses(response_body) == expected_responses @@ -865,10 +825,7 @@ class TestAmazonTitanAdapter: def test_prepare_body_with_default_params(self) -> None: layer = AmazonTitanAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = { - "inputText": "Hello, how are you?", - "textGenerationConfig": {"maxTokenCount": 99}, - } + expected_body = {"inputText": "Hello, how are you?", "textGenerationConfig": {"maxTokenCount": 99}} body = layer.prepare_body(prompt) @@ -964,15 +921,9 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_responses_multiple_responses(self) -> None: adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) response_body = { - "results": [ - {"outputText": "This is a single response."}, - {"outputText": "This is a second response."}, - ] + "results": [{"outputText": "This is a single response."}, {"outputText": "This is a second response."}] } - expected_responses = [ - "This is a single response.", - "This is a second response.", - ] + expected_responses = ["This is a single response.", "This is a second response."] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: @@ -1031,40 +982,19 @@ def test_prepare_body_with_default_params(self) -> None: def test_prepare_body_with_custom_inference_params(self) -> None: layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = { - "prompt": "Hello, how are you?", - "max_gen_len": 50, - "temperature": 0.7, - "top_p": 0.8, - } + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} - body = layer.prepare_body( - prompt, - temperature=0.7, - top_p=0.8, - max_gen_len=50, - unknown_arg="unknown_value", - ) + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, max_gen_len=50, unknown_arg="unknown_value") assert body == expected_body def test_prepare_body_with_model_kwargs(self) -> None: layer = MetaLlama2ChatAdapter( - model_kwargs={ - "temperature": 0.7, - "top_p": 0.8, - "max_gen_len": 50, - "unknown_arg": "unknown_value", - }, + model_kwargs={"temperature": 0.7, "top_p": 0.8, "max_gen_len": 50, "unknown_arg": "unknown_value"}, max_length=99, ) prompt = "Hello, how are you?" - expected_body = { - "prompt": "Hello, how are you?", - "max_gen_len": 50, - "temperature": 0.7, - "top_p": 0.8, - } + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} body = layer.prepare_body(prompt) @@ -1072,21 +1002,10 @@ def test_prepare_body_with_model_kwargs(self) -> None: def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: layer = MetaLlama2ChatAdapter( - model_kwargs={ - "temperature": 0.6, - "top_p": 0.7, - "top_k": 4, - "max_gen_len": 49, - }, - max_length=99, + model_kwargs={"temperature": 0.6, "top_p": 0.7, "top_k": 4, "max_gen_len": 49}, max_length=99 ) prompt = "Hello, how are you?" - expected_body = { - "prompt": "Hello, how are you?", - "max_gen_len": 50, - "temperature": 0.7, - "top_p": 0.7, - } + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.7} body = layer.prepare_body(prompt, temperature=0.7, max_gen_len=50) diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index affcc078f..866a5a99d 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -194,28 +194,20 @@ def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) -> } body = layer.prepare_body( - [ChatMessage.from_user(prompt)], - top_p=0.8, - top_k=5, - max_tokens_to_sample=69, - stop_sequences=["CUSTOM_STOP"], + [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] ) assert body == expected_body class TestMetaLlama2ChatAdapter: - @pytest.mark.integration def test_prepare_body_with_default_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_gen_len": 512, - } + expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} body = layer.prepare_body([ChatMessage.from_user(prompt)]) From baac9fe925bf7e0a2355254dcda83fecf6b9809b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 6 Feb 2024 13:35:19 +0100 Subject: [PATCH 10/16] Revert back to llama2 tokenizer --- .../amazon_bedrock/chat/adapters.py | 30 ++----------------- .../tests/test_amazon_chat_bedrock.py | 10 +++---- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index dce52ac03..1b1287c59 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -156,32 +156,6 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ - # Llama 2 chat template - chat_template = """ - {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} - {% else %} - {% set loop_messages = messages %} - {% set system_message = false %} - {% endif %} - {% for message in loop_messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ ' ' + content.strip() + ' ' + eos_token }} - {% endif %} - {% endfor %} - """ - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model @@ -190,7 +164,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: model_max_length = self.generation_kwargs.get("model_max_length", 4096) # Truncate prompt if prompt tokens > model_max_length-max_length self.prompt_handler = DefaultPromptHandler( - model="gpt2", # use gpt2 tokenizer to estimate prompt length + model="meta-llama/Llama-2-7b-chat-hf", model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_gen_len") or 512, ) @@ -208,7 +182,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template + conversation=messages, tokenize=False ) return prepared_prompt diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 866a5a99d..622bae4ef 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch import pytest -from haystack.components.generators.utils import default_streaming_callback +from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator @@ -50,7 +50,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): aws_profile_name="some_fake_profile", aws_region_name="fake_region", generation_kwargs={"temperature": 0.7}, - streaming_callback=default_streaming_callback, + streaming_callback=print_streaming_chunk, ) expected_dict = { "type": clazz, @@ -58,7 +58,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], - "streaming_callback": default_streaming_callback, + "streaming_callback": print_streaming_chunk, }, } @@ -75,13 +75,13 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, - "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", }, } ) assert generator.model == "anthropic.claude-v2" assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} - assert generator.streaming_callback == default_streaming_callback + assert generator.streaming_callback == print_streaming_chunk def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): From 13279db6dd162880749e98eed4471291e144f97a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 6 Feb 2024 14:24:58 +0100 Subject: [PATCH 11/16] Use gpt2 with special_token_map, use llama2 chat template --- .../amazon_bedrock/chat/adapters.py | 36 +++++++++++++++++-- .../generators/amazon_bedrock/generator.py | 2 +- .../generators/amazon_bedrock/handlers.py | 13 +++++-- .../tests/test_amazon_chat_bedrock.py | 3 +- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 1b1287c59..0c6335635 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from transformers import AutoTokenizer, PreTrainedTokenizer from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler @@ -102,7 +103,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # TODO use Anthropic tokenizer to get the precise prompt length # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting self.prompt_handler = DefaultPromptHandler( - model="gpt2", + tokenizer="gpt2", model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, ) @@ -156,6 +157,31 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ + chat_template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" + "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model @@ -163,8 +189,12 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # Llama 2 has context window size of 4096 tokens model_max_length = self.generation_kwargs.get("model_max_length", 4096) # Truncate prompt if prompt tokens > model_max_length-max_length + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer.bos_token = "" + tokenizer.eos_token = "" + tokenizer.unk_token = "" self.prompt_handler = DefaultPromptHandler( - model="meta-llama/Llama-2-7b-chat-hf", + tokenizer=tokenizer, model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_gen_len") or 512, ) @@ -182,7 +212,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False + conversation=messages, tokenize=False, chat_template=self.chat_template ) return prepared_prompt diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 2d19159f9..48f22f59b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -104,7 +104,7 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - model="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 + tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 ) model_adapter_cls = self.get_model_adapter(model=model) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index 56dcb24d3..71450bec0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Union -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast class DefaultPromptHandler: @@ -10,8 +10,15 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__(self, model: str, model_max_length: int, max_length: int = 100): - self.tokenizer = AutoTokenizer.from_pretrained(model) + def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer = tokenizer + else: + msg = "model must be a string or a PreTrainedTokenizer instance" + raise ValueError(msg) + self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length self.max_length = max_length diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 622bae4ef..9592b5b39 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -241,7 +241,8 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body - def test_get_responses(self, mock_auto_tokenizer) -> None: + @pytest.mark.integration + def test_get_responses(self) -> None: adapter = MetaLlama2ChatAdapter(generation_kwargs={}) response_body = {"generation": "This is a single response."} expected_response = "This is a single response." From 2a62b8c48d30dcdbec2549cb342ef942ca13ff15 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 7 Feb 2024 11:35:09 +0100 Subject: [PATCH 12/16] Hook prompt length check --- .../amazon_bedrock/chat/adapters.py | 48 ++++++++++++++++--- .../amazon_bedrock/chat/chat_generator.py | 6 +-- .../generators/amazon_bedrock/handlers.py | 5 ++ 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 0c6335635..37c490f6b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -72,6 +72,29 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str return kwargs + def _ensure_token_limit(self, prompt: str) -> str: + resize_info = self.check_prompt(prompt) + if resize_info["prompt_length"] != resize_info["new_prompt_length"]: + logger.warning( + "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " + "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " + "Shorten the prompt or it will be cut off.", + resize_info["prompt_length"], + max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore + resize_info["max_length"], + resize_info["model_max_length"], + ) + return str(resize_info["resized_prompt"]) + + @abstractmethod + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. + + :param prompt: The prompt to check. + :return: A dictionary containing the resized prompt and additional information. + """ + @abstractmethod def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """Extracts the responses from the Amazon Bedrock response.""" @@ -89,14 +112,14 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): ANTHROPIC_USER_TOKEN = "\n\nHuman:" ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + def __init__(self, generation_kwargs: Dict[str, Any]): super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed # Anthropic Claude has a limit of at least 100000 tokens # https://docs.anthropic.com/claude/reference/input-and-output-sizes - model_max_length = self.generation_kwargs.get("model_max_length", 100000) + model_max_length = self.generation_kwargs.pop("model_max_length", 100000) # Truncate prompt if prompt tokens > model_max_length-max_length # (max_length is the length of the generated text) @@ -142,7 +165,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: invalid_role = f"Invalid role {message.role} for message {message.content}" raise ValueError(invalid_role) - return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " + prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " + return self._ensure_token_limit(prepared_prompt) + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + return self.prompt_handler(prompt) def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: metadata = {k: v for (k, v) in response_body.items() if k != "completion"} @@ -187,9 +214,13 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed # Llama 2 has context window size of 4096 tokens - model_max_length = self.generation_kwargs.get("model_max_length", 4096) - # Truncate prompt if prompt tokens > model_max_length-max_length - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2") + # with some exceptions when the context window has been extended + model_max_length = self.generation_kwargs.pop("model_max_length", 4096) + + # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 + # a) we should get good estimates for the prompt length (empirically close to llama 2) + # b) we can use apply_chat_template with the template above to delineate ChatMessages + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.unk_token = "" @@ -214,7 +245,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=messages, tokenize=False, chat_template=self.chat_template ) - return prepared_prompt + return self._ensure_token_limit(prepared_prompt) + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + return self.prompt_handler(prompt) def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: metadata = {k: v for (k, v) in response_body.items() if k != "generation"} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 94bec3a72..fda9d4fff 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -153,11 +153,7 @@ def invoke(self, *args, **kwargs): response_body = json.loads(response.get("body").read().decode("utf-8")) responses = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: - msg = ( - f"Could not connect to Amazon Bedrock model {self.model}. " - f"Make sure your AWS environment is configured correctly, " - f"the model is available in the configured AWS region, and you have access." - ) + msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception return responses diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index 71450bec0..b7b555ec0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -11,6 +11,11 @@ class DefaultPromptHandler: """ def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): + """ + :param tokenizer: The tokenizer to be used to tokenize the prompt. + :param model_max_length: The maximum length of the prompt and answer tokens combined. + :param max_length: The maximum length of the answer tokens. + """ if isinstance(tokenizer, str): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): From 11b9b2f0a608cdeb61efaa30e32058f50f3ab0fa Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 7 Feb 2024 22:43:52 +0100 Subject: [PATCH 13/16] PR feedback David --- .../amazon_bedrock/chat/adapters.py | 30 ++++++++++++------- .../amazon_bedrock/chat/chat_generator.py | 4 ++- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 37c490f6b..a4eefe321 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List +from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from transformers import AutoTokenizer, PreTrainedTokenizer @@ -25,9 +26,9 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """Extracts the responses from the Amazon Bedrock response.""" - return self._extract_messages_from_response(response_body) + return self._extract_messages_from_response(self.response_body_message_key(), response_body) - def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: tokens: List[str] = [] for event in stream: chunk = event.get("chunk") @@ -43,7 +44,8 @@ def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk] responses = ["".join(tokens).lstrip()] return responses - def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: + @staticmethod + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: """ Updates target_dict with values from updates_dict. Merges lists instead of overriding them. @@ -62,6 +64,10 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str """ Merges params from inference_kwargs with the default params and self.generation_kwargs. Uses a helper function to merge lists or override values as necessary. + + :param inference_kwargs: The inference kwargs to merge. + :param default_params: The default params to start with. + :return: The merged params. """ # Start with a copy of default_params kwargs = default_params.copy() @@ -95,9 +101,13 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: :return: A dictionary containing the resized prompt and additional information. """ + def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] + @abstractmethod - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """Extracts the responses from the Amazon Bedrock response.""" + def response_body_message_key(self) -> str: + """Returns the key for the message in the response body.""" @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -171,9 +181,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: def check_prompt(self, prompt: str) -> Dict[str, Any]: return self.prompt_handler(prompt) - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - metadata = {k: v for (k, v) in response_body.items() if k != "completion"} - return [ChatMessage.from_assistant(response_body["completion"], meta=metadata)] + def response_body_message_key(self) -> str: + return "completion" def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("completion", "") @@ -250,9 +259,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: def check_prompt(self, prompt: str) -> Dict[str, Any]: return self.prompt_handler(prompt) - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - metadata = {k: v for (k, v) in response_body.items() if k != "generation"} - return [ChatMessage.from_assistant(response_body["generation"], meta=metadata)] + def response_body_message_key(self) -> str: + return "generation" def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("generation", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index fda9d4fff..ecb0c7bb9 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -76,7 +76,9 @@ def __init__( Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded automatically from the environment or the AWS configuration file and do not need to be provided explicitly via - the constructor. + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). From 1c071b0bdbf429afb63f0e030843b4e5f6f04e79 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 8 Feb 2024 12:08:06 +0100 Subject: [PATCH 14/16] Reverse PR unrelated files --- .../generators/amazon_bedrock/adapters.py | 13 +- .../generators/amazon_bedrock/errors.py | 5 +- .../generators/amazon_bedrock/generator.py | 34 ++++- .../tests/test_amazon_bedrock.py | 127 ++++++++++++++---- 4 files changed, 146 insertions(+), 33 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index eca81c3f1..40ba0bc67 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -150,7 +150,12 @@ class AmazonTitanAdapter(BedrockModelAdapter): """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None} + default_params = { + "maxTokenCount": self.max_length, + "stopSequences": None, + "temperature": None, + "topP": None, + } params = self._get_params(inference_kwargs, default_params) body = {"inputText": prompt, "textGenerationConfig": params} @@ -170,7 +175,11 @@ class MetaLlama2ChatAdapter(BedrockModelAdapter): """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None} + default_params = { + "max_gen_len": self.max_length, + "temperature": None, + "top_p": None, + } params = self._get_params(inference_kwargs, default_params) body = {"prompt": prompt, **params} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py index 53c28ad1d..aa8a3f6e4 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py @@ -10,7 +10,10 @@ class AmazonBedrockError(Exception): `AmazonBedrockError.message` will exist and have the expected content. """ - def __init__(self, message: Optional[str] = None): + def __init__( + self, + message: Optional[str] = None, + ): super().__init__() if message: self.message = message diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 48f22f59b..4c43c9a09 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -15,8 +15,16 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from .errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError -from .handlers import DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler +from .errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) +from .handlers import ( + DefaultPromptHandler, + DefaultTokenStreamingHandler, + TokenStreamingHandler, +) logger = logging.getLogger(__name__) @@ -37,7 +45,7 @@ class AmazonBedrockGenerator: Usage example: ```python - from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator + from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator generator = AmazonBedrockGenerator( model="anthropic.claude-v2", @@ -104,7 +112,9 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 + model="gpt2", + model_max_length=model_max_length, + max_length=self.max_length or 100, ) model_adapter_cls = self.get_model_adapter(model=model) @@ -193,7 +203,10 @@ def invoke(self, *args, **kwargs): try: if stream: response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + body=json.dumps(body), + modelId=self.model, + accept="application/json", + contentType="application/json", ) response_stream = response["body"] handler: TokenStreamingHandler = kwargs.get( @@ -203,7 +216,10 @@ def invoke(self, *args, **kwargs): responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + body=json.dumps(body), + modelId=self.model, + accept="application/json", + contentType="application/json", ) response_body = json.loads(response.get("body").read().decode("utf-8")) responses = self.model_adapter.get_responses(response_body=response_body) @@ -280,7 +296,11 @@ def to_dict(self) -> Dict[str, Any]: Serialize this component to a dictionary. :return: The serialized component as a dictionary. """ - return default_to_dict(self, model=self.model, max_length=self.max_length) + return default_to_dict( + self, + model=self.model, + max_length=self.max_length, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index 6be07b06a..b08e9dfd5 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -56,7 +56,10 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", - "init_parameters": {"model": "anthropic.claude-v2", "max_length": 99}, + "init_parameters": { + "model": "anthropic.claude-v2", + "max_length": 99, + }, } assert generator.to_dict() == expected_dict @@ -70,7 +73,10 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): generator = AmazonBedrockGenerator.from_dict( { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", - "init_parameters": {"model": "anthropic.claude-v2", "max_length": 99}, + "init_parameters": { + "model": "anthropic.claude-v2", + "max_length": 99, + }, } ) @@ -175,7 +181,9 @@ def test_short_prompt_is_not_truncated(mock_boto3_session): with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( - "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length + "anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, ) prompt_after_resize = layer._ensure_token_limit(mock_prompt_text) @@ -208,7 +216,9 @@ def test_long_prompt_is_truncated(mock_boto3_session): with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( - "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length + "anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, ) prompt_after_resize = layer._ensure_token_limit(long_prompt_text) @@ -228,7 +238,10 @@ def test_supports_for_valid_aws_configuration(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): - supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") + supported = AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + ) args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args assert kwargs["byOutputModality"] == "TEXT" @@ -240,7 +253,10 @@ def test_supports_raises_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): - AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_fake_profile") + AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_fake_profile", + ) @pytest.mark.unit @@ -253,7 +269,10 @@ def test_supports_for_invalid_bedrock_config(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") + AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + ) @pytest.mark.unit @@ -266,7 +285,10 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") + AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + ) @pytest.mark.unit @@ -296,7 +318,9 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", aws_profile_name="some_real_profile", stream=True + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + stream=True, ) assert supported @@ -313,8 +337,15 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): with patch( "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, - ), pytest.raises(AmazonBedrockConfigurationError, match="The model ai21.j2-mid-v1 doesn't support streaming."): - AmazonBedrockGenerator.supports(model="ai21.j2-mid-v1", aws_profile_name="some_real_profile", stream=True) + ), pytest.raises( + AmazonBedrockConfigurationError, + match="The model ai21.j2-mid-v1 doesn't support streaming.", + ): + AmazonBedrockGenerator.supports( + model="ai21.j2-mid-v1", + aws_profile_name="some_real_profile", + stream=True, + ) @pytest.mark.unit @@ -634,9 +665,15 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_responses_multiple_responses(self) -> None: adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) response_body = { - "generations": [{"text": "This is a single response."}, {"text": "This is a second response."}] + "generations": [ + {"text": "This is a single response."}, + {"text": "This is a second response."}, + ] } - expected_responses = ["This is a single response.", "This is a second response."] + expected_responses = [ + "This is a single response.", + "This is a second response.", + ] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: @@ -817,7 +854,10 @@ def test_get_responses_multiple_responses(self) -> None: {"data": {"text": "This is a second response."}}, ] } - expected_responses = ["This is a single response.", "This is a second response."] + expected_responses = [ + "This is a single response.", + "This is a second response.", + ] assert adapter.get_responses(response_body) == expected_responses @@ -825,7 +865,10 @@ class TestAmazonTitanAdapter: def test_prepare_body_with_default_params(self) -> None: layer = AmazonTitanAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = {"inputText": "Hello, how are you?", "textGenerationConfig": {"maxTokenCount": 99}} + expected_body = { + "inputText": "Hello, how are you?", + "textGenerationConfig": {"maxTokenCount": 99}, + } body = layer.prepare_body(prompt) @@ -921,9 +964,15 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_responses_multiple_responses(self) -> None: adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) response_body = { - "results": [{"outputText": "This is a single response."}, {"outputText": "This is a second response."}] + "results": [ + {"outputText": "This is a single response."}, + {"outputText": "This is a second response."}, + ] } - expected_responses = ["This is a single response.", "This is a second response."] + expected_responses = [ + "This is a single response.", + "This is a second response.", + ] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: @@ -982,19 +1031,40 @@ def test_prepare_body_with_default_params(self) -> None: def test_prepare_body_with_custom_inference_params(self) -> None: layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} + expected_body = { + "prompt": "Hello, how are you?", + "max_gen_len": 50, + "temperature": 0.7, + "top_p": 0.8, + } - body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, max_gen_len=50, unknown_arg="unknown_value") + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + max_gen_len=50, + unknown_arg="unknown_value", + ) assert body == expected_body def test_prepare_body_with_model_kwargs(self) -> None: layer = MetaLlama2ChatAdapter( - model_kwargs={"temperature": 0.7, "top_p": 0.8, "max_gen_len": 50, "unknown_arg": "unknown_value"}, + model_kwargs={ + "temperature": 0.7, + "top_p": 0.8, + "max_gen_len": 50, + "unknown_arg": "unknown_value", + }, max_length=99, ) prompt = "Hello, how are you?" - expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} + expected_body = { + "prompt": "Hello, how are you?", + "max_gen_len": 50, + "temperature": 0.7, + "top_p": 0.8, + } body = layer.prepare_body(prompt) @@ -1002,10 +1072,21 @@ def test_prepare_body_with_model_kwargs(self) -> None: def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: layer = MetaLlama2ChatAdapter( - model_kwargs={"temperature": 0.6, "top_p": 0.7, "top_k": 4, "max_gen_len": 49}, max_length=99 + model_kwargs={ + "temperature": 0.6, + "top_p": 0.7, + "top_k": 4, + "max_gen_len": 49, + }, + max_length=99, ) prompt = "Hello, how are you?" - expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.7} + expected_body = { + "prompt": "Hello, how are you?", + "max_gen_len": 50, + "temperature": 0.7, + "top_p": 0.7, + } body = layer.prepare_body(prompt, temperature=0.7, max_gen_len=50) From 1b2faaac258962051801bbb4abcce7d0d8f719c1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 8 Feb 2024 12:12:16 +0100 Subject: [PATCH 15/16] First check model, then open AWS connection - David --- .../amazon_bedrock/chat/chat_generator.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index ecb0c7bb9..804d44413 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -102,6 +102,15 @@ def __init__( msg = "'model' cannot be None or empty string" raise ValueError(msg) self.model = model + + # get the model adapter for the given model + model_adapter_cls = self.get_model_adapter(model=model) + if not model_adapter_cls: + msg = f"AmazonBedrockGenerator doesn't support the model {model}." + raise AmazonBedrockConfigurationError(msg) + self.model_adapter = model_adapter_cls(generation_kwargs or {}) + + # create the AWS session and client try: session = self.get_aws_session( aws_access_key_id=aws_access_key_id, @@ -118,11 +127,6 @@ def __init__( ) raise AmazonBedrockConfigurationError(msg) from exception - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(generation_kwargs or {}) self.stop_words = stop_words or [] self.streaming_callback = streaming_callback From 57d7d5f1e1e208b91063a3894f1da92170a9e16f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 8 Feb 2024 12:15:53 +0100 Subject: [PATCH 16/16] Small fix --- .../components/generators/amazon_bedrock/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 4c43c9a09..8e89dab59 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -112,7 +112,7 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - model="gpt2", + tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100, )