From 529e56bdf50250bde8c6c89f2356875c2dd1a06a Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 8 Nov 2023 11:54:05 +0800 Subject: [PATCH] fix benchmark serving computation mistake (#630) * fix benchmark serving computation mistake * fix timestamps computations * remove speed up * no mp * mp seems faster? * remove * update * remove * fix * update * update print log * typo * print fist token latency only stream==True * remove renew_session * update AsyncEngine --- benchmark/profile_restful_api.py | 102 +++++++++++++------- benchmark/profile_serving.py | 65 ++++++++----- benchmark/profile_throughput.py | 151 ++++++++++++++++++++---------- lmdeploy/serve/async_engine.py | 38 ++++---- lmdeploy/serve/openai/protocol.py | 2 - 5 files changed, 233 insertions(+), 125 deletions(-) diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py index b827db16d..394c7ec1b 100644 --- a/benchmark/profile_restful_api.py +++ b/benchmark/profile_restful_api.py @@ -1,47 +1,56 @@ import json -import multiprocessing as mp import random import time +from queue import Queue +from threading import Thread import fire import numpy as np from lmdeploy.serve.openai.api_client import get_streaming_response from lmdeploy.tokenizer import Tokenizer -from lmdeploy.utils import get_logger -def infer(server_addr: str, session_id: int, req_queue: mp.Queue, - res_que: mp.Queue): +def infer(server_addr: str, session_id: int, req_queue: Queue, res_que: Queue, + stream_output: bool): stats = [] - while not req_queue.empty(): - prompt, input_seqlen, output_seqlen = req_queue.get() - get_logger('profile_restful_api').info( - f'request info: session {session_id}, ' - f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}') + for prompt, input_seqlen, output_seqlen in iter(req_queue.get, + [None, None, None]): + if prompt is None: + break timestamps = [] tokens = [] - start = time.perf_counter() + timestamps.append(time.perf_counter()) for res, token, status in get_streaming_response( prompt, server_addr, session_id, request_output_len=output_seqlen, - interactive_mode=False): + interactive_mode=False, + ignore_eos=True, + stream=stream_output): timestamps.append(time.perf_counter()) tokens.append(token) - first_token_latency = timestamps[1] - start - token_latency = timestamps[-1] - timestamps[0] - token = tokens[-1] - tokens[0] - stats.append([first_token_latency, token, token_latency]) + first_token_latency = np.round(timestamps[1] - timestamps[0], 3) + token_latency = np.round(timestamps[-1] - timestamps[0], 3) + completion_tokens = tokens[-1] + total_tokens = tokens[-1] + input_seqlen + stats.append([ + first_token_latency, completion_tokens, output_seqlen, + total_tokens, token_latency + ]) + print(f'session {session_id}: ' + f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}, ' + f'completion_tokens {completion_tokens}') res_que.put((session_id, stats)) def warmup(server_addr: str, concurrency: int, output_seqlen: int, - warmup_round: int = 1): + warmup_round: int = 1, + stream_output: bool = False): print('start to warmup ...') def _infer(server_addr, session_id): @@ -50,13 +59,15 @@ def _infer(server_addr, session_id): server_addr, session_id, request_output_len=output_seqlen, - interactive_mode=False): + interactive_mode=False, + stream=stream_output, + ignore_eos=True): continue _start = time.perf_counter() procs = [] for i in range(concurrency): - proc = mp.Process(target=_infer, args=(server_addr, i + 1)) + proc = Thread(target=_infer, args=(server_addr, i + 1)) procs.append(proc) proc.start() for proc in procs: @@ -79,6 +90,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, print(f'elapsed time for read data: ' f'{round(time.perf_counter() - start, 2)} s') + print('start tokenization. This takes a while, please wait...') start = time.perf_counter() tokenizer = Tokenizer(tokenizer_path) prompts_token_lens = [len(tokenizer.encode(prompt)) for prompt in prompts] @@ -100,9 +112,10 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, if samples > 0: filtered_dataset = random.sample(filtered_dataset, samples) - que = mp.Queue() + que = Queue() for data in filtered_dataset: que.put(data) + que.put((None, None, None)) print(f'elapsed time for filtering: ' f'{round(time.perf_counter() - start, 2)} s') return que, len(filtered_dataset) @@ -113,17 +126,20 @@ def main(server_addr: str, dataset_path: str, concurrency: int = 1, session_len: int = 2048, - samples: int = 1000): + samples: int = 1000, + stream_output: bool = False): api_url = server_addr + '/v1/chat/interactive' - warmup(api_url, concurrency, session_len - 1) + warmup(api_url, concurrency, session_len - 1, 4, stream_output) req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len) - res_que = mp.Queue() + for i in range(concurrency): + req_queue.put([None, None, None]) + res_que = Queue() procs = [] _start = time.perf_counter() for i in range(concurrency): - proc = mp.Process(target=infer, - args=(api_url, i + 1, req_queue, res_que)) + proc = Thread(target=infer, + args=(api_url, i + 1, req_queue, res_que, stream_output)) procs.append(proc) proc.start() for proc in procs: @@ -138,22 +154,40 @@ def main(server_addr: str, f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') stats.append(np.array(_stats)) - stats = np.concatenate(stats).reshape(-1, 3) + stats = np.concatenate(stats).reshape(-1, 5) first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0) first_token_latency_ave = np.mean(stats[:, 0], axis=0) - token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time - req_throughput = n_req / elapsed_time + completion_tokens = np.sum(stats[:, 1], axis=0) + request_output_tokens = np.sum(stats[:, 2], axis=0) + total_tokens = np.sum(stats[:, 3], axis=0) + prompt_tokens = total_tokens - completion_tokens + completion_token_throughput = completion_tokens / elapsed_time + total_token_throughput = total_tokens / elapsed_time + rqs = n_req / elapsed_time + rqm = rqs * 60 + + if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: + print(f'Did not generate requested number of tokens. ' + f'Request {request_output_tokens:.0f}, ' + f'but got {completion_tokens:.0f}') print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' - f'elapsed_time: {elapsed_time:.2f}s\n' - f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, ' - f'{first_token_latency_ave:.2f}s\n' - f'token throughput: {token_throughput:.2f} token/s\n' - f'req throughput: {req_throughput:.2f} req/s\n' - f'{"-" * 50}\n') + f'elapsed_time: {elapsed_time:.3f}s\n') + if stream_output: + print(f'first_token latency(min, max, ave): ' + f'{first_token_latency_min:.3f}s, ' + f'{first_token_latency_max:.3f}s, ' + f'{first_token_latency_ave:.3f}s\n') + print( + f'number of prompt tokens: {prompt_tokens:.0f}\n' + f'number of completion tokens: {completion_tokens:.0f}\n' + f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa + f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa + f'RPS (request per second): {rqs:.3f} req/s\n' + f'RPM (request per minute): {rqm:.3f} req/min\n' + f'{"-" * 50}\n') if __name__ == '__main__': diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py index 4580757ee..ee23452d8 100644 --- a/benchmark/profile_serving.py +++ b/benchmark/profile_serving.py @@ -17,7 +17,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): [None, None, None]): timestamps = [] tokens = [] - start = time.perf_counter() + timestamps.append(time.perf_counter()) for status, res, token in chatbot.stream_infer( session_id, prompt, @@ -26,13 +26,17 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): sequence_end=True): timestamps.append(time.perf_counter()) tokens.append(token) - - first_token_latency = np.round(timestamps[1] - start, 3) + first_token_latency = np.round(timestamps[1] - timestamps[0], 3) token_latency = np.round(timestamps[-1] - timestamps[0], 3) - token = tokens[-1] - tokens[0] - stats.append([first_token_latency, token, token_latency]) + completion_tokens = tokens[-1] + total_tokens = tokens[-1] + input_seqlen + stats.append([ + first_token_latency, completion_tokens, output_seqlen, + total_tokens, token_latency + ]) print(f'session {session_id}: ' - f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}') + f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}, ' + f'completion_tokens {completion_tokens}') res_que.put((session_id, stats)) @@ -84,6 +88,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, completions = [completion for _, completion in dataset] print(f'elapsed time for read data: ' f'{round(time.perf_counter() - start, 2)} s') + print('start tokenization. This takes a while, please wait...') start = time.perf_counter() tokenizer = Tokenizer(tokenizer_path) @@ -124,7 +129,6 @@ def main(tritonserver_addr: str, res_que = mp.Queue() procs = [] - _start = time.perf_counter() for i in range(concurrency): chatbot = Chatbot(tritonserver_addr=tritonserver_addr, display=False, @@ -134,13 +138,15 @@ def main(tritonserver_addr: str, proc = mp.Process(target=infer, args=(chatbot, i + 1, req_que, res_que)) procs.append(proc) - proc.start() # read data and put it to queue n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len, req_que) for i in range(concurrency): req_que.put([None, None, None]) + _start = time.perf_counter() + for proc in procs: + proc.start() stats = [] for i in range(concurrency): @@ -149,27 +155,42 @@ def main(tritonserver_addr: str, f'session {session_id}: processed reqs {len(_stats)}, ' f'stats: \n{_stats}\n{"-" * 50}\n') stats.append(np.array(_stats)) - _end = time.perf_counter() + elapsed_time = _end - _start - stats = np.concatenate(stats).reshape(-1, 3) + stats = np.concatenate(stats).reshape(-1, 5) first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0) first_token_latency_ave = np.mean(stats[:, 0], axis=0) - token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time - req_throughput = n_req / elapsed_time - - print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' - f'elapsed_time: {elapsed_time:.3f}s\n' - f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, ' - f'{first_token_latency_ave:.3f}s\n' - f'token throughput: {token_throughput:.3f} token/s\n' - f'req throughput: {req_throughput:.3f} req/s\n' - f'{"-" * 50}\n') - + completion_tokens = np.sum(stats[:, 1], axis=0) + request_output_tokens = np.sum(stats[:, 2], axis=0) + total_tokens = np.sum(stats[:, 3], axis=0) + prompt_tokens = total_tokens - completion_tokens + completion_token_throughput = completion_tokens / elapsed_time + total_token_throughput = total_tokens / elapsed_time + rqs = n_req / elapsed_time + rqm = rqs * 60 + + if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: + print(f'Did not generate requested number of tokens. ' + f'Request {request_output_tokens:.0f}, ' + f'but got {completion_tokens:.0f}') + + print( + f'\n{"-" * 50}\nconcurrency: {concurrency}\n' + f'elapsed_time: {elapsed_time:.3f}s\n' + f'first_token latency(min, max, ave): ' + f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, ' + f'{first_token_latency_ave:.3f}s\n' + f'number of prompt tokens: {prompt_tokens:.0f}\n' + f'number of completion tokens: {completion_tokens:.0f}\n' + f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa + f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa + f'RPS (request per second): {rqs:.3f} req/s\n' + f'RPM (request per minute): {rqm:.3f} req/min\n' + f'{"-" * 50}\n') for proc in procs: proc.join() diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 8fc5090f7..77402b559 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -7,6 +7,7 @@ from typing import List, Tuple import fire +import numpy as np from lmdeploy.tokenizer import Tokenizer from lmdeploy.turbomind import TurboMind @@ -24,8 +25,7 @@ def sample_requests( dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], - data['conversations'][1]['value']) - for data in dataset][:num_requests * 2] # speed up encoding + data['conversations'][1]['value']) for data in dataset] # Tokenize the prompts and completions. prompts = [prompt for prompt, _ in dataset] @@ -64,80 +64,131 @@ def __init__(self, model_path: str, tp: int = 1): self.tm_model = tm_model self.tokenizer = tokenizer - def _inference(self, queue, session_id: int): - + def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, + stream_output: bool): model_inst = self.tm_model.create_instance() - while True: - request = queue.get() - if request is None: - # stop signal - queue.put(None) - return - else: - prompt, _, output_seqlen = request - input_ids = self.tokenizer.encode(prompt) - - for outputs in model_inst.stream_infer( - session_id, - input_ids=input_ids, - request_output_len=output_seqlen, - temperature=1.0, - top_p=1.0, - sequence_start=True, - sequence_end=True, - ignore_eos=True): - res, tokens = outputs[0] - self.tokenizer.decode(res) - - def process_request(self, requests, concurrency: int = 1): - q = Queue() + stats = [] + timestamps = [] + tokens = [] + timestamps.append(time.perf_counter()) + for prompt, input_seqlen, output_seqlen in iter( + req_queue.get, [None, None, None]): + input_ids = self.tokenizer.encode(prompt) + offset = 0 + for outputs in model_inst.stream_infer( + session_id, + input_ids=input_ids, + request_output_len=output_seqlen, + temperature=1.0, + top_p=1.0, + sequence_start=True, + sequence_end=True, + ignore_eos=True, + stream_output=stream_output): + res, token = outputs[0] + self.tokenizer.decode(res, offset) + offset = token + timestamps.append(time.perf_counter()) + tokens.append(token) + first_token_latency = np.round(timestamps[1] - timestamps[0], 3) + token_latency = np.round(timestamps[-1] - timestamps[0], 3) + completion_tokens = tokens[-1] + total_tokens = tokens[-1] + len(input_ids) + stats.append([ + first_token_latency, completion_tokens, output_seqlen, + total_tokens, token_latency + ]) + print( + f'session {session_id}: ' + f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}, ' + f'completion_tokens {completion_tokens}') + res_queue.put((session_id, stats)) + + def process_request(self, + requests, + concurrency: int = 1, + stream_output: bool = True): + res_queue = Queue() + req_queue = Queue() threads = [] + # feed request to q + for req in requests: + req_queue.put(req) + for i in range(concurrency): + req_queue.put([None, None, None]) + start = time.time() # start threads for i in range(concurrency): - t = Thread(target=self._inference, args=(q, i)) + t = Thread(target=self._inference, + args=(req_queue, res_queue, i, stream_output)) t.start() threads.append(t) - # feed request to q - for req in requests: - q.put(req) - - q.put(None) - # wait for finish for t in threads: t.join() - end = time.time() - - return end - start + elapsed_time = time.time() - start + + stats = [] + while not res_queue.empty(): + session_id, _stats = res_queue.get() + print(f'\n{"-" * 50}\n' + f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') + stats.append(np.array(_stats)) + + stats = np.concatenate(stats).reshape(-1, 5) + + first_token_latency_min = np.min(stats[:, 0], axis=0) + first_token_latency_max = np.max(stats[:, 0], axis=0) + first_token_latency_ave = np.mean(stats[:, 0], axis=0) + completion_tokens = np.sum(stats[:, 1], axis=0) + request_output_tokens = np.sum(stats[:, 2], axis=0) + total_tokens = np.sum(stats[:, 3], axis=0) + prompt_tokens = total_tokens - completion_tokens + completion_token_throughput = completion_tokens / elapsed_time + total_token_throughput = total_tokens / elapsed_time + rqs = len(requests) / elapsed_time + rqm = rqs * 60 + + if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: + print(f'Did not generate requested number of tokens. ' + f'Request {request_output_tokens:.0f}, ' + f'but got {completion_tokens:.0f}') + + print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' + f'elapsed_time: {elapsed_time:.3f}s\n') + if stream_output: + print(f'first_token latency(min, max, ave): ' + f'{first_token_latency_min:.3f}s, ' + f'{first_token_latency_max:.3f}s, ' + f'{first_token_latency_ave:.3f}s\n') + print( + f'number of prompt tokens: {prompt_tokens:.0f}\n' + f'number of completion tokens: {completion_tokens:.0f}\n' + f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa + f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa + f'RPS (request per second): {rqs:.3f} req/s\n' + f'RPM (request per minute): {rqm:.3f} req/min\n' + f'{"-" * 50}\n') def main(dataset: str, model_path: str, concurrency: int = 1, num_prompts: int = 1000, - tp: int = 1): + tp: int = 1, + stream_output: bool = True): engine = Engine(model_path, tp=tp) tokenizer = engine.tokenizer requests = sample_requests(dataset, num_prompts, tokenizer) - elapsed_time = engine.process_request(requests, concurrency) - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len in requests) - total_num_out_tokens = sum(output_len for _, _, output_len in requests) - print(f'Throughput requests: {len(requests) / elapsed_time:.2f} req/s') - print( - f'Throughput requests: {len(requests) * 60 / elapsed_time:.2f} req/min' - ) - print(f'Throughput tokens: {total_num_tokens / elapsed_time:.2f} tokens/s') - print('Throughput tokens(output only):' - f'{total_num_out_tokens / elapsed_time:.2f} tokens/s') + engine.process_request(requests, concurrency, stream_output) if __name__ == '__main__': diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 98c53bbf2..5abae0d97 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -107,6 +107,7 @@ def batch_infer(self, temperature=0.8, repetition_penalty=1.0, ignore_eos=False, + do_preprocess=True, **kwargs): """Inference a batch of prompts. @@ -122,6 +123,7 @@ def batch_infer(self, repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty ignore_eos (bool): indicator for ignoring eos + do_preprocess (bool): whether pre-process the messages. """ assert isinstance(prompts, List), 'prompts should be a list' batch_size = len(prompts) @@ -139,7 +141,9 @@ def batch_infer(self, top_p=top_p, temperature=temperature, ignore_eos=ignore_eos, - repetition_penalty=repetition_penalty)) + repetition_penalty=repetition_penalty, + do_preprocess=do_preprocess, + **kwargs)) async def _inner_call(i, generator): async for out in generator: @@ -153,22 +157,22 @@ async def gather(): return outputs async def generate( - self, - messages, - session_id, - stream_response=True, - sequence_start=True, - sequence_end=True, # no interactive mode by default - step=0, - request_output_len=512, - stop=False, - top_k=40, - top_p=0.8, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=False, - do_preprocess=True, - ): + self, + messages, + session_id, + stream_response=True, + sequence_start=True, + sequence_end=True, # no interactive mode by default + step=0, + request_output_len=512, + stop=False, + top_k=40, + top_p=0.8, + temperature=0.8, + repetition_penalty=1.0, + ignore_eos=False, + do_preprocess=True, + **kwargs): """Generate responses. Args: diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index b39e6cbc8..bee2e2c91 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -71,8 +71,6 @@ class ChatCompletionRequest(BaseModel): # additional argument of lmdeploy repetition_penalty: Optional[float] = 1.0 session_id: Optional[int] = -1 - renew_session: Optional[ - bool] = False # lagecy and useless, will be removed ignore_eos: Optional[bool] = False