From 83120e357007c7783fd9126f5df86de72ed4bcc7 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 29 Mar 2024 11:04:13 +0100 Subject: [PATCH] Double check everything, minor adjustements --- .../generators/amazon_bedrock/adapters.py | 8 ++++---- .../generators/amazon_bedrock/chat/adapters.py | 15 +++++++++------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index a2e206b76..f5bd4aa07 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -162,9 +162,9 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: "top_k": None, } params = self._get_params(inference_kwargs, default_params) - prompt = f"[INST] {prompt} [/INST]" if "INST" not in prompt else prompt - body = {"prompt": prompt, **params} - return body + # Add the instruction tag to the prompt if it's not already there + formatted_prompt = f"[INST] {prompt} [/INST]" if "INST" not in prompt else prompt + return {"prompt": formatted_prompt, **params} def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ @@ -173,7 +173,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L :param response_body: The response body from the Amazon Bedrock request. :returns: A list of string responses. """ - return [output["text"] for output in response_body["outputs"]] + return [output.get("text", "") for output in response_body.get("outputs", [])] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 82a9f2c83..e24274803 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -273,7 +273,7 @@ def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: class MistralChatAdapter(BedrockModelChatAdapter): """ - Model adapter for the Anthropic Claude chat model. + Model adapter for the Mistral chat model. """ chat_template = ( @@ -302,6 +302,8 @@ class MistralChatAdapter(BedrockModelChatAdapter): "{% endfor %}" ) # the above template was designed to match https://docs.mistral.ai/models/#chat-template + # and to support system messages, otherwise we could use the default mistral chat template + # available on HF infrastructure # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html ALLOWED_PARAMS: ClassVar[List[str]] = [ @@ -326,7 +328,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): model_max_length = self.generation_kwargs.pop("model_max_length", 32000) # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer - # a) we should get good estimates for the prompt length (empirically close to llama 2) + # a) we should get good estimates for the prompt length # b) we can use apply_chat_template with the template above to delineate ChatMessages tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") self.prompt_handler = DefaultPromptHandler( @@ -346,7 +348,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ default_params = { "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required } - # replace stop_words from inference_kwargs with stop, as this is Mistral specific + # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter stop_words = inference_kwargs.pop("stop_words", []) if stop_words: inference_kwargs["stop"] = stop_words @@ -362,10 +364,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: :returns: The prepared chat messages as a string. """ # it would be great to use the default mistral chat template, but it doesn't support system messages - # the above defined chat_template is a workaround to support system messages - # template is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json + # the class variable defined chat_template is a workaround to support system messages + # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json + # but we'll use our custom chat template prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template + conversation=[m.to_openai_format() for m in messages], tokenize=False, chat_template=self.chat_template ) return self._ensure_token_limit(prepared_prompt)