Skip to content

Commit

Permalink
Use local to_openai_format conversion instead of ChatMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Apr 5, 2024
1 parent 83120e3 commit a1a04a9
Showing 1 changed file with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,25 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
# default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
# but we'll use our custom chat template
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=[m.to_openai_format() for m in messages], tokenize=False, chat_template=self.chat_template
conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template
)
return self._ensure_token_limit(prepared_prompt)

def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]:
"""
Convert the message to the format expected by OpenAI's Chat API.
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
:returns: A dictionary with the following key:
- `role`
- `content`
- `name` (optional)
"""
msg = {"role": m.role.value, "content": m.content}
if m.name:
msg["name"] = m.name
return msg

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
Expand Down

0 comments on commit a1a04a9

Please sign in to comment.