From 0ed1e4d49eb433c50cb2ef78ae319507def77104 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 7 Aug 2023 12:48:17 +0800 Subject: [PATCH] Improve postprocessing in TIS serving by applying Incremental de-tokenizing (#197) * change to incremental decoding * update --- benchmark/profile_serving.py | 23 ++++++++++++----------- lmdeploy/serve/client.py | 1 - lmdeploy/serve/turbomind/chatbot.py | 29 ++++++++++++++++------------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py index 79e5c8365..b991ee475 100644 --- a/benchmark/profile_serving.py +++ b/benchmark/profile_serving.py @@ -55,7 +55,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): def warmup(tritonserver_addr: str, concurrency: int, output_seqlen: int, - warmup_round: int = 4): + warmup_round: int = 1): print('start to warmup ...') def _infer(_chatbot, session_id): @@ -87,7 +87,7 @@ def _infer(_chatbot, session_id): def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, - test_round: int, session_len: int): + session_len: int): start = time.perf_counter() with open(dataset_path) as f: dataset = json.load(f) @@ -119,14 +119,12 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, if samples > 0: filtered_dataset = random.sample(filtered_dataset, samples) - filtered_dataset *= test_round - random.shuffle(filtered_dataset) que = mp.Queue() for data in filtered_dataset: que.put(data) print(f'elapsed time for filtering: ' f'{round(time.perf_counter() - start, 2)} s') - return que + return que, len(filtered_dataset) def main(tritonserver_addr: str, @@ -134,11 +132,10 @@ def main(tritonserver_addr: str, dataset_path: str, concurrency: int = 1, session_len: int = 2048, - samples: int = 1000, - test_round: int = 1): + samples: int = 1000): warmup(tritonserver_addr, concurrency, session_len - 1) - req_que = read_dataset(tokenizer_path, dataset_path, samples, test_round, - session_len) + req_que, n_req = read_dataset(tokenizer_path, dataset_path, samples, + session_len) res_que = mp.Queue() procs = [] _start = time.perf_counter() @@ -168,13 +165,17 @@ def main(tritonserver_addr: str, first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0) first_token_latency_ave = np.mean(stats[:, 0], axis=0) - throughput = np.sum(stats[:, 1], axis=0) / elapsed_time + token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time + req_throughput = n_req / elapsed_time + print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' f'elapsed_time: {elapsed_time:.2f}s\n' f'first_token latency(min, max, ave): ' f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, ' f'{first_token_latency_ave:.2f}s\n' - f'throughput: {throughput:.2f} token/s\n{"-" * 50}') + f'token throughput: {token_throughput:.2f} token/s\n' + f'req throughput: {req_throughput} req/s\n' + f'{"-" * 50}\n') if __name__ == '__main__': diff --git a/lmdeploy/serve/client.py b/lmdeploy/serve/client.py index 23876ac82..eabe4d3c1 100644 --- a/lmdeploy/serve/client.py +++ b/lmdeploy/serve/client.py @@ -20,7 +20,6 @@ def main(tritonserver_addr: str, session_id: int = 1): Args: tritonserver_addr (str): the address in format "ip:port" of triton inference server - model_name (str): the name of the deployed model session_id (int): the identical id of a session """ log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py index ff253e49d..05bb24869 100644 --- a/lmdeploy/serve/turbomind/chatbot.py +++ b/lmdeploy/serve/turbomind/chatbot.py @@ -26,6 +26,7 @@ class Session: request_id: str = '' histories: str = '' # history conversations of the session sequence_length: int = 0 # the total generated token number in the session + sequence_offset: int = 0 # the new generated token offset in the session prompt: str = '' response: str = '' status: int = None # status of the session @@ -539,14 +540,15 @@ def stream_consumer(postprocess, res_queue, session, n_input_token, Yields: tuple: status, text, generated token number """ - offset = n_input_token + preseq_length + session.sequence_offset = n_input_token + preseq_length + sentinel = n_input_token + preseq_length status, res, n_token = None, '', 0 while True: result = res_queue.get() if result is None: status = StatusCode.TRITON_STREAM_END res = session.response - n_token = session.sequence_length - offset + n_token = session.sequence_length - sentinel session.status = StatusCode.TRITON_STREAM_END break if 'errcode' in result: @@ -569,30 +571,31 @@ def stream_consumer(postprocess, res_queue, session, n_input_token, output_ids = result.as_numpy('output_ids') session.sequence_length = sequence_length.squeeze() - sequence_length = sequence_length - offset + new_token_length = sequence_length - session.sequence_offset last_token_id = output_ids[-1][-1][session.sequence_length - 1] if last_token_id == eos_id: session.sequence_length = session.sequence_length - 1 - sequence_length = sequence_length - 1 + new_token_length = new_token_length - 1 output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) - sequence_length = sequence_length.reshape( - (1, sequence_length.shape[-1])) + new_token_length = new_token_length.reshape( + (1, new_token_length.shape[-1])) if profile_generation: yield (StatusCode.TRITON_STREAM_ING, 'postprocessing is ignored during profiling ' - 'token generation', sequence_length.squeeze()) + 'token generation', new_token_length.squeeze()) continue - output_str = postprocess(output_ids[:, :, offset:], - sequence_length) + output_str = postprocess( + output_ids[:, :, session.sequence_offset:], + new_token_length) + session.sequence_offset = session.sequence_length text = output_str[0].decode() if display: - new_text = text[len(session.response):] - print(new_text, end='', flush=True) - session.response = text + print(text, end='', flush=True) + session.response += text yield (StatusCode.TRITON_STREAM_ING, session.response, - sequence_length.squeeze()) + session.sequence_offset - sentinel) except Exception as e: logger.error(f'catch exception: {e}')