From 64d0d704f122f8fbf98c42956be21a67c546f9e1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 20 Jul 2024 19:47:29 +0200 Subject: [PATCH] Improve --- src/mistral_inference/main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mistral_inference/main.py b/src/mistral_inference/main.py index d4302fe..74743e8 100644 --- a/src/mistral_inference/main.py +++ b/src/mistral_inference/main.py @@ -12,6 +12,7 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.base import Tokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import Tekkenizer, SpecialTokenPolicy from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece from mistral_common.tokens.tokenizers.tekken import is_tekken @@ -36,6 +37,9 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer: mistral_tokenizer = MistralTokenizer.from_file(str(model_path / tokenizer[0])) + if isinstance(mistral_tokenizer.instruct_tokenizer.tokenizer, Tekkenizer): + mistral_tokenizer.instruct_tokenizer.tokenizer.special_token_policy = SpecialTokenPolicy.KEEP + logging.info(f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}") return mistral_tokenizer