Skip to content

Commit

Permalink
Double check everything, minor adjustements
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Apr 4, 2024
1 parent b4a0b5f commit 83120e3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)
prompt = f"<s>[INST] {prompt} [/INST]" if "INST" not in prompt else prompt
body = {"prompt": prompt, **params}
return body
# Add the instruction tag to the prompt if it's not already there
formatted_prompt = f"<s>[INST] {prompt} [/INST]" if "INST" not in prompt else prompt
return {"prompt": formatted_prompt, **params}

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Expand All @@ -173,7 +173,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
:param response_body: The response body from the Amazon Bedrock request.
:returns: A list of string responses.
"""
return [output["text"] for output in response_body["outputs"]]
return [output.get("text", "") for output in response_body.get("outputs", [])]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:

class MistralChatAdapter(BedrockModelChatAdapter):
"""
Model adapter for the Anthropic Claude chat model.
Model adapter for the Mistral chat model.
"""

chat_template = (
Expand Down Expand Up @@ -302,6 +302,8 @@ class MistralChatAdapter(BedrockModelChatAdapter):
"{% endfor %}"
)
# the above template was designed to match https://docs.mistral.ai/models/#chat-template
# and to support system messages, otherwise we could use the default mistral chat template
# available on HF infrastructure

# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
ALLOWED_PARAMS: ClassVar[List[str]] = [
Expand All @@ -326,7 +328,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
model_max_length = self.generation_kwargs.pop("model_max_length", 32000)

# Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer
# a) we should get good estimates for the prompt length (empirically close to llama 2)
# a) we should get good estimates for the prompt length
# b) we can use apply_chat_template with the template above to delineate ChatMessages
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
self.prompt_handler = DefaultPromptHandler(
Expand All @@ -346,7 +348,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
default_params = {
"max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required
}
# replace stop_words from inference_kwargs with stop, as this is Mistral specific
# replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter
stop_words = inference_kwargs.pop("stop_words", [])
if stop_words:
inference_kwargs["stop"] = stop_words
Expand All @@ -362,10 +364,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
:returns: The prepared chat messages as a string.
"""
# it would be great to use the default mistral chat template, but it doesn't support system messages
# the above defined chat_template is a workaround to support system messages
# template is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
# the class variable defined chat_template is a workaround to support system messages
# default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
# but we'll use our custom chat template
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, chat_template=self.chat_template
conversation=[m.to_openai_format() for m in messages], tokenize=False, chat_template=self.chat_template
)
return self._ensure_token_limit(prepared_prompt)

Expand Down

0 comments on commit 83120e3

Please sign in to comment.