Skip to content

Commit

Permalink
detokenize with prompt token ids (#1753)
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan authored Jun 22, 2024
1 parent 4067cb2 commit fd0cefb
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
for prompt, input_seqlen, output_seqlen in iter(
req_queue.get, [None, None, None]):
_per_token_latency_stats = [0] * (output_seqlen + 1)
state = DetokenizeState()
prev = time.perf_counter()
n_prev_token = 0

input_ids = self.tokenizer(prompt).input_ids
state = DetokenizeState(len(input_ids))

for outputs in model_inst.stream_infer(
session_id,
Expand All @@ -110,7 +110,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
sequence_start=True,
sequence_end=True,
stream_output=stream_output):
res, n_token = outputs.token_ids, outputs.num_token
res, n_token = input_ids + outputs.token_ids, outputs.num_token
_, state = self.tokenizer.detokenize_incrementally(res, state)
now = time.perf_counter()
if n_prev_token != n_token:
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ def run_chat(model_path: str,
' Please end the session.')
continue

print(f'{prompt} ', end='', flush=True)
state = DetokenizeState()
print(f'{prompt}', end='', flush=True)
state = DetokenizeState(len(input_ids))
gen_config.random_seed = seed
gen_config.stop_words = stop_words
for outputs in generator.stream_infer(session_id=session_id,
input_ids=input_ids,
gen_config=gen_config,
adapter_name=adapter_name):
res, tokens = outputs.token_ids, outputs.num_token
res, tokens = input_ids + outputs.token_ids, outputs.num_token
# decode res
response, state = tokenizer.detokenize_incrementally(
res, state)
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def generate(
else:
generator = await self.get_generator(False, session_id)
async with self.safe_run(session_id):
state = DetokenizeState()
state = DetokenizeState(len(input_ids))
response = ''
async for outputs in generator.async_stream_infer(
session_id=session_id,
Expand All @@ -627,7 +627,7 @@ async def generate(
sequence_end=sequence_end,
step=self.id2step[str(session_id)]):
# decode res
res, tokens = outputs.token_ids, outputs.num_token
res, tokens = input_ids + outputs.token_ids, outputs.num_token # noqa
if len(res) <= state.ids_offset:
continue

Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/serve/gradio/vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ def chat(chatbot, session, max_new_tokens, top_p, top_k, temperature):
top_k=top_k,
temperature=temperature)
step = session.step
state = DetokenizeState()
state = DetokenizeState(len(input_ids))
for outputs in generator.stream_infer(
session_id=session._session_id,
**inputs,
sequence_start=sequence_start,
step=step,
gen_config=gen_config,
stream_output=True):
res, tokens = outputs.token_ids, outputs.num_token
res, tokens = input_ids + outputs.token_ids, outputs.num_token
response, state = engine.tokenizer.detokenize_incrementally(
res,
state,
Expand Down
14 changes: 7 additions & 7 deletions lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,16 @@ def detokenize_incrementally(self,
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
prev_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids[:ids_offset],
skip_special_tokens=skip_special_tokens)
read_offset = len(prev_tokens)
if skip_special_tokens and new_tokens and new_tokens[
0] in tokenizer.all_special_ids:
read_offset = 1 # skip special token
output_tokens = new_tokens
prev_tokens = new_tokens
else:
# Put new_token_id in a list so skip_special_tokens is respected
output_tokens = prev_tokens + new_tokens
prev_tokens += new_tokens
read_offset = read_offset + 1 # skip special token

output_tokens = prev_tokens + new_tokens
prev_tokens += new_tokens
prefix_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def main(model_path: str,
sequence_start, sequence_end = True, True
step = 0

print(f'{prompt} ', end='', flush=True)
state = DetokenizeState()
print(f'{prompt}', end='', flush=True)
state = DetokenizeState(len(input_ids))
for outputs in generator.stream_infer(
session_id=session_id,
input_ids=[input_ids],
Expand All @@ -134,7 +134,7 @@ def main(model_path: str,
gen_config=gen_config,
ignore_eos=False,
random_seed=seed if nth_round == 1 else None):
res, tokens = outputs.token_ids, outputs.num_token
res, tokens = input_ids + outputs.token_ids, outputs.num_token
# decode res
response, state = tokenizer.detokenize_incrementally(
res, state=state)
Expand Down

0 comments on commit fd0cefb

Please sign in to comment.