Skip to content

Commit

Permalink
Decode generated token_ids incrementally (#309)
Browse files Browse the repository at this point in the history
* add incremental decoding for turbomind

* update TIS

* fix triton post processing

* update doc

* fix typo

* SentencePieceTokenizer incremental decode, add qwen message prompt

* docstring

* update bot
  • Loading branch information
AllentDan authored Sep 1, 2023
1 parent 22e8b2c commit 9bfe03c
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 34 deletions.
4 changes: 2 additions & 2 deletions docs/en/restful_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ Generate:
curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
"prompt": "Hello! Ho are you?",
"prompt": "Hello! How are you?",
"instance_id": 1,
"sequence_start": true,
"sequence_end": true
}'
Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/restful_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ curl http://{server_ip}:{server_port}/v1/models
curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
"prompt": "Hello! Ho are you?",
"prompt": "Hello! How are you?",
"instance_id": 1,
"sequence_start": true,
"sequence_end": true
}'
Expand Down
46 changes: 46 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,29 @@ def get_prompt(self, prompt, sequence_start=True):
else:
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
sequence_start (bool): flag to start the sequence
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'<BOS>{system}{self.meta_instruction}{self.eosys}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}' \
f'{assistant}'
else:
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}'
return ret

@property
def stop_words(self):
"""Return the stop-words' token ids."""
Expand Down Expand Up @@ -360,6 +383,29 @@ def get_prompt(self, prompt, sequence_start=True):
return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'{self.im_start}system\n{system}{self.im_end}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n{assistant}'
else:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n'
return ret

@property
def stop_words(self):
"""Return the stop-words' token ids."""
Expand Down
8 changes: 5 additions & 3 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ async def generate(
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res)[response_size:]
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size += len(response)
response_size = tokens

# update step
self.steps[str(session_id)] += len(input_ids) + tokens
Expand Down Expand Up @@ -229,7 +230,8 @@ async def generate_openai(
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res[response_size:])
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len, input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
Expand Down
29 changes: 13 additions & 16 deletions lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,14 +599,12 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
Yields:
tuple: status, text, generated token number
"""
offset = n_input_token + preseq_length
status, res, n_token = None, '', 0
while True:
result = res_queue.get()
if result is None:
status = StatusCode.TRITON_STREAM_END
res = session.response
n_token = session.sequence_length - offset
session.status = StatusCode.TRITON_STREAM_END
break
if 'errcode' in result:
Expand All @@ -629,30 +627,29 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
output_ids = result.as_numpy('output_ids')

session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
output_ids = output_ids[:, :, n_input_token +
preseq_length:sequence_length.squeeze(
)]
last_token_id = output_ids[-1, -1, -1]
if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1
sequence_length = sequence_length - 1

output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1]))
output_ids = output_ids[:, :, :-1]

if profile_generation:
yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze())
'token generation', output_ids.shape[-1])
continue
output_str = postprocess(output_ids[:, :, offset:],
sequence_length)
output_str = postprocess(
output_ids, np.array([[n_token]], dtype=np.uint32))
n_token = output_ids.shape[-1]
text = output_str[0].decode()
if display:
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.response = text
print(text, end='', flush=True)
session.response += text
yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze())
output_ids.shape[-1])
except Exception as e:
logger.error(f'catch exception: {e}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _postprocessing(self, tokens_batch, sequence_length):
outputs = []
for beam_tokens, beam_len in zip(tokens_batch, sequence_length):
for tokens, _len in zip(beam_tokens, beam_len):
output = self.tokenizer.decode(tokens[:_len])
output = self.tokenizer.decode(tokens, _len)
output = output.encode('utf8')
outputs.append(output)
return outputs
4 changes: 2 additions & 2 deletions lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def main(model_path,
random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0]
# decode res
response = tokenizer.decode(res)[response_size:]
response = tokenizer.decode(res.tolist(), offset=response_size)
response = valid_str(response)
print(f'{response}', end='', flush=True)
response_size += len(response)
response_size = tokens

# update step
step += len(input_ids) + tokens
Expand Down
72 changes: 64 additions & 8 deletions lmdeploy/turbomind/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import Sequence, Union
from typing import Optional, Sequence, Union

import torch

Expand All @@ -16,6 +16,7 @@ class SentencePieceTokenizer:
def __init__(self, model_file: str):
from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file)
self._no_prefix_space_tokens = None

@property
def vocab_size(self):
Expand All @@ -32,6 +33,24 @@ def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_id()

@property
def no_prefix_space_tokens(self):
"""tokens without prefix space."""
if self._no_prefix_space_tokens is None:
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i
for i, tok in enumerate(vocab) if not tok.startswith('▁')
}
return self._no_prefix_space_tokens

def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if len(tokens) and tokens[0] not in self.no_prefix_space_tokens:
return ' ' + decoded
else:
return decoded

def encode(self, s: str):
"""Tokenize a prompt.
Expand All @@ -50,17 +69,23 @@ def encode(self, s: str):
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)

def decode(self, t: Sequence[int]):
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
if isinstance(t, torch.Tensor):
t = t.tolist()
return self.model.Decode(t)
t = t[offset:]
out_string = self.model.Decode(t)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Expand All @@ -86,7 +111,7 @@ class HuggingFaceTokenizer:
"""

def __init__(self, model_dir: str):
from transformers import AutoTokenizer
from transformers import AutoTokenizer, LlamaTokenizerFast
model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
Expand All @@ -95,6 +120,8 @@ def __init__(self, model_dir: str):
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)
self.need_padding = isinstance(self.model, LlamaTokenizerFast)
self._no_prefix_space_tokens = None
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'):
Expand Down Expand Up @@ -122,6 +149,26 @@ def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_token_id

@property
def no_prefix_space_tokens(self):
"""tokens without prefix space."""
if self._no_prefix_space_tokens is None:
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i
for i, tok in enumerate(vocab) if not tok.startswith('▁')
}
return self._no_prefix_space_tokens

def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if self.need_padding and len(
tokens) and tokens[0] not in self.no_prefix_space_tokens:
return ' ' + decoded
else:
return decoded

def encode(self, s: str):
"""Tokenize a prompt.
Expand All @@ -139,16 +186,23 @@ def encode(self, s: str):
add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)

def decode(self, t: Sequence[int]):
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
skip_special_tokens = True
return self.model.decode(t, skip_special_tokens=skip_special_tokens)
t = t[offset:]
out_string = self.model.decode(t,
skip_special_tokens=skip_special_tokens)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Expand Down Expand Up @@ -211,15 +265,17 @@ def encode(self, s: str):
"""
return self.model.encode(s)

def decode(self, t: Sequence[int]):
def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
Returns:
str: text of decoding tokens
"""
return self.model.decode(t)
return self.model.decode(t, offset)

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Expand Down

0 comments on commit 9bfe03c

Please sign in to comment.