diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 03527265d2..ff20bec5bd 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -12,7 +12,7 @@ import numpy as np from tqdm import tqdm -from lmdeploy.tokenizer import Tokenizer +from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.turbomind import TurboMind @@ -80,7 +80,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, for prompt, input_seqlen, output_seqlen in iter( req_queue.get, [None, None, None]): _per_token_latency_stats = [0] * (output_seqlen + 1) - offset = 0 + state = DetokenizeState() prev = time.perf_counter() n_prev_token = 0 @@ -96,8 +96,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, ignore_eos=True, stream_output=stream_output): _, res, n_token = outputs - self.tokenizer.decode(res, offset) - offset = n_token + _, state = self.tokenizer.detokenize_incrementally(res, state) now = time.perf_counter() if n_prev_token != n_token: _per_token_latency_stats[n_prev_token] = np.round( diff --git a/benchmark/profile_torch_throughput.py b/benchmark/profile_torch_throughput.py index e82f657593..9c5f6d3b82 100644 --- a/benchmark/profile_torch_throughput.py +++ b/benchmark/profile_torch_throughput.py @@ -14,7 +14,7 @@ from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig from lmdeploy.pytorch.engine import Engine as LMEngine -from lmdeploy.tokenizer import Tokenizer +from lmdeploy.tokenizer import DetokenizeState, Tokenizer def sample_requests( @@ -83,7 +83,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, for prompt, input_seqlen, output_seqlen in iter( req_queue.get, [None, None, None]): _per_token_latency_stats = [0] * (output_seqlen + 1) - offset = 0 + state = DetokenizeState prev = time.perf_counter() n_prev_token = 0 @@ -98,8 +98,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, input_ids=input_ids, gen_config=gen_config): res, n_token = outputs[-2:] - self.tokenizer.decode(res, offset) - offset = n_token + _, state = self.tokenizer.detokenize_incrementally(res, state) now = time.perf_counter() if n_prev_token != n_token: _per_token_latency_stats[n_prev_token] = np.round( diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py index a0f41eb998..9876dbf98f 100644 --- a/lmdeploy/pytorch/chat.py +++ b/lmdeploy/pytorch/chat.py @@ -6,7 +6,7 @@ from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig from lmdeploy.model import MODELS, best_match_model -from lmdeploy.tokenizer import Tokenizer +from lmdeploy.tokenizer import DetokenizeState, Tokenizer os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -107,7 +107,7 @@ def run_chat(model_path: str, continue print(f'{prompt} ', end='', flush=True) - response_size = 0 + state = DetokenizeState() gen_config.random_seed = seed gen_config.stop_words = stop_words for outputs in generator.stream_infer(session_id=session_id, @@ -116,15 +116,10 @@ def run_chat(model_path: str, adapter_name=adapter_name): status, res, tokens = outputs # decode res - response = tokenizer.decode(res, offset=response_size) - # utf-8 char at the end means it's a potential unfinished - # byte sequence, continue to concate it with the next - # sequence and decode them together - if response.endswith('�'): - continue + response, state = tokenizer.detokenize_incrementally( + res, state) response = valid_str(response) print(f'{response}', end='', flush=True) - response_size = tokens # update step step += len(input_ids) + tokens diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 5d2b318731..bd81372c9e 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -12,6 +12,7 @@ PytorchEngineConfig, Response, TurbomindEngineConfig) from lmdeploy.model import ChatTemplateConfig, best_match_model +from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _stop_words, get_logger @@ -452,7 +453,7 @@ async def generate( else: generator = await self.get_generator(False, session_id) with self.safe_run(session_id): - response_size = 0 + state = DetokenizeState() async for outputs in generator.async_stream_infer( session_id=session_id, input_ids=input_ids, @@ -461,25 +462,20 @@ async def generate( sequence_start=(sequence_start), sequence_end=sequence_end, step=self.id2step[str(session_id)]): - status, res, tokens = outputs + _, res, tokens = outputs # decode res - response = self.tokenizer.decode(res, offset=response_size) - # utf-8 char at the end means it's a potential unfinished - # byte sequence, continue to concate it with the next - # sequence and decode them together - if response.endswith('�'): - continue + response, state = self.tokenizer.detokenize_incrementally( + res, state) # response, history token len, # input token len, gen token len yield GenOut(response, self.id2step[str(session_id)], len(input_ids), tokens, finish_reason) - response_size = tokens finish_reason = 'length' \ if tokens >= gen_config.max_new_tokens else 'stop' - # `response_size` might be note updated since - # ` if response.endswith('�')` - if response_size == tokens: + # utf-8 char at the end means it's a potential unfinished + # byte sequence + if not response.endswith('�'): response = '' # avaid returning the last response twice yield GenOut(response, self.id2step[str(session_id)], len(input_ids), tokens, finish_reason) diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py index 28e8092c63..b37a5a2483 100644 --- a/lmdeploy/tokenizer.py +++ b/lmdeploy/tokenizer.py @@ -2,7 +2,8 @@ import json import os.path as osp from collections import deque -from typing import List, Optional, Sequence, Union +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple, Union import torch @@ -12,6 +13,31 @@ # importing are starting from the package root lmdeploy +@dataclass +class DetokenizeState: + """A state collection of incrementally detekenization. + + Args: + ids_offset (int): offset to all input ids. In LMDeploy, the output + ids length is not one by one. It could be random by random. + prev_tokens (List[str] | None): for incrementally decoding. + Default to None, which means the first round. + prefix_offset (int): the start index of tokens to be converted to + string (prev + new tokens). Default to 0 for the first round. + read_offset (int): the end index of tokens to be converted to + string (prev token). Default to 0 for the first round. + """ + ids_offset: int = 0 + prev_tokens: Optional[List[str]] = None + prefix_offset: int = 0 + read_offset: int = 0 + + def as_tuple(self) -> Tuple: + """Return a tuple of states.""" + return (self.ids_offset, self.prev_tokens, self.prefix_offset, + self.read_offset) + + class SentencePieceTokenizer: """Tokenizer of sentencepiece. @@ -113,6 +139,34 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None): out_string = self._maybe_add_prefix_space(t, out_string) return out_string + def detokenize_incrementally(self, + all_input_ids: Sequence[int], + state: DetokenizeState, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True): + """Incrementally detokenize the input indexes. + + Args: + all_input_ids (List[int]): a list of token ids. Expected to be + different sections of a long sequence. + state (DetokenizeState): an instance of DetokenizeState. Consists + of incrementally decoding states. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + spaces_between_special_tokens (bool): Whether or not to add spaces + between special tokens. Default to be True. + Returns: + str: decoding output string of the current round. + state (DetokenizeState): an instance of DetokenizeState. Consists + of incrementally decoding states. + """ + out_string = self.model.Decode(all_input_ids) + if state.prev_tokens is not None: + out_string = self._maybe_add_prefix_space(all_input_ids, + out_string) + state.prev_tokens = [] # not None for the above condition + return out_string, state + def __call__(self, s: Union[str, Sequence[str]]): """Tokenize prompts. @@ -289,9 +343,117 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None): out_string = self.model.decode(t, skip_special_tokens=skip_special_tokens) if offset: + logger = get_logger('lmdeploy') + logger.warning('For incrementally detokenization, please try' + 'detokenize_incrementally function instead.') out_string = self._maybe_add_prefix_space(t, out_string) return out_string + @staticmethod + def _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, + ) -> str: + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + return tokenizer.convert_tokens_to_string(output_tokens) + # Adapted from + # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99 + sub_texts = [] + current_sub_text = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string( + current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return ' '.join(sub_texts) + else: + return ''.join(sub_texts) + + # Based on + # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165 + def detokenize_incrementally(self, + all_input_ids: Sequence[int], + state: DetokenizeState, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True): + """Incrementally detokenize the input indexes. + + Args: + all_input_ids (List[int]): a list of token ids. Expected to be + different sections of a long sequence. + state (DetokenizeState): an instance of DetokenizeState. Consists + of incrementally decoding states. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + spaces_between_special_tokens (bool): Whether or not to add spaces + between special tokens. Default to be True. + Returns: + str: decoding output string of the current round. + state (DetokenizeState): an instance of DetokenizeState. Consists + of incrementally decoding states. + """ + tokenizer = self.model + ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple() + # This is the first iteration for this sequence + new_tokens = tokenizer.convert_ids_to_tokens( + all_input_ids[ids_offset:], + skip_special_tokens=skip_special_tokens) + if prev_tokens is None: + # Please notice that in VLLM, indexes are detokenized one by one + # while in LMDeploy, every turn, the detokenized indexes length + # can be different. + if skip_special_tokens and new_tokens[ + 0] in tokenizer.all_special_ids: + read_offset = 1 # skip special token + output_tokens = new_tokens + prev_tokens = new_tokens + else: + # Put new_token_id in a list so skip_special_tokens is respected + output_tokens = prev_tokens + new_tokens + prev_tokens += new_tokens + + prefix_text = self._convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = self._convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + # update state and get final decoded output + if len(new_text) > len(prefix_text) and not new_text.endswith('�'): + # utf-8 char at the end means it's a potential unfinished byte + # sequence from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + prefix_offset = read_offset + read_offset = len(output_tokens) + new_text = new_text[len(prefix_text):] + else: + new_text = '' + + return new_text, DetokenizeState(len(all_input_ids), prev_tokens, + prefix_offset, read_offset) + def __call__(self, s: Union[str, Sequence[str]]): """Tokenize prompts. @@ -365,6 +527,33 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None): """ return self.model.decode(t, offset) + def detokenize_incrementally(self, + all_input_ids: Sequence[int], + state: DetokenizeState, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True): + """Incrementally detokenize the input indexes. + + Args: + all_input_ids (List[int]): a list of token ids. Expected to be + different sections of a long sequence. + state (DetokenizeState): an instance of DetokenizeState. Consists + of incrementally decoding states. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + spaces_between_special_tokens (bool): Whether or not to add spaces + between special tokens. Default to be True. + Returns: + str: decoding output string of the current round. + state (DetokenizeState): an instance of DetokenizeState. Consists + of incrementally decoding states. + """ + return self.model.detokenize_incrementally( + all_input_ids, + state=state, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens) + def __call__(self, s: Union[str, Sequence[str]]): """Tokenize prompts. diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index 694a3cbc31..e6da93b11e 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -4,6 +4,7 @@ import random from lmdeploy.model import ChatTemplateConfig +from lmdeploy.tokenizer import DetokenizeState from lmdeploy.turbomind.utils import get_gen_param os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -113,7 +114,7 @@ def main(model_path: str, step, request_output_len, **kwargs) print(f'{prompt} ', end='', flush=True) - response_size = 0 + state = DetokenizeState() for outputs in generator.stream_infer( session_id=session_id, input_ids=[input_ids], @@ -123,15 +124,10 @@ def main(model_path: str, random_seed=seed if nth_round == 1 else None): _, res, tokens = outputs # decode res - response = tokenizer.decode(res, offset=response_size) - # utf-8 char at the end means it's a potential unfinished - # byte sequence, continue to concate it with the next - # sequence and decode them together - if response.endswith('�'): - continue + response, state = tokenizer.detokenize_incrementally( + res, state=state) response = valid_str(response) print(f'{response}', end='', flush=True) - response_size = tokens # update step step += len(input_ids) + tokens diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py index 93acfd17a4..c899b4c9d3 100644 --- a/tests/test_lmdeploy/test_model.py +++ b/tests/test_lmdeploy/test_model.py @@ -17,7 +17,7 @@ ('01-ai/Yi-34B-Chat', ['yi-chat', 'yi-34b', 'yi-200k']), ('01-ai/Yi-6B-Chat', ['yi', 'yi-chat']), ('WizardLM/WizardLM-70B-V1.0', ['wizardlm']), - ('CodeLlama-34b-Instruct-hf', ['codellama']), + ('codellama/CodeLlama-34b-Instruct-hf', ['codellama']), ('tiiuae/falcon-7b', ['falcon']), ('workspace', [None])]) @pytest.mark.parametrize('suffix', ['', '-w4', '-4bit', '-16bit']) def test_best_match_model(model_path_and_name, suffix): diff --git a/tests/test_lmdeploy/test_tokenizer.py b/tests/test_lmdeploy/test_tokenizer.py index d5638193d1..d887334564 100644 --- a/tests/test_lmdeploy/test_tokenizer.py +++ b/tests/test_lmdeploy/test_tokenizer.py @@ -1,26 +1,31 @@ import pytest -from lmdeploy.tokenizer import HuggingFaceTokenizer +from lmdeploy.tokenizer import DetokenizeState, HuggingFaceTokenizer @pytest.mark.parametrize('model_path', [ 'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat', - 'baichuan-inc/Baichuan2-7B-Chat', 'codellama/CodeLlama-7b-hf', - 'upstage/SOLAR-0-70b-16bit' + 'baichuan-inc/Baichuan2-7B-Chat', 'upstage/SOLAR-0-70b-16bit', + 'baichuan-inc/Baichuan-7B', 'codellama/CodeLlama-7b-hf', + 'THUDM/chatglm2-6b', '01-ai/Yi-6B-200k', '01-ai/Yi-34B-Chat', + '01-ai/Yi-6B-Chat', 'WizardLM/WizardLM-70B-V1.0', + 'codellama/CodeLlama-34b-Instruct-hf', 'tiiuae/falcon-7b' +]) +@pytest.mark.parametrize('input', [ + 'hi, this is a test 😆😆! ' * 5, '為什麼我還在用繁體字 😆😆 gg! ' * 5, + ' License at\n#\n#' + ' ' * 100 + 'ht', ' ' ]) -@pytest.mark.parametrize( - 'input', ['hi, this is a test 😆😆! ' * 5, '為什麼我還在用繁體字 😆😆 gg! ' * 5]) -def test_tokenizer(model_path, input): +@pytest.mark.parametrize('interval', [1, 3]) +@pytest.mark.parametrize('skip_special_tokens', [True, False]) +def test_tokenizer(model_path, input, interval, skip_special_tokens): tokenizer = HuggingFaceTokenizer(model_path) - encoded = tokenizer.encode(input, False) + encoded = tokenizer.encode(input, False, add_special_tokens=False) output = '' - offset = 0 - for i in range(1, len(encoded) + 1): - decoded = tokenizer.decode(encoded[:i], offset) - if decoded.endswith('�'): - continue + state = DetokenizeState() + for i in range(0, len(encoded), interval): + decoded, state = tokenizer.detokenize_incrementally( + encoded, state, skip_special_tokens) output += decoded - offset = i assert input == output, 'input string should equal to output after enc-dec'