Skip to content

Commit

Permalink
Fix fast tokenizer swallows prefix space when there are too many whit…
Browse files Browse the repository at this point in the history
…e spaces (#992)

* fix fast tokenizer swallow prefix space when there is too many whitespace

* decode prefix token id when decoding the current ones
  • Loading branch information
AllentDan authored Jan 31, 2024
1 parent 789dcce commit 29b74d5
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 52 deletions.
7 changes: 3 additions & 4 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from tqdm import tqdm

from lmdeploy.tokenizer import Tokenizer
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.turbomind import TurboMind


Expand Down Expand Up @@ -80,7 +80,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
for prompt, input_seqlen, output_seqlen in iter(
req_queue.get, [None, None, None]):
_per_token_latency_stats = [0] * (output_seqlen + 1)
offset = 0
state = DetokenizeState()
prev = time.perf_counter()
n_prev_token = 0

Expand All @@ -96,8 +96,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
ignore_eos=True,
stream_output=stream_output):
_, res, n_token = outputs
self.tokenizer.decode(res, offset)
offset = n_token
_, state = self.tokenizer.detokenize_incrementally(res, state)
now = time.perf_counter()
if n_prev_token != n_token:
_per_token_latency_stats[n_prev_token] = np.round(
Expand Down
7 changes: 3 additions & 4 deletions benchmark/profile_torch_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig
from lmdeploy.pytorch.engine import Engine as LMEngine
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.tokenizer import DetokenizeState, Tokenizer


def sample_requests(
Expand Down Expand Up @@ -83,7 +83,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
for prompt, input_seqlen, output_seqlen in iter(
req_queue.get, [None, None, None]):
_per_token_latency_stats = [0] * (output_seqlen + 1)
offset = 0
state = DetokenizeState
prev = time.perf_counter()
n_prev_token = 0

Expand All @@ -98,8 +98,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
input_ids=input_ids,
gen_config=gen_config):
res, n_token = outputs[-2:]
self.tokenizer.decode(res, offset)
offset = n_token
_, state = self.tokenizer.detokenize_incrementally(res, state)
now = time.perf_counter()
if n_prev_token != n_token:
_per_token_latency_stats[n_prev_token] = np.round(
Expand Down
13 changes: 4 additions & 9 deletions lmdeploy/pytorch/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig
from lmdeploy.model import MODELS, best_match_model
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.tokenizer import DetokenizeState, Tokenizer

os.environ['TM_LOG_LEVEL'] = 'ERROR'

Expand Down Expand Up @@ -107,7 +107,7 @@ def run_chat(model_path: str,
continue

print(f'{prompt} ', end='', flush=True)
response_size = 0
state = DetokenizeState()
gen_config.random_seed = seed
gen_config.stop_words = stop_words
for outputs in generator.stream_infer(session_id=session_id,
Expand All @@ -116,15 +116,10 @@ def run_chat(model_path: str,
adapter_name=adapter_name):
status, res, tokens = outputs
# decode res
response = tokenizer.decode(res, offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
response, state = tokenizer.detokenize_incrementally(
res, state)
response = valid_str(response)
print(f'{response}', end='', flush=True)
response_size = tokens

# update step
step += len(input_ids) + tokens
Expand Down
20 changes: 8 additions & 12 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PytorchEngineConfig, Response,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig, best_match_model
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _stop_words, get_logger


Expand Down Expand Up @@ -452,7 +453,7 @@ async def generate(
else:
generator = await self.get_generator(False, session_id)
with self.safe_run(session_id):
response_size = 0
state = DetokenizeState()
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=input_ids,
Expand All @@ -461,25 +462,20 @@ async def generate(
sequence_start=(sequence_start),
sequence_end=sequence_end,
step=self.id2step[str(session_id)]):
status, res, tokens = outputs
_, res, tokens = outputs
# decode res
response = self.tokenizer.decode(res, offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
response, state = self.tokenizer.detokenize_incrementally(
res, state)
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens

finish_reason = 'length' \
if tokens >= gen_config.max_new_tokens else 'stop'
# `response_size` might be note updated since
# ` if response.endswith('�')`
if response_size == tokens:
# utf-8 char at the end means it's a potential unfinished
# byte sequence
if not response.endswith('�'):
response = '' # avaid returning the last response twice
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
Expand Down
191 changes: 190 additions & 1 deletion lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import json
import os.path as osp
from collections import deque
from typing import List, Optional, Sequence, Union
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union

import torch

Expand All @@ -12,6 +13,31 @@
# importing are starting from the package root lmdeploy


@dataclass
class DetokenizeState:
"""A state collection of incrementally detekenization.
Args:
ids_offset (int): offset to all input ids. In LMDeploy, the output
ids length is not one by one. It could be random by random.
prev_tokens (List[str] | None): for incrementally decoding.
Default to None, which means the first round.
prefix_offset (int): the start index of tokens to be converted to
string (prev + new tokens). Default to 0 for the first round.
read_offset (int): the end index of tokens to be converted to
string (prev token). Default to 0 for the first round.
"""
ids_offset: int = 0
prev_tokens: Optional[List[str]] = None
prefix_offset: int = 0
read_offset: int = 0

def as_tuple(self) -> Tuple:
"""Return a tuple of states."""
return (self.ids_offset, self.prev_tokens, self.prefix_offset,
self.read_offset)


class SentencePieceTokenizer:
"""Tokenizer of sentencepiece.
Expand Down Expand Up @@ -113,6 +139,34 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None):
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string

def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
out_string = self.model.Decode(all_input_ids)
if state.prev_tokens is not None:
out_string = self._maybe_add_prefix_space(all_input_ids,
out_string)
state.prev_tokens = [] # not None for the above condition
return out_string, state

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Expand Down Expand Up @@ -289,9 +343,117 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None):
out_string = self.model.decode(t,
skip_special_tokens=skip_special_tokens)
if offset:
logger = get_logger('lmdeploy')
logger.warning('For incrementally detokenization, please try'
'detokenize_incrementally function instead.')
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string

@staticmethod
def _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
if tokenizer.is_fast or not tokenizer.get_added_vocab():
return tokenizer.convert_tokens_to_string(output_tokens)
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(
current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return ' '.join(sub_texts)
else:
return ''.join(sub_texts)

# Based on
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
tokenizer = self.model
ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple()
# This is the first iteration for this sequence
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids[ids_offset:],
skip_special_tokens=skip_special_tokens)
if prev_tokens is None:
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
if skip_special_tokens and new_tokens[
0] in tokenizer.all_special_ids:
read_offset = 1 # skip special token
output_tokens = new_tokens
prev_tokens = new_tokens
else:
# Put new_token_id in a list so skip_special_tokens is respected
output_tokens = prev_tokens + new_tokens
prev_tokens += new_tokens

prefix_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)

# update state and get final decoded output
if len(new_text) > len(prefix_text) and not new_text.endswith('�'):
# utf-8 char at the end means it's a potential unfinished byte
# sequence from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
prefix_offset = read_offset
read_offset = len(output_tokens)
new_text = new_text[len(prefix_text):]
else:
new_text = ''

return new_text, DetokenizeState(len(all_input_ids), prev_tokens,
prefix_offset, read_offset)

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Expand Down Expand Up @@ -365,6 +527,33 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""
return self.model.decode(t, offset)

def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
return self.model.detokenize_incrementally(
all_input_ids,
state=state,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens)

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Expand Down
Loading

0 comments on commit 29b74d5

Please sign in to comment.