Skip to content

Commit

Permalink
Merge pull request #26 from mistralai/add_tekken
Browse files Browse the repository at this point in the history
Tekken
  • Loading branch information
patrickvonplaten authored Jul 18, 2024
2 parents 01f2a17 + 49582b3 commit b40f748
Show file tree
Hide file tree
Showing 13 changed files with 750,596 additions and 132 deletions.
750,011 changes: 750,011 additions & 0 deletions src/mistral_common/data/tekken_240718.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/mistral_common/protocol/instruct/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _aggregate_assistant_messages(self, messages: List[UATS]) -> AssistantMessag
return self._assistant_message_class(
content="\n\n".join(aggregated_content) if len(aggregated_content) else None,
tool_calls=tool_calls or None,
prefix=prefix
prefix=prefix,
)

def _aggregate_user_messages(self, messages: List[UATS]) -> UserMessageType:
Expand Down
2 changes: 1 addition & 1 deletion src/mistral_common/protocol/instruct/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _validate_last_message(self, message: UATS) -> None:
# The last message must be a user or tool message in serving mode or an assistant message in finetuning mode
last_message_role = message.role
if self._mode == ValidationMode.finetuning:
if last_message_role != Roles.assistant:
if last_message_role != Roles.assistant:
raise InvalidMessageStructureException(
f"Expected last role Assistant for finetuning but got {last_message_role.value}"
)
Expand Down
6 changes: 6 additions & 0 deletions src/mistral_common/tokens/tokenizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class SpecialTokens(str, Enum):
suffix = "[SUFFIX]"


class TokenizerVersion(str, Enum):
v1 = "v1" # vocab_size = 32000
v2 = "v2" # vocab_size = 32768 with special control tokens [INST], [\INST]
v3 = "v3" # vocab_size = 32768 (spm) OR 128000 (tekken) with improved function calling


class Tokenized(MistralBase):
"""
A tokenized InstructRequest
Expand Down
67 changes: 42 additions & 25 deletions src/mistral_common/tokens/tokenizers/mistral.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Callable, Dict, Generic, List
from typing import Callable, Dict, Generic, List, Union

from mistral_common.exceptions import (
TokenizerException,
Expand All @@ -25,15 +25,17 @@
InstructRequest,
InstructRequestType,
InstructTokenizer,
Tokenized,
TokenizedType,
TokenizerVersion,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
InstructTokenizerV1,
InstructTokenizerV2,
InstructTokenizerV3,
SentencePieceTokenizer,
is_sentencepiece,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer, is_tekken


class MistralTokenizer(
Expand All @@ -56,26 +58,25 @@ def _data_path(cls) -> Path:
return Path(__file__).parents[2] / "data"

@classmethod
def v1(cls) -> MistralTokenizer:
"""open-mistral-7b // open-mixtral-8x7b // mistral-embed"""
def v1(cls) -> "MistralTokenizer":
"""open 7B x 8x7B + embed"""
return cls.from_file(str(cls._data_path() / "tokenizer.model.v1"), mode=ValidationMode.test)

@classmethod
def v2(cls) -> MistralTokenizer:
def v2(cls) -> "MistralTokenizer":
"""mistral-small // mistral-large"""
return cls.from_file(
str(cls._data_path() / "mistral_instruct_tokenizer_240216.model.v2"), mode=ValidationMode.test
)

@classmethod
def v3(cls) -> MistralTokenizer:
"""open-mixtral-8x22b // codestral-22b"""
return cls.from_file(
str(cls._data_path() / "mistral_instruct_tokenizer_240323.model.v3"), mode=ValidationMode.test
)
def v3(cls, is_tekken: bool = False) -> "MistralTokenizer":
"""open-mixtral-8x22B"""
tokenizer_name = "mistral_instruct_tokenizer_240323.model.v3" if not is_tekken else "tekken_240718.json"
return cls.from_file(str(cls._data_path() / tokenizer_name), mode=ValidationMode.test)

@classmethod
def from_model(cls, model: str) -> MistralTokenizer:
def from_model(cls, model: str) -> "MistralTokenizer":
model_name_to_tokenizer_cls: Dict[str, Callable[[], MistralTokenizer]] = {
"open-mistral-7b": MistralTokenizer.v1,
"open-mixtral-8x7b": MistralTokenizer.v1,
Expand All @@ -84,6 +85,7 @@ def from_model(cls, model: str) -> MistralTokenizer:
"mistral-large": MistralTokenizer.v2,
"open-mixtral-8x22b": MistralTokenizer.v3,
"codestral-22b": MistralTokenizer.v3,
"mistral-nemo": lambda: MistralTokenizer.v3(is_tekken=True),
}

# Prefix search the model name mapping
Expand All @@ -94,37 +96,52 @@ def from_model(cls, model: str) -> MistralTokenizer:
raise TokenizerException(f"Unrecognized model: {model}")

