From 9f4a6b916bfc1e761ba756f19b189b9db1def69a Mon Sep 17 00:00:00 2001 From: juvi <140188098+juvi21@users.noreply.github.com> Date: Sat, 3 Aug 2024 09:31:40 +0200 Subject: [PATCH 1/2] add: use token_str from vocab --- src/mistral_common/tokens/tokenizers/tekken.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/mistral_common/tokens/tokenizers/tekken.py b/src/mistral_common/tokens/tokenizers/tekken.py index 864fa98..2a884ed 100644 --- a/src/mistral_common/tokens/tokenizers/tekken.py +++ b/src/mistral_common/tokens/tokenizers/tekken.py @@ -89,6 +89,7 @@ def __init__( ) self._vocab_size = vocab_size self._path = _path + self._tokens_str = {token['rank']: token['token_str'] for token in vocab if token['token_str'] is not None} special_tokens = list(self.SPECIAL_TOKENS) assert len(special_tokens) == len(set(special_tokens)), f"Special tokens must be unique: {special_tokens}" @@ -195,7 +196,6 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]: return tokens def _decode_all(self, tokens: List[int], special_token_policy: SpecialTokenPolicy) -> List[str]: - # Lump special and non-special tokens together to minimize calls to decode decoded: List[str] = [] for is_special, group in groupby(tokens, lambda t: t < self.num_special_tokens): if is_special: @@ -215,12 +215,16 @@ def _decode_all(self, tokens: List[int], special_token_policy: SpecialTokenPolic decoded.extend(self._all_special_tokens[t] for t in group) elif special_token_policy == SpecialTokenPolicy.IGNORE: continue - # TODO: Could use "tokens_str" from vocab.json - # but need to handle null cases. else: - decoded.append(self._model.decode([t - self.num_special_tokens for t in group])) + decoded.extend(self._decode_token(t) for t in group) return decoded + def _decode_token(self, token: int) -> str: + adjusted_token = token - self.num_special_tokens + if adjusted_token in self._tokens_str: + return self._tokens_str[adjusted_token] + return self._model.decode([adjusted_token]) + def is_byte(self, token_id: int) -> bool: return 0 <= token_id - self.num_special_tokens < 256 From 043712ceeb7e19bc0895da6bb109d6907667e25c Mon Sep 17 00:00:00 2001 From: juvi <140188098+juvi21@users.noreply.github.com> Date: Sat, 3 Aug 2024 09:57:16 +0200 Subject: [PATCH 2/2] Update tekken.py add: removed comment --- src/mistral_common/tokens/tokenizers/tekken.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mistral_common/tokens/tokenizers/tekken.py b/src/mistral_common/tokens/tokenizers/tekken.py index 2a884ed..8b13367 100644 --- a/src/mistral_common/tokens/tokenizers/tekken.py +++ b/src/mistral_common/tokens/tokenizers/tekken.py @@ -196,6 +196,7 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]: return tokens def _decode_all(self, tokens: List[int], special_token_policy: SpecialTokenPolicy) -> List[str]: + # Lump special and non-special tokens together to minimize calls to decode decoded: List[str] = [] for is_special, group in groupby(tokens, lambda t: t < self.num_special_tokens): if is_special: @@ -215,6 +216,7 @@ def _decode_all(self, tokens: List[int], special_token_policy: SpecialTokenPolic decoded.extend(self._all_special_tokens[t] for t in group) elif special_token_policy == SpecialTokenPolicy.IGNORE: continue + else: decoded.extend(self._decode_token(t) for t in group) return decoded