Skip to content

Commit

Permalink
Instruct v7
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Nov 15, 2024
1 parent 221cbf4 commit d93e786
Show file tree
Hide file tree
Showing 14 changed files with 660 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistral_common"
version = "1.4.4"
version = "1.5.0"
description = ""
authors = ["bam4d <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/mistral_common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.4"
__version__ = "1.5.0"
Binary file not shown.
52 changes: 49 additions & 3 deletions src/mistral_common/protocol/instruct/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import FunctionCall, Tool, ToolCall
from mistral_common.tokens.instruct.request import InstructRequest
from mistral_common.tokens.tokenizers.base import InstructRequestType
from mistral_common.tokens.tokenizers.base import InstructRequestType, TokenizerVersion


class InstructRequestNormalizer(
Expand All @@ -35,6 +35,8 @@ class InstructRequestNormalizer(
- Normalize tool calls
"""

system_prompt_in_begin: bool = False

def __init__(
self,
user_message_class: Type[UserMessageType],
Expand Down Expand Up @@ -117,7 +119,7 @@ def _aggregate_assistant_messages(self, messages: List[UATS]) -> AssistantMessag
weight: Optional[float] = None
for message in messages:
assert isinstance(message, self._assistant_message_class), "Expected assistant message"
if message.tool_calls is not None:
if message.tool_calls is not None and len(message.tool_calls) > 0:
for tool_call in message.tool_calls:
normalized_tool_call = self._normalize_tool_call(tool_call)
tool_calls.append(normalized_tool_call)
Expand Down Expand Up @@ -205,7 +207,9 @@ def _aggregate_messages(self, request: ChatCompletionRequest[UATS]) -> List[UATS

# If the first message is not a user message, or we didnt aggregate
# anything (all system messages) for example, add an empty user message
if len(aggregated_messages) == 0 or aggregated_messages[0].role != Roles.user:
if len(aggregated_messages) == 0 or (
not self.system_prompt_in_begin and aggregated_messages[0].role != Roles.user
):
aggregated_messages.insert(0, self._user_message_class(content=""))

return aggregated_messages
Expand All @@ -217,3 +221,45 @@ def from_chat_completion_request(self, request: ChatCompletionRequest[UATS]) ->
return self._instruct_request_class(
messages=messages, system_prompt=system_prompt, available_tools=request.tools
)


class InstructRequestNormalizerV7(InstructRequestNormalizer):
system_prompt_in_begin: bool = True

@staticmethod
def normalizer() -> "InstructRequestNormalizerV7":
return InstructRequestNormalizerV7(
UserMessage,
AssistantMessage,
ToolMessage,
SystemMessage,
InstructRequest[UATS, Tool],
)

def _aggregate_role(self, messages: List[UATS], role: Optional[Roles]) -> Sequence[UATS]:
if role == Roles.tool:
return self._aggregate_tool_messages(messages)
elif role == Roles.assistant:
return [self._aggregate_assistant_messages(messages)]
elif role == Roles.user:
return [self._aggregate_user_messages(messages)]
elif role == Roles.system:
return messages
else:
assert role is None and len(messages) == 0
return []

def _aggregate_system_prompts(self, request: ChatCompletionRequest[UATS]) -> Optional[str]:
raise NotImplementedError("We should not aggregate system prompts")

def from_chat_completion_request(self, request: ChatCompletionRequest[UATS]) -> InstructRequestType: # type: ignore[type-var]
messages = self._aggregate_messages(request)
return self._instruct_request_class(messages=messages, system_prompt=None, available_tools=request.tools) # type: ignore[no-any-return]


def normalizer_for_tokenizer_version(version: TokenizerVersion) -> InstructRequestNormalizer:
if version in {TokenizerVersion.v1, TokenizerVersion.v2, TokenizerVersion.v3}:
return InstructRequestNormalizer.normalizer()
elif version == TokenizerVersion.v7:
return InstructRequestNormalizerV7.normalizer()
raise ValueError(f"Unknown tokenizer version {version}")
1 change: 1 addition & 0 deletions src/mistral_common/protocol/instruct/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class ChatCompletionRequest(BaseCompletionRequest, Generic[ChatMessageType]):
response_format: ResponseFormat = Field(default_factory=ResponseFormat)
tools: Optional[List[Tool]] = None
tool_choice: ToolChoice = ToolChoice.auto
truncate_for_context_length: bool = False
1 change: 1 addition & 0 deletions src/mistral_common/tokens/instruct/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ class InstructRequest(MistralBase, Generic[ChatMessageType, ToolType]):
messages: List[ChatMessageType]
system_prompt: Optional[str] = None
available_tools: Optional[List[ToolType]] = None
truncate_at_max_tokens: Optional[int] = None
4 changes: 4 additions & 0 deletions src/mistral_common/tokens/tokenizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ class SpecialTokens(str, Enum):
prefix = "[PREFIX]"
middle = "[MIDDLE]"
suffix = "[SUFFIX]"
begin_system = "[SYSTEM_PROMPT]"
end_system = "[/SYSTEM_PROMPT]"
begin_tool_content = "[TOOL_CONTENT]"


class TokenizerVersion(str, Enum):
v1 = "v1" # vocab_size = 32000
v2 = "v2" # vocab_size = 32768 with special control tokens [INST], [\INST]
v3 = "v3" # vocab_size = 32768 (spm) OR 128000 (tekken) with improved function calling
v7 = "v7" # vocab_size = 32768 (spm) or 128000 (tekken) with improved system prompt and function calling


class Tokenized(MistralBase):
Expand Down
47 changes: 40 additions & 7 deletions src/mistral_common/tokens/tokenizers/mistral.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Callable, Dict, Generic, List, Union
from typing import Callable, Dict, Generic, List, Optional, Union

from mistral_common.exceptions import (
TokenizerException,
Expand All @@ -13,7 +13,7 @@
ToolMessageType,
UserMessageType,
)
from mistral_common.protocol.instruct.normalize import InstructRequestNormalizer
from mistral_common.protocol.instruct.normalize import InstructRequestNormalizer, normalizer_for_tokenizer_version
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.validator import (
MistralRequestValidator,
Expand All @@ -39,13 +39,17 @@
InstructTokenizerV1,
InstructTokenizerV2,
InstructTokenizerV3,
InstructTokenizerV7,
SentencePieceTokenizer,
get_mm_config,
is_sentencepiece,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer, is_tekken


def load_mm_encoder(mm_config: MultimodalConfig, tokenizer: Tekkenizer) -> MultiModalEncoder:
def load_mm_encoder(
mm_config: MultimodalConfig, tokenizer: Union[Tekkenizer, SentencePieceTokenizer]
) -> MultiModalEncoder:
special_ids = SpecialImageIDs(
img=tokenizer.get_control_token(SpecialTokens.img.value),
img_break=tokenizer.get_control_token(SpecialTokens.img_break.value),
Expand Down Expand Up @@ -99,6 +103,13 @@ def v3(cls, is_tekken: bool = False, is_mm: bool = False) -> "MistralTokenizer":

return cls.from_file(str(cls._data_path() / tokenizer_name), mode=ValidationMode.test)

@classmethod
def v7(cls) -> "MistralTokenizer":
"""mistral-large 2.1"""
return cls.from_file(
str(cls._data_path() / "mistral_instruct_tokenizer_241114.model.v7m1"), mode=ValidationMode.test
)

@classmethod
def from_model(cls, model: str) -> "MistralTokenizer":
model_name_to_tokenizer_cls: Dict[str, Callable[[], MistralTokenizer]] = {
Expand Down Expand Up @@ -136,14 +147,15 @@ def from_file(
if is_tekken(tokenizer_filename):
tokenizer = Tekkenizer.from_file(tokenizer_filename)
mm_config = tokenizer.multimodal
mm_encoder = load_mm_encoder(mm_config, tokenizer) if mm_config is not None else None
elif is_sentencepiece(tokenizer_filename):
tokenizer = SentencePieceTokenizer(tokenizer_filename)
mm_encoder = None
mm_config = get_mm_config(tokenizer_filename)
else:
raise TokenizerException(f"Unrecognized tokenizer file: {tokenizer_filename}")

request_normalizer = InstructRequestNormalizer.normalizer()
mm_encoder = load_mm_encoder(mm_config, tokenizer) if mm_config is not None else None

request_normalizer = normalizer_for_tokenizer_version(tokenizer.version)

if tokenizer.version == TokenizerVersion.v1:
assert mm_encoder is None, "Tokenizer version needs to be >= v3"
Expand All @@ -165,14 +177,35 @@ def from_file(
validator=MistralRequestValidatorV3(mode=mode),
request_normalizer=request_normalizer,
)
elif tokenizer.version == TokenizerVersion.v7:
return MistralTokenizer(
InstructTokenizerV7(tokenizer, mm_encoder=mm_encoder),
validator=MistralRequestValidatorV3(mode=mode),
request_normalizer=request_normalizer,
)
else:
raise TokenizerException(f"Unrecognized tokenizer filename: {tokenizer_filename}")

raise TokenizerException(f"Unrecognized tokenizer version: {tokenizer.version}")

def encode_chat_completion(self, request: ChatCompletionRequest[UATS]) -> TokenizedType:
def encode_chat_completion(
self, request: ChatCompletionRequest[UATS], max_model_input_len: Optional[int] = None
) -> TokenizedType:
validated_request = self._chat_completion_request_validator.validate_request(request)

if max_model_input_len is None and request.truncate_for_context_length:
# the max_model_input_len arg should not be optionnal ;
# but this function is used in many small scripts that have no use
# for truncation, and don't provide the max model len
raise TokenizerException(
"encoding a chat completion request with truncation, but no max model len was provided",
)

instruct_request = self._instruct_request_normalizer.from_chat_completion_request(validated_request)

if request.truncate_for_context_length:
instruct_request.truncate_at_max_tokens = max_model_input_len

return self.instruct_tokenizer.encode_instruct(instruct_request)

def encode_fim(self, request: FIMRequest) -> TokenizedType:
Expand Down
13 changes: 13 additions & 0 deletions src/mistral_common/tokens/tokenizers/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import logging
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Tuple, Union

Expand Down Expand Up @@ -57,6 +58,18 @@ def image_from_chunk(chunk: Union[ImageURLChunk, ImageChunk]) -> SerializableIma
DATASET_STD = (0.26862954, 0.26130258, 0.27577711) # RGB


# only relevant for spm
class MultiModalVersion(str, Enum):
m1 = "m1"

@property
def config(self) -> "MultimodalConfig":
if self.name == "m1":
return MultimodalConfig(16, 1024)

raise NotImplementedError(f"{self.name}")


@dataclass
class MultimodalConfig:
image_patch_size: int
Expand Down
Loading

0 comments on commit d93e786

Please sign in to comment.