@classmethod
def from_file(cls, tokenizer_filename: str, mode: ValidationMode = ValidationMode.test) -> MistralTokenizer:
def from_file(
cls,
tokenizer_filename: str,
mode: ValidationMode = ValidationMode.test,
) -> "MistralTokenizer":
"""
Depending on which model we are loading, tokenization and validation might be different. 💩
"""
if tokenizer_filename.endswith(".model.v1") or tokenizer_filename.endswith(".model"):
tokenizer: Union[SentencePieceTokenizer, Tekkenizer]

if is_tekken(tokenizer_filename):
tokenizer = Tekkenizer.from_file(tokenizer_filename)
elif is_sentencepiece(tokenizer_filename):
tokenizer = SentencePieceTokenizer(tokenizer_filename)
else:
raise TokenizerException(f"Unrecognized tokenizer file: {tokenizer_filename}")

request_normalizer = InstructRequestNormalizer.normalizer()

if tokenizer.version == TokenizerVersion.v1:
return MistralTokenizer(
InstructTokenizerV1(SentencePieceTokenizer(tokenizer_filename)),
InstructTokenizerV1(tokenizer),
validator=MistralRequestValidator(mode=mode),
request_normalizer=InstructRequestNormalizer.normalizer(),
request_normalizer=request_normalizer,
)
elif tokenizer_filename.endswith(".model.v2"):
elif tokenizer.version == TokenizerVersion.v2:
return MistralTokenizer(
InstructTokenizerV2(SentencePieceTokenizer(tokenizer_filename)),
InstructTokenizerV2(tokenizer),
validator=MistralRequestValidator(mode=mode),
request_normalizer=InstructRequestNormalizer.normalizer(),
request_normalizer=request_normalizer,
)
elif tokenizer_filename.endswith(".model.v3"):
elif tokenizer.version == TokenizerVersion.v3:
return MistralTokenizer(
InstructTokenizerV3(SentencePieceTokenizer(tokenizer_filename)),
InstructTokenizerV3(tokenizer),
validator=MistralRequestValidatorV3(mode=mode),
request_normalizer=InstructRequestNormalizer.normalizer(),
request_normalizer=request_normalizer,
)
else:
raise TokenizerException(f"Unrecognized tokenizer filename: {tokenizer_filename}")

def encode_chat_completion(self, request: ChatCompletionRequest[UATS]) -> Tokenized:
raise TokenizerException(f"Unrecognized tokenizer version: {tokenizer.version}")

def encode_chat_completion(self, request: ChatCompletionRequest[UATS]) -> TokenizedType:
validated_request = self._chat_completion_request_validator.validate_request(request)
instruct_request = self._instruct_request_normalizer.from_chat_completion_request(validated_request)
return self.instruct_tokenizer.encode_instruct(instruct_request)

def encode_fim(self, request: FIMRequest) -> Tokenized:
def encode_fim(self, request: FIMRequest) -> TokenizedType:
return self.instruct_tokenizer.encode_fim(request)

def decode(self, tokens: List[int]) -> str:
Expand Down
33 changes: 32 additions & 1 deletion src/mistral_common/tokens/tokenizers/sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
from abc import abstractmethod
from functools import cached_property
from typing import Any, Dict, Generic, List, Optional, Set, Tuple
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Union

from mistral_common.exceptions import TokenizerException
from mistral_common.protocol.instruct.messages import (
Expand All @@ -22,10 +23,34 @@
Tokenized,
TokenizedType,
Tokenizer,
TokenizerVersion,
)
from sentencepiece import SentencePieceProcessor


def is_sentencepiece(path: Union[str, Path]) -> bool:
if isinstance(path, str):
path = Path(path)

suffixes = [f".model.{v}" for v in list(TokenizerVersion.__members__)] + [".model"]
return path.is_file() and any(path.name.endswith(suffix) for suffix in suffixes)


def get_spm_version(tokenizer_filename: str, raise_deprecated: bool = False) -> TokenizerVersion:
_version_str = tokenizer_filename.split(".")[-1]
if _version_str == "model":
if raise_deprecated:
raise TokenizerException(f"Make sure to rename your tokenizer file to end with {tokenizer_filename}.v1.")

# tokenizer.model => tokenizer.model.v1
return TokenizerVersion("v1")

if _version_str not in TokenizerVersion.__members__:
raise TokenizerException(f"Unrecognized tokenizer filename: {tokenizer_filename}")

return TokenizerVersion(_version_str)


class SentencePieceTokenizer(Tokenizer):
def __init__(self, model_path: str) -> None:
self._logger = logging.getLogger(self.__class__.__name__)
Expand All @@ -36,8 +61,14 @@ def __init__(self, model_path: str) -> None:
assert self._model.vocab_size() == self._model.get_piece_size()
self._vocab = [self._model.id_to_piece(i) for i in range(self.n_words)]

self._version: TokenizerVersion = get_spm_version(model_path, raise_deprecated=False)

super().__init__()

@property
def version(self) -> TokenizerVersion:
return self._version

def get_control_token(self, s: str) -> int:
return self._model.piece_to_id(s) # type: ignore

Expand Down
Loading

0 comments on commit b40f748

Please sign in to comment.