Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MistralV2 Format #5118

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,65 @@ def _encode(
return encoded_messages


def get_last_user_message_index(messages: Sequence[Dict[str, str]]) -> int:
for i, message in enumerate(messages[::-1]):
if message["role"] == Role.USER.value:
return len(messages) - i - 1
logger.warning("No user message found.")
return 0


@dataclass
class MistralV2Template(Template):

def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> List[List[int]]:
r"""
Mistral V2 adds the system prompt prior to the last user message.
(https://github.com/mistralai/mistral-common/blob/main/src/mistral_common/tokens/tokenizers/sentencepiece.py#L282)

Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + query resp
Turn t: system + query resp
"""
system = system or self.default_system
encoded_messages = []
last_user_message_idx = get_last_user_message_index(messages)
for i, message in enumerate(messages):
elements = []
if i == 0:
elements += self.format_prefix.apply()
if tools:
elements += self.format_tools.apply(content=tools)[0] if tools else ""

if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()

if message["role"] == Role.USER.value:
user_slot = self.format_user.apply(content=message["content"], idx=str(i // 2))
if i == last_user_message_idx and system:
elements += self.format_system.apply(content=system)
user_slot = user_slot[0].split(" ", 1)[-1]
elements += user_slot
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))

encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))

return encoded_messages


TEMPLATES: Dict[str, Template] = {}


Expand Down Expand Up @@ -236,7 +295,12 @@ def _register_template(
```
"""
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template
if name.startswith("llama2"):
template_class = Llama2Template
elif name.startswith("mistral_v2"):
template_class = MistralV2Template
else:
template_class = Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
Expand Down Expand Up @@ -722,6 +786,14 @@ def get_template_and_fix_tokenizer(
)


_register_template(
name="mistral_v2",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["[INST] {{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)


_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
Expand Down