diff --git a/benchmark/benchmark_13b.sh b/benchmark/benchmark_13b.sh new file mode 100755 index 0000000000..cdf7355471 --- /dev/null +++ b/benchmark/benchmark_13b.sh @@ -0,0 +1,92 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of llama2-13b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=1 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert llama2 ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then + exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 500 +crudini --set ${config_path} llama max_batch_size 128 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 16 32 64 \ + --csv ${output_path}/generation.csv +} + +################################# BENCHMARK AFTER TUNING GEMM ################################# +# tune gemm +head_num=$(crudini --get "${config_path}" llama head_num) +size_per_head=$(crudini --get "${config_path}" llama size_per_head) +vocab_size=$(crudini --get "${config_path}" llama vocab_size) +inter_size=$(crudini --get "${config_path}" llama inter_size) +tensor_para_size=$(crudini --get "${config_path}" llama tensor_para_size) +max_batch_size=$(crudini --get "${config_path}" llama max_batch_size) + +echo $head_num, $size_per_head, $vocab_size, $inter_size, $tensor_para_size, $max_batch_size + +python3 -m lmdeploy.turbomind.generate_gemm_config \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --vocab_size ${vocab_size} \ + --inter_size ${inter_size} \ + --tensor_para_size ${tensor_para_size} \ + --max_batch_size ${max_batch_size} + +output_path="${workspace_dir}"/output/"${model_foldername}"-tunned-gemm-tp"${tp}" +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} + +mv gemm_config.in ${output_path} diff --git a/benchmark/benchmark_20b.sh b/benchmark/benchmark_20b.sh new file mode 100755 index 0000000000..152978c453 --- /dev/null +++ b/benchmark/benchmark_20b.sh @@ -0,0 +1,92 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of internlm-20b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=2 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert internlm-20b ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then + exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 700 +crudini --set ${config_path} llama max_batch_size 128 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 16 32 64 \ + --csv ${output_path}/generation.csv +} + +################################# BENCHMARK AFTER TUNING GEMM ################################# +# tune gemm +head_num=$(crudini --get "${config_path}" llama head_num) +size_per_head=$(crudini --get "${config_path}" llama size_per_head) +vocab_size=$(crudini --get "${config_path}" llama vocab_size) +inter_size=$(crudini --get "${config_path}" llama inter_size) +tensor_para_size=$(crudini --get "${config_path}" llama tensor_para_size) +max_batch_size=$(crudini --get "${config_path}" llama max_batch_size) + +echo $head_num, $size_per_head, $vocab_size, $inter_size, $tensor_para_size, $max_batch_size + +python3 -m lmdeploy.turbomind.generate_gemm_config \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --vocab_size ${vocab_size} \ + --inter_size ${inter_size} \ + --tensor_para_size ${tensor_para_size} \ + --max_batch_size ${max_batch_size} + +output_path="${workspace_dir}"/output/"${model_foldername}"-tunned-gemm-tp"${tp}" +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} + +cp gemm_config.in ${output_path} diff --git a/benchmark/benchmark_70b.sh b/benchmark/benchmark_70b.sh new file mode 100755 index 0000000000..0f23033064 --- /dev/null +++ b/benchmark/benchmark_70b.sh @@ -0,0 +1,71 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of llama2-70b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=4 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert llama2 ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then + exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 4000 +crudini --set ${config_path} llama max_batch_size 256 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128 256) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 64 128 256 \ + --csv ${output_path}/generation.csv +} + +output_path="${workspace_dir}"/output/"${model_foldername}"-tp"${tp}" +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} diff --git a/benchmark/benchmark_7b.sh b/benchmark/benchmark_7b.sh new file mode 100755 index 0000000000..bf702d56ab --- /dev/null +++ b/benchmark/benchmark_7b.sh @@ -0,0 +1,93 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of llama2-7b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=1 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert llama2 ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then +exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 1000 +crudini --set ${config_path} llama max_batch_size 128 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 16 32 64 \ + --csv ${output_path}/generation.csv +} + +################################# BENCHMARK AFTER TUNING GEMM ################################# +output_path="${workspace_dir}"/output/"${model_foldername}"-tunned-gemm-tp"${tp}" + +# tune gemm +head_num=$(crudini --get "${config_path}" llama head_num) +size_per_head=$(crudini --get "${config_path}" llama size_per_head) +vocab_size=$(crudini --get "${config_path}" llama vocab_size) +inter_size=$(crudini --get "${config_path}" llama inter_size) +tensor_para_size=$(crudini --get "${config_path}" llama tensor_para_size) +max_batch_size=$(crudini --get "${config_path}" llama max_batch_size) + +echo $head_num, $size_per_head, $vocab_size, $inter_size, $tensor_para_size, $max_batch_size + +python3 -m lmdeploy.turbomind.generate_gemm_config \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --vocab_size ${vocab_size} \ + --inter_size ${inter_size} \ + --tensor_para_size ${tensor_para_size} \ + --max_batch_size ${max_batch_size} + +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} + +mv gemm_config.in ${output_path} diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 41e04c16fc..d86693afb9 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -29,14 +29,14 @@ def infer(model, session_id: int, input_ids: List, output_seqlen: int, for _ in range(test_round): token_latency_stats = [0] * (output_seqlen + 1) prev = time.perf_counter() - n_pre_token = 0 + n_prev_token = 0 """ The iterator provided by `stream_infer` denotes the number of generated tokens so far, which is represented by the variable `n_token`. Please note that `n_token` is not a continuous value. In other words, during the iteration, its value might be 5, 7, 8, 16, and so on, rather than 1, 2, 3, 4, etc. So, it is quite difficult to get the latency of each generated token. - As a work-around, we set the latency `new-prev` of each iteration to the first token of + As a work-around, we set the latency `now-prev` of each iteration to the first token of the new generated tokens, and leave the latency of the rest tokens being 0. For example, in the first iteration, 5 tokens are generated. The time elapsing in this iteration `now-prev` is set to the latency of first token of @@ -54,9 +54,9 @@ def infer(model, session_id: int, input_ids: List, output_seqlen: int, temperature=temperature): _, n_token = outputs[0] now = time.perf_counter() - if n_pre_token != n_token: - token_latency_stats[n_pre_token] = np.round(now - prev, 3) - n_pre_token = n_token + if n_prev_token != n_token: + token_latency_stats[n_prev_token] = np.round(now - prev, 3) + n_prev_token = n_token prev = now if session_id == 1: pbar.update(1) diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py index 1e0d2388ee..b16dfcd482 100644 --- a/benchmark/profile_restful_api.py +++ b/benchmark/profile_restful_api.py @@ -158,8 +158,8 @@ def process_request(self, prompt_tokens = total_tokens - completion_tokens completion_token_throughput = completion_tokens / elapsed_time total_token_throughput = total_tokens / elapsed_time - rqs = len(requests) / elapsed_time - rqm = rqs * 60 + rps = len(requests) / elapsed_time + rpm = rps * 60 if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: print(f'Did not generate requested number of tokens. ' @@ -178,8 +178,8 @@ def process_request(self, f'number of completion tokens: {completion_tokens:.0f}\n' f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa - f'RPS (request per second): {rqs:.3f} req/s\n' - f'RPM (request per minute): {rqm:.3f} req/min\n' + f'RPS (request per second): {rps:.3f} req/s\n' + f'RPM (request per minute): {rpm:.3f} req/min\n' f'{"-" * 50}\n') if self.csv: @@ -190,7 +190,7 @@ def process_request(self, 'completion_tokens', '1st_token_latency(min)(s)', '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)', 'output token thr(tokens/s', 'total token thr(token/s)', - 'RPM' + 'RPS', 'RPM' ]) writer.writerow([ concurrency, @@ -199,7 +199,7 @@ def process_request(self, f'{first_token_latency_max:.3f}' if stream_output else '-', f'{first_token_latency_ave:.3f}' if stream_output else '-', f'{completion_token_throughput:.3f}', - f'{total_token_throughput:.3f}', f'{rqm:.3f}' + f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}' ]) diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py index f8daafec87..154751737e 100644 --- a/benchmark/profile_serving.py +++ b/benchmark/profile_serving.py @@ -163,8 +163,8 @@ def process_request(self, prompt_tokens = total_tokens - completion_tokens completion_token_throughput = completion_tokens / elapsed_time total_token_throughput = total_tokens / elapsed_time - rqs = len(requests) / elapsed_time - rqm = rqs * 60 + rps = len(requests) / elapsed_time + rpm = rps * 60 if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: print(f'Did not generate requested number of tokens. ' @@ -183,8 +183,8 @@ def process_request(self, f'number of completion tokens: {completion_tokens:.0f}\n' f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa - f'RPS (request per second): {rqs:.3f} req/s\n' - f'RPM (request per minute): {rqm:.3f} req/min\n' + f'RPS (request per second): {rps:.3f} req/s\n' + f'RPM (request per minute): {rpm:.3f} req/min\n' f'{"-" * 50}\n') if self.csv: @@ -195,7 +195,7 @@ def process_request(self, 'completion_tokens', '1st_token_latency(min)(s)', '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)', 'output token thr(tokens/s', 'total token thr(token/s)', - 'RPM' + 'RPS', 'RPM' ]) writer.writerow([ concurrency, @@ -204,7 +204,7 @@ def process_request(self, f'{first_token_latency_max:.3f}' if stream_output else '-', f'{first_token_latency_ave:.3f}' if stream_output else '-', f'{completion_token_throughput:.3f}', - f'{total_token_throughput:.3f}', f'{rqm:.3f}' + f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}' ]) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index cda390d727..1cd4353ec9 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -75,13 +75,14 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, stream_output: bool): model_inst = self.tm_model.create_instance() stats = [] + # get each generated token's latency + per_token_latency_stats = [] for prompt, input_seqlen, output_seqlen in iter( req_queue.get, [None, None, None]): + _per_token_latency_stats = [0] * (output_seqlen + 1) offset = 0 - timestamps = [] - tokens = [] - - timestamps.append(time.perf_counter()) + prev = time.perf_counter() + n_prev_token = 0 input_ids = self.tokenizer(prompt).input_ids for outputs in model_inst.stream_infer( @@ -94,25 +95,32 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, sequence_end=True, ignore_eos=True, stream_output=stream_output): - res, token = outputs[0] + res, n_token = outputs[0] self.tokenizer.decode(res, offset) - offset = token - timestamps.append(time.perf_counter()) - tokens.append(token) - first_token_latency = np.round(timestamps[1] - timestamps[0], 3) - token_latency = np.round(timestamps[-1] - timestamps[0], 3) - completion_tokens = tokens[-1] - assert output_seqlen <= completion_tokens <= output_seqlen + 1, \ + offset = n_token + now = time.perf_counter() + if n_prev_token != n_token: + _per_token_latency_stats[n_prev_token] = np.round( + now - prev, 3) + n_prev_token = n_token + prev = now + + assert output_seqlen <= n_token <= output_seqlen + 1, \ f'Error. session_id({session_id}) request {output_seqlen} ' \ - f'tokens, but generate {completion_tokens} tokens.\n' \ + f'tokens, but generate {n_token} tokens.\n' \ f'prompt: {prompt}' - total_tokens = tokens[-1] + input_seqlen + + first_token_latency = _per_token_latency_stats[0] + completion_tokens = n_token + total_tokens = n_token + input_seqlen stats.append([ first_token_latency, completion_tokens, output_seqlen, - total_tokens, token_latency + total_tokens ]) + # skip the first token latency + per_token_latency_stats.append(_per_token_latency_stats[1:]) self.pbar.update(1) - res_queue.put((session_id, stats)) + res_queue.put((session_id, stats, per_token_latency_stats)) def process_request(self, requests, @@ -146,13 +154,15 @@ def process_request(self, elapsed_time = time.time() - start stats = [] + per_token_latency_stats = [] while not res_queue.empty(): - session_id, _stats = res_queue.get() - # print(f'\n{"-" * 50}\n' - # f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') + session_id, _stats, _per_token_latency_stats = res_queue.get() stats.append(np.array(_stats)) - - stats = np.concatenate(stats).reshape(-1, 5) + per_token_latency_stats += [ + item for sublist in _per_token_latency_stats + for item in sublist + ] + stats = np.concatenate(stats).reshape(-1, 4) first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0) @@ -162,23 +172,33 @@ def process_request(self, prompt_tokens = total_tokens - completion_tokens completion_token_throughput = completion_tokens / elapsed_time total_token_throughput = total_tokens / elapsed_time - rqs = len(requests) / elapsed_time - rqm = rqs * 60 + rps = len(requests) / elapsed_time + rpm = rps * 60 + + per_token_latency_stats.sort() + percentiles = [ + np.round( + per_token_latency_stats[int(percent * + len(per_token_latency_stats))], 3) + for percent in [0.5, 0.75, 0.95, 0.99] + ] print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' f'elapsed_time: {elapsed_time:.3f}s\n') if stream_output: - print(f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.3f}s, ' - f'{first_token_latency_max:.3f}s, ' - f'{first_token_latency_ave:.3f}s\n') + print(f'first token latency(s)(min, max, ave): ' + f'{first_token_latency_min:.3f}, ' + f'{first_token_latency_max:.3f}, ' + f'{first_token_latency_ave:.3f}') + print(f'per-token latency(s) percentile(50, 75, 95, 99): ' + f'{percentiles}\n') print( f'number of prompt tokens: {prompt_tokens:.0f}\n' f'number of completion tokens: {completion_tokens:.0f}\n' f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa - f'RPS (request per second): {rqs:.3f} req/s\n' - f'RPM (request per minute): {rqm:.3f} req/min\n' + f'RPS (request per second): {rps:.3f} req/s\n' + f'RPM (request per minute): {rpm:.3f} req/min\n' f'{"-" * 50}\n') if self.csv: @@ -188,8 +208,9 @@ def process_request(self, 'batch', 'num_promts', 'prompt_tokens', 'completion_tokens', '1st_token_latency(min)(s)', '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)', - 'output token thr(tokens/s', 'total token thr(token/s)', - 'RPM' + 'percentile50(s)', 'percentile75(s)', 'percentile95(s)', + 'percentile99(s)', 'output token thr(tokens/s)', + 'total token thr(token/s)', 'RPS', 'RPM' ]) writer.writerow([ concurrency, @@ -197,8 +218,12 @@ def process_request(self, f'{first_token_latency_min:.3f}' if stream_output else '-', f'{first_token_latency_max:.3f}' if stream_output else '-', f'{first_token_latency_ave:.3f}' if stream_output else '-', + f'{percentiles[0]:.3f}' if stream_output else '-', + f'{percentiles[1]:.3f}' if stream_output else '-', + f'{percentiles[2]:.3f}' if stream_output else '-', + f'{percentiles[3]:.3f}' if stream_output else '-', f'{completion_token_throughput:.3f}', - f'{total_token_throughput:.3f}', f'{rqm:.3f}' + f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}' ]) diff --git a/docs/en/benchmark/a100_fp16.md b/docs/en/benchmark/a100_fp16.md new file mode 100644 index 0000000000..b157137868 --- /dev/null +++ b/docs/en/benchmark/a100_fp16.md @@ -0,0 +1,134 @@ +# Benchmark on A100 (FP16) + +All the following results are tested on (x8) A100-80G CUDA 11.8. + +The tested lmdeploy version is `v0.1.0a1`. + +The commands provided below facilitate benchmarking both [static inference performance](#static-inference-benchmark) and [request throughput](#request-throughput-benchmark) on an A100-80G(x8) for models of various sizes. + +```shell +bash benchmark/benchmark_7b.sh +bash benchmark/benchmark_13b.sh +bash benchmark/benchmark_20b.sh +bash benchmark/benchmark_70b.sh +``` + +## Static Inference Benchmark + +FTL: **F**irst **T**oken **L**atency + +### llama2-7b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 1 | 1 | 128 | 100.02 | 76.55 | 0.011 | 0.01 | 0.011 | 0.009 | 0.009 | 0.01 | 0.011 | +| 1 | 1 | 128 | 128 | 102.21 | 76.59 | 0.022 | 0.022 | 0.022 | 0.01 | 0.01 | 0.01 | 0.01 | +| 1 | 1 | 128 | 2048 | 98.92 | 76.59 | 0.022 | 0.022 | 0.022 | 0.01 | 0.01 | 0.01 | 0.01 | +| 1 | 1 | 2048 | 128 | 86.1 | 76.77 | 0.139 | 0.139 | 0.14 | 0.01 | 0.01 | 0.01 | 0.011 | +| 1 | 1 | 2048 | 2048 | 93.78 | 76.77 | 0.14 | 0.139 | 0.141 | 0.011 | 0.011 | 0.011 | 0.011 | +| 16 | 1 | 1 | 128 | 1504.72 | 76.59 | 0.021 | 0.011 | 0.031 | 0.01 | 0.011 | 0.011 | 0.013 | +| 16 | 1 | 128 | 128 | 1272.47 | 76.77 | 0.129 | 0.023 | 0.149 | 0.011 | 0.011 | 0.012 | 0.014 | +| 16 | 1 | 128 | 2048 | 1010.62 | 76.77 | 0.13 | 0.023 | 0.144 | 0.015 | 0.018 | 0.02 | 0.021 | +| 16 | 1 | 2048 | 128 | 348.87 | 78.3 | 2.897 | 0.143 | 3.576 | 0.02 | 0.021 | 0.022 | 0.025 | +| 16 | 1 | 2048 | 2048 | 601.63 | 78.3 | 2.678 | 0.142 | 3.084 | 0.025 | 0.028 | 0.03 | 0.031 | +| 32 | 1 | 1 | 128 | 2136.73 | 76.62 | 0.079 | 0.014 | 0.725 | 0.011 | 0.012 | 0.013 | 0.021 | +| 32 | 1 | 128 | 128 | 2125.47 | 76.99 | 0.214 | 0.022 | 0.359 | 0.012 | 0.013 | 0.014 | 0.035 | +| 32 | 1 | 128 | 2048 | 1462.12 | 76.99 | 0.2 | 0.026 | 0.269 | 0.021 | 0.026 | 0.031 | 0.033 | +| 32 | 1 | 2048 | 128 | 450.43 | 78.3 | 4.288 | 0.143 | 5.267 | 0.031 | 0.032 | 0.034 | 0.161 | +| 32 | 1 | 2048 | 2048 | 733.34 | 78.34 | 4.118 | 0.19 | 5.429 | 0.04 | 0.045 | 0.05 | 0.053 | +| 64 | 1 | 1 | 128 | 4154.81 | 76.71 | 0.042 | 0.013 | 0.21 | 0.012 | 0.018 | 0.028 | 0.041 | +| 64 | 1 | 128 | 128 | 3024.07 | 77.43 | 0.44 | 0.026 | 1.061 | 0.014 | 0.018 | 0.026 | 0.158 | +| 64 | 1 | 128 | 2048 | 1852.06 | 77.96 | 0.535 | 0.027 | 1.231 | 0.03 | 0.041 | 0.048 | 0.053 | +| 64 | 1 | 2048 | 128 | 493.46 | 78.4 | 6.59 | 0.142 | 16.235 | 0.046 | 0.049 | 0.055 | 0.767 | +| 64 | 1 | 2048 | 2048 | 755.65 | 78.4 | 39.105 | 0.142 | 116.285 | 0.047 | 0.049 | 0.051 | 0.207 | + +### llama2-13b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 1 | 1 | 128 | 57.49 | 74.84 | 0.018 | 0.018 | 0.019 | 0.017 | 0.017 | 0.017 | 0.017 | +| 1 | 1 | 128 | 128 | 56.58 | 74.84 | 0.04 | 0.039 | 0.04 | 0.017 | 0.017 | 0.017 | 0.018 | +| 1 | 1 | 128 | 2048 | 55.29 | 74.84 | 0.04 | 0.04 | 0.04 | 0.018 | 0.018 | 0.018 | 0.019 | +| 1 | 1 | 2048 | 128 | 48.99 | 75.09 | 0.242 | 0.242 | 0.243 | 0.019 | 0.019 | 0.019 | 0.019 | +| 1 | 1 | 2048 | 2048 | 52.12 | 75.09 | 0.243 | 0.24 | 0.244 | 0.019 | 0.019 | 0.019 | 0.02 | +| 16 | 1 | 1 | 128 | 869.45 | 74.87 | 0.036 | 0.019 | 0.053 | 0.018 | 0.019 | 0.019 | 0.02 | +| 16 | 1 | 128 | 128 | 757.3 | 75.09 | 0.252 | 0.041 | 0.272 | 0.019 | 0.02 | 0.02 | 0.021 | +| 16 | 1 | 128 | 2048 | 605.88 | 75.09 | 0.253 | 0.041 | 0.275 | 0.026 | 0.03 | 0.033 | 0.034 | +| 16 | 1 | 2048 | 128 | 257.92 | 76.96 | 3.442 | 0.245 | 3.668 | 0.033 | 0.034 | 0.035 | 0.035 | +| 16 | 1 | 2048 | 2048 | 366.67 | 76.99 | 3.122 | 0.249 | 3.671 | 0.04 | 0.044 | 0.047 | 0.047 | +| 32 | 1 | 1 | 128 | 1667.5 | 74.9 | 0.034 | 0.021 | 0.057 | 0.019 | 0.02 | 0.021 | 0.023 | +| 32 | 1 | 128 | 128 | 1301.27 | 75.37 | 0.461 | 0.04 | 0.497 | 0.021 | 0.022 | 0.023 | 0.025 | +| 32 | 1 | 128 | 2048 | 860.14 | 75.84 | 0.833 | 0.041 | 1.151 | 0.034 | 0.042 | 0.047 | 0.048 | +| 32 | 1 | 2048 | 128 | 291.54 | 77.02 | 5.315 | 0.245 | 13.483 | 0.046 | 0.047 | 0.049 | 0.51 | +| 32 | 1 | 2048 | 2048 | 389.64 | 77.02 | 38.725 | 0.245 | 108.104 | 0.047 | 0.047 | 0.049 | 0.05 | +| 64 | 1 | 1 | 128 | 3049.16 | 74.96 | 0.044 | 0.025 | 0.073 | 0.02 | 0.022 | 0.026 | 0.029 | +| 64 | 1 | 128 | 128 | 2033.22 | 75.87 | 0.703 | 0.046 | 0.951 | 0.024 | 0.026 | 0.029 | 0.032 | +| 64 | 1 | 128 | 2048 | 998.86 | 76.9 | 7.805 | 0.042 | 60.1 | 0.045 | 0.047 | 0.05 | 0.063 | +| 64 | 1 | 2048 | 128 | 286.32 | 76.99 | 19.69 | 0.245 | 32.394 | 0.047 | 0.048 | 0.05 | 0.27 | +| 64 | 1 | 2048 | 2048 | 387.86 | 77.09 | 190.453 | 0.245 | 307.331 | 0.047 | 0.048 | 0.049 | 0.05 | + +### internlm-20b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 2 | 1 | 128 | 61.14 | 73.55 | 0.018 | 0.017 | 0.019 | 0.016 | 0.016 | 0.016 | 0.018 | +| 1 | 2 | 128 | 128 | 60.03 | 73.55 | 0.042 | 0.041 | 0.043 | 0.016 | 0.016 | 0.016 | 0.017 | +| 1 | 2 | 128 | 2048 | 58.26 | 73.55 | 0.042 | 0.042 | 0.043 | 0.017 | 0.017 | 0.018 | 0.018 | +| 1 | 2 | 2048 | 128 | 51.93 | 73.68 | 0.217 | 0.216 | 0.217 | 0.018 | 0.018 | 0.018 | 0.018 | +| 1 | 2 | 2048 | 2048 | 56.36 | 73.68 | 0.217 | 0.217 | 0.217 | 0.018 | 0.018 | 0.018 | 0.018 | +| 16 | 2 | 1 | 128 | 903.01 | 73.65 | 0.034 | 0.018 | 0.051 | 0.017 | 0.018 | 0.019 | 0.02 | +| 16 | 2 | 128 | 128 | 794.13 | 73.74 | 0.227 | 0.043 | 0.248 | 0.018 | 0.019 | 0.02 | 0.021 | +| 16 | 2 | 128 | 2048 | 669.87 | 73.74 | 0.227 | 0.043 | 0.25 | 0.024 | 0.027 | 0.029 | 0.03 | +| 16 | 2 | 2048 | 128 | 288.60 | 75.60 | 3.09 | 0.247 | 4.485 | 0.029 | 0.03 | 0.031 | 0.032 | +| 16 | 2 | 2048 | 2048 | 441.46 | 75.61 | 3.172 | 0.219 | 4.442 | 0.035 | 0.037 | 0.04 | 0.041 | +| 32 | 2 | 1 | 128 | 1673.64 | 73.71 | 0.037 | 0.02 | 0.066 | 0.019 | 0.02 | 0.021 | 0.023 | +| 32 | 2 | 128 | 128 | 1347.57 | 73.90 | 0.351 | 0.043 | 0.436 | 0.02 | 0.021 | 0.023 | 0.025 | +| 32 | 2 | 128 | 2048 | 1025.62 | 73.90 | 0.391 | 0.042 | 0.441 | 0.031 | 0.037 | 0.041 | 0.043 | +| 32 | 2 | 2048 | 128 | 352.45 | 75.74 | 6.062 | 0.218 | 6.3 | 0.042 | 0.043 | 0.045 | 0.046 | +| 32 | 2 | 2048 | 2048 | 514.60 | 75.77 | 10.36 | 0.222 | 70.328 | 0.049 | 0.05 | 0.051 | 0.053 | +| 64 | 2 | 1 | 128 | 2954.34 | 73.82 | 0.05 | 0.029 | 0.074 | 0.021 | 0.023 | 0.026 | 0.03 | +| 64 | 2 | 128 | 128 | 2122.92 | 74.24 | 0.591 | 0.047 | 0.808 | 0.024 | 0.026 | 0.029 | 0.032 | +| 64 | 2 | 128 | 2048 | 1276.61 | 75.18 | 2.529 | 0.049 | 41.212 | 0.042 | 0.048 | 0.052 | 0.055 | +| 64 | 2 | 2048 | 128 | 350.82 | 75.88 | 12.382 | 0.219 | 20.986 | 0.05 | 0.051 | 0.054 | 0.249 | +| 64 | 2 | 2048 | 2048 | 512.37 | 76.26 | 111.149 | 0.221 | 211.531 | 0.05 | 0.051 | 0.052 | 0.055 | + +### llama2-70b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 4 | 1 | 128 | 33.94 | 73.72 | 0.031 | 0.03 | 0.031 | 0.029 | 0.029 | 0.029 | 0.03 | +| 1 | 4 | 128 | 128 | 33.63 | 73.72 | 0.074 | 0.073 | 0.074 | 0.029 | 0.029 | 0.029 | 0.03 | +| 1 | 4 | 128 | 2048 | 32.38 | 73.72 | 0.074 | 0.074 | 0.075 | 0.031 | 0.031 | 0.031 | 0.031 | +| 1 | 4 | 2048 | 128 | 28.32 | 73.78 | 0.402 | 0.401 | 0.403 | 0.031 | 0.031 | 0.031 | 0.051 | +| 1 | 4 | 2048 | 2048 | 31.9 | 73.78 | 0.405 | 0.402 | 0.407 | 0.031 | 0.031 | 0.031 | 0.031 | +| 16 | 4 | 1 | 128 | 468.52 | 73.72 | 0.071 | 0.034 | 0.939 | 0.03 | 0.031 | 0.032 | 0.251 | +| 16 | 4 | 128 | 128 | 439.77 | 73.81 | 0.437 | 0.08 | 0.687 | 0.03 | 0.031 | 0.032 | 0.207 | +| 16 | 4 | 128 | 2048 | 482.99 | 73.81 | 0.403 | 0.079 | 0.44 | 0.033 | 0.033 | 0.035 | 0.036 | +| 16 | 4 | 2048 | 128 | 189.34 | 73.98 | 5.776 | 0.437 | 7.612 | 0.035 | 0.036 | 0.036 | 0.037 | +| 16 | 4 | 2048 | 2048 | 399.42 | 73.98 | 5.773 | 0.411 | 6.844 | 0.036 | 0.037 | 0.038 | 0.041 | +| 32 | 4 | 1 | 128 | 906.03 | 73.75 | 0.098 | 0.043 | 0.253 | 0.032 | 0.033 | 0.035 | 0.178 | +| 32 | 4 | 128 | 128 | 746.36 | 73.91 | 0.749 | 0.078 | 1.026 | 0.032 | 0.033 | 0.035 | 0.438 | +| 32 | 4 | 128 | 2048 | 853.56 | 73.91 | 0.732 | 0.076 | 1.129 | 0.036 | 0.038 | 0.041 | 0.158 | +| 32 | 4 | 2048 | 128 | 232.6 | 73.99 | 11.834 | 0.408 | 13.321 | 0.04 | 0.041 | 0.043 | 0.248 | +| 32 | 4 | 2048 | 2048 | 636.23 | 73.99 | 11.711 | 0.409 | 12.689 | 0.043 | 0.045 | 0.048 | 0.179 | +| 64 | 4 | 1 | 128 | 1425.79 | 73.81 | 0.213 | 0.046 | 1.264 | 0.037 | 0.039 | 0.044 | 0.329 | +| 64 | 4 | 128 | 128 | 1159.84 | 73.96 | 1.292 | 0.107 | 2.676 | 0.037 | 0.04 | 0.045 | 0.378 | +| 64 | 4 | 128 | 2048 | 1391.8 | 73.95 | 1.173 | 0.135 | 1.623 | 0.043 | 0.047 | 0.052 | 0.251 | +| 64 | 4 | 2048 | 128 | 270.47 | 74.02 | 17.402 | 0.452 | 24.164 | 0.05 | 0.052 | 0.057 | 0.345 | +| 64 | 4 | 2048 | 2048 | 930.46 | 74.01 | 21.29 | 0.423 | 24.498 | 0.055 | 0.059 | 0.065 | 0.299 | + +## Request Throughput Benchmark + +FTL: **F**irst **T**oken **L**atency + +| model | batch | tp | num_prompts | PRS | PRM | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | throughput(out tok/s) | throughput(total tok/s) | +| ------------ | ----- | --- | ----------- | ------ | ------- | ----------- | ----------- | ----------- | --------------------- | ----------------------- | +| llama2-7b | 64 | 1 | 3000 | 10.275 | 616.477 | 0.092 | 0.036 | 1.145 | 2562.435 | 5283.547 | +| | 128 | 1 | 3000 | 12.611 | 756.677 | 0.205 | 0.056 | 2.241 | 3210.281 | 6619.357 | +| llama2-13b | 64 | 1 | 3000 | 6.337 | 380.244 | 0.159 | 0.051 | 2.048 | 1474.786 | 3039.398 | +| | 128 | 1 | 3000 | 7.588 | 455.273 | 0.412 | 0.085 | 4.445 | 1765.788 | 3639.128 | +| internlm-20b | 64 | 2 | 3000 | 7.842 | 470.516 | 0.166 | 0.059 | 2.461 | 1564.696 | 3311.16 | +| | 128 | 2 | 3000 | 9.776 | 586.568 | 0.34 | 0.079 | 5.808 | 1950.627 | 4127.855 | +| llama2-70b | 64 | 4 | 3000 | 4.285 | 257.08 | 0.301 | 0.083 | 4.689 | 1000.376 | 2062.7 | +| | 128 | 4 | 3000 | 5.833 | 349.996 | 0.633 | 0.107 | 8.431 | 1361.939 | 2808.216 | +| | 256 | 4 | 3000 | 6.568 | 394.108 | 1.49 | 0.171 | 19.52 | 1533.592 | 3162.15 | diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 66ba6c68d2..fc18825a2e 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -31,54 +31,59 @@ def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None: tp=tp, **kwargs) self.tokenizer = self.tm_model.tokenizer - self.generators = [ - self.tm_model.create_instance() for i in range(instance_num) - ] self.instance_num = instance_num self.model = self.tm_model.model - self.available = [True] * instance_num - self.starts = [None] * instance_num - self.steps = {} + self.id2step = {} + self.id2generator = {} self.loop = asyncio.get_event_loop() + self.special_gen = self.tm_model.create_instance() + self.gens_set = set() + for i in range(instance_num): + self.gens_set.add(self.tm_model.create_instance()) def stop_session(self, session_id: int): """Stop a session by a session_id.""" - instance_id = session_id % self.instance_num - input_ids = self.tokenizer.encode('') - for outputs in self.generators[instance_id].stream_infer( - session_id, - input_ids, - request_output_len=0, - sequence_start=False, - sequence_end=False, - stop=True): + input_ids = [self.tm_model.eos_id] + stop_generator = self.id2generator.get(str(session_id), + self.special_gen) + for outputs in stop_generator.stream_infer(session_id, + input_ids, + request_output_len=0, + sequence_start=False, + sequence_end=False, + stop=True): pass - self.available[instance_id] = True + if str(session_id) in self.id2generator and self.id2generator[str( + session_id)] not in self.gens_set: + self.gens_set.add(self.id2generator[str(session_id)]) def end_session(self, session_id: int): """Clear a session by a session_id.""" - instance_id = session_id % self.instance_num - input_ids = self.tokenizer.encode('') - for outputs in self.generators[instance_id].stream_infer( - session_id, - input_ids, - request_output_len=0, - sequence_start=False, - sequence_end=True, - stop=True): + input_ids = [self.tm_model.eos_id] + end_generator = self.id2generator.get(str(session_id), + self.special_gen) + for outputs in end_generator.stream_infer(session_id, + input_ids, + request_output_len=0, + sequence_start=False, + sequence_end=True, + stop=True): pass - self.steps[str(session_id)] = 0 - self.available[instance_id] = True + self.id2step[str(session_id)] = 0 + if str(session_id) in self.id2generator and self.id2generator[str( + session_id)] not in self.gens_set: + self.gens_set.add(self.id2generator[str(session_id)]) @contextmanager - def safe_run(self, instance_id: int, session_id: Optional[int] = None): + def safe_run(self, session_id: Optional[int] = None): """A context manager to make sure server's safe running.""" - self.available[instance_id] = False try: yield except (Exception, asyncio.CancelledError) as e: # noqa self.stop_session(session_id) - self.available[instance_id] = True + if str(session_id) in self.id2generator and self.id2generator[str( + session_id)] not in self.gens_set: + self.gens_set.add(self.id2generator[str(session_id)]) async def get_embeddings(self, prompt, do_prerpocess=False): if do_prerpocess: @@ -86,12 +91,13 @@ async def get_embeddings(self, prompt, do_prerpocess=False): input_ids = self.tokenizer.encode(prompt) return input_ids - async def get_generator(self, instance_id: int, stop: bool = False): + async def get_generator(self, stop: bool, session_id: int): """Only return the model instance if it is available.""" - if not stop: - while self.available[instance_id] is False: - await asyncio.sleep(0.1) - return self.generators[instance_id] + if stop: + return self.id2generator.get(str(session_id), self.special_gen) + while self.gens_set == set(): + await asyncio.sleep(0) + return self.gens_set.pop() def batch_infer(self, prompts: List[str], @@ -189,27 +195,27 @@ async def generate( ignore_eos (bool): indicator for ignoring eos do_preprocess (bool): whether pre-process the messages. """ - instance_id = session_id % self.instance_num - if str(session_id) not in self.steps: - self.steps[str(session_id)] = 0 + if str(session_id) not in self.id2step: + self.id2step[str(session_id)] = 0 if step != 0: - self.steps[str(session_id)] = step + self.id2step[str(session_id)] = step seed = random.getrandbits(64) prompt = messages if do_preprocess: prompt = self.model.messages2prompt(prompt, sequence_start) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) finish_reason = 'stop' if stop else None - if self.steps[str(session_id)] + len( + if self.id2step[str(session_id)] + len( input_ids) + request_output_len >= self.tm_model.session_len: finish_reason = 'length' - yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, + yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0, finish_reason) if sequence_end is True and sequence_start is False: self.end_session(session_id) else: - generator = await self.get_generator(instance_id, stop) - with self.safe_run(instance_id, session_id): + generator = await self.get_generator(stop, session_id) + self.id2generator[str(session_id)] = generator + with self.safe_run(session_id): response_size = 0 async for outputs in generator.async_stream_infer( session_id=session_id, @@ -218,7 +224,7 @@ async def generate( request_output_len=request_output_len, sequence_start=(sequence_start), sequence_end=sequence_end, - step=self.steps[str(session_id)], + step=self.id2step[str(session_id)], stop=stop, top_k=top_k, top_p=top_p, @@ -237,16 +243,16 @@ async def generate( continue # response, history token len, # input token len, gen token len - yield GenOut(response, self.steps[str(session_id)], + yield GenOut(response, self.id2step[str(session_id)], len(input_ids), tokens, finish_reason) response_size = tokens # `response_size` might be note updated since # ` if response.endswith('�')` if response_size != tokens: - yield GenOut(response, self.steps[str(session_id)], + yield GenOut(response, self.id2step[str(session_id)], len(input_ids), tokens, finish_reason) # update step - self.steps[str(session_id)] += len(input_ids) + tokens + self.id2step[str(session_id)] += len(input_ids) + tokens if sequence_end or stop: - self.steps[str(session_id)] = 0 + self.id2step[str(session_id)] = 0 diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 0b61f7967b..bcf40ba2b2 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -136,7 +136,10 @@ async def chat_completions_v1(request: ChatCompletionRequest, top_p=request.top_p, temperature=request.temperature, repetition_penalty=request.repetition_penalty, - ignore_eos=request.ignore_eos) + ignore_eos=request.ignore_eos, + do_preprocess=not isinstance(request.messages, + str), # text completion for string input + ) def create_stream_response_json( index: int, @@ -424,7 +427,7 @@ async def chat_interactive_v1(request: GenerateRequest, request.session_id = random.randint(10087, 23333) async_engine = VariableInterface.async_engine - sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0 + sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0 sequence_end = not request.interactive_mode generation = async_engine.generate( diff --git a/lmdeploy/version.py b/lmdeploy/version.py index f72f79f35e..1109a4bcdc 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.1.0a1' +__version__ = '0.1.0a2' short_version = __version__ diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index 4877bdb1a0..f7ebfeff03 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T* logits, dim3 grid(local_batch_size); size_t smem_size = step * (sizeof(float) + sizeof(int)); if (penalty_type == RepetitionPenaltyType::Additive) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc index 2738e674e3..32dd0f12c7 100644 --- a/src/turbomind/models/llama/BlockManager.cc +++ b/src/turbomind/models/llama/BlockManager.cc @@ -4,6 +4,7 @@ #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/debug_utils.h" #include "src/turbomind/utils/logger.h" +#include "src/turbomind/utils/string_utils.h" #include #include #include @@ -70,7 +71,6 @@ bool BlockManager::Malloc() for (int i = 0; i < chunk_size; ++i, ptr += block_size_) { auto& block = blocks_.emplace_back(); block.use_count = 0; - block.ref_count = 0; block.id = (int)blocks_.size() - 1; block.timestamp = 0; block.data = ptr; @@ -91,16 +91,23 @@ size_t BlockManager::GetBlockCount(size_t block_size, double ratio) void BlockManager::Move(std::vector& src, const std::vector& delta, std::vector& dst) { + FT_CHECK(src.size() >= delta.size()); std::vector src1(src.size() - delta.size()); - std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin()); + { + auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin()); + FT_CHECK(end == src1.end()); + } src.swap(src1); std::vector dst1(dst.size() + delta.size()); - std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin()); + { + auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin()); + FT_CHECK(end == dst1.end()); + } dst.swap(dst1); } -std::vector BlockManager::Allocate(int count) +auto BlockManager::Allocate(int count) -> std::pair { while (free_ids_.size() < count) { if (!Malloc()) { @@ -108,30 +115,30 @@ std::vector BlockManager::Allocate(int count) } } - std::vector ret; - - std::vector idxs(count); + BlockIds block_ids(count); + UniqueIds unique_ids(count); for (int i = 0; i < count; ++i) { - int idx = free_ids_[i]; - idxs[i] = idx; - auto& block = blocks_[idx]; - FT_CHECK(is_free(block)); - block.ref_count = 1; - block.use_count = 1; - block.unique_id = unique_id_++; - ret.push_back(&block); + int idx = free_ids_[i]; + auto& b = blocks_[idx]; + FT_CHECK(is_free(b)); // pre-condition: uc == 0 && ts == 0 + b.use_count = 1; + b.unique_id = unique_id_++; + FT_CHECK(is_active(b)); // post-condition + block_ids[i] = idx; + unique_ids[i] = b.unique_id; } - Move(free_ids_, idxs, active_ids_); + Move(free_ids_, block_ids, active_ids_); dbg(free_ids_, active_ids_); - return ret; + return {block_ids, unique_ids}; } void BlockManager::Evict(int count) { + FT_CHECK(count <= cached_ids_.size()); std::vector idxs(cached_ids_); // get first `count` cached ids according to timestamp std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) { @@ -146,9 +153,9 @@ void BlockManager::Evict(int count) for (const auto& idx : idxs) { auto& b = blocks_[idx]; FT_CHECK(is_cached(b)); - b.ref_count = 0; b.unique_id = 0; b.timestamp = 0; + FT_CHECK(is_free(b)); } Move(cached_ids_, idxs, free_ids_); @@ -156,79 +163,94 @@ void BlockManager::Evict(int count) dbg(cached_ids_, free_ids_); } -int BlockManager::Free(const std::vector& bs) +void BlockManager::Free(BlockIds ids) { - std::vector idxs; + std::sort(ids.begin(), ids.end()); - for (const auto& p : bs) { - auto& b = blocks_[p->id]; - FT_CHECK(is_cached(b)); - if (--b.ref_count == 0) { - b.unique_id = 0; - b.timestamp = 0; - idxs.push_back(b.id); - } + for (const auto& i : ids) { + auto& b = blocks_[i]; + FT_CHECK(is_cached(b)); // uc == 0 && ts != 0 + b.unique_id = 0; + b.timestamp = 0; + FT_CHECK(is_free(b)); } - std::sort(idxs.begin(), idxs.end()); - - Move(cached_ids_, idxs, free_ids_); - - dbg(cached_ids_, free_ids_); - - return idxs.size(); + Move(cached_ids_, ids, free_ids_); } -int BlockManager::Unlock(const std::vector& bs) +int BlockManager::Unlock(const BlockIds& ids) { - std::vector idxs; - - for (const auto& p : bs) { - auto& block = blocks_[p->id]; - FT_CHECK(is_active(block)); - if (--block.use_count == 0) { - idxs.push_back(block.id); + BlockIds unlock; + unlock.reserve(ids.size()); + + for (const auto& i : ids) { + auto& b = blocks_[i]; + FT_CHECK(is_active(b)); // pre-condition: uc > 0 + if (--b.use_count == 0) { + unlock.push_back(b.id); + FT_CHECK(is_cached(b)); // post-condition } } - std::sort(idxs.begin(), idxs.end()); + std::sort(unlock.begin(), unlock.end()); - Move(active_ids_, idxs, cached_ids_); + Move(active_ids_, unlock, cached_ids_); dbg(active_ids_, cached_ids_); - - return idxs.size(); + return unlock.size(); } -int BlockManager::Lock(const std::vector& bs) +int BlockManager::Lock(const BlockIds& ids) { - std::vector idxs; - - for (const auto& p : bs) { - auto& block = blocks_[p->id]; - FT_CHECK(is_cached(block)); - if (++block.use_count == 1) { - idxs.push_back(p->id); + BlockIds lock; + lock.reserve(ids.size()); + + for (const auto& i : ids) { + auto& b = blocks_[i]; + FT_CHECK_WITH_INFO(is_cached(b), to_string(b)); + if (++b.use_count == 1) { + lock.push_back(i); + FT_CHECK(is_active(b)); } } - std::sort(idxs.begin(), idxs.end()); + std::sort(lock.begin(), lock.end()); - Move(cached_ids_, idxs, active_ids_); + Move(cached_ids_, lock, active_ids_); // dbg(cached_ids_, active_ids_); - return idxs.size(); + return lock.size(); } -void BlockManager::Touch(const std::vector& bs) +void BlockManager::Touch(const BlockIds& ids) { - std::for_each(bs.crbegin(), bs.crend(), [this](const Block* p) { - FT_CHECK(is_active(*p)); - const_cast(p)->timestamp = timestamp_++; + std::for_each(ids.crbegin(), ids.crend(), [this](int i) { + FT_CHECK(is_active(blocks_[i])); + blocks_[i].timestamp = timestamp_++; }); } +int BlockManager::Verify(const std::vector& block_ids, const std::vector& unique_ids) +{ + FT_CHECK(block_ids.size() == unique_ids.size()); + int valid = block_ids.size(); + for (int i = 0; i < block_ids.size(); ++i) { + if (unique_id(block_ids[i]) != unique_ids[i]) { + valid = i; + break; + } + } + int miss = 0; + for (int i = valid; i < block_ids.size(); ++i) { + miss += (unique_id(block_ids[i]) != unique_ids[i]); + } + // All later blocks should have been invalidated + FT_CHECK_WITH_INFO(miss == (int)block_ids.size() - valid, + fmtstr("count = %d, valid = %d, miss = %d", (int)block_ids.size(), valid, miss)); + return valid; +} + Snapshot BlockManager::TakeSnapshot() { std::vector use_count(blocks_.size()); diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h index da3e53ee54..c9ec2d06dc 100644 --- a/src/turbomind/models/llama/BlockManager.h +++ b/src/turbomind/models/llama/BlockManager.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -22,28 +23,37 @@ namespace turbomind { struct Block { int id; // fixed linear id in the pool - int ref_count; // all sequences referencing the block int use_count; // active sequences using the block uint64_t unique_id; // unique for every block allocation uint64_t timestamp; void* data; friend std::ostream& operator<<(std::ostream& os, const Block& block); + friend std::string to_string(const Block& b) + { + std::stringstream ss; + ss << b; + return ss.str(); + } }; +using BlockIds = std::vector; +using UniqueIds = std::vector; + inline bool is_active(const Block& block) { - return block.ref_count > 0 && block.use_count > 0; + // timestamp may be 0 for newly allocated block that has not been written + return block.use_count > 0; } inline bool is_cached(const Block& block) { - return block.ref_count > 0 && block.use_count == 0; + return block.use_count == 0 && block.timestamp != 0; } inline bool is_free(const Block& block) { - return block.ref_count == 0 && block.use_count == 0 && block.timestamp == 0; + return block.use_count == 0 && block.timestamp == 0; } struct Snapshot { @@ -60,22 +70,24 @@ class BlockManager { ~BlockManager(); // free -> active (use_count = 1, ref_count = 1) - [[nodiscard]] std::vector Allocate(int count); + [[nodiscard]] std::pair Allocate(int count); // cached -> active (use_count += 1) - [[maybe_unused]] int Lock(const std::vector& bs); + [[maybe_unused]] int Lock(const BlockIds& ids); // active -> cached (use_count -= 1) - [[maybe_unused]] int Unlock(const std::vector& bs); + [[maybe_unused]] int Unlock(const BlockIds& ids); // cached -> free (ref_count = 0) void Evict(int count); // cached -> free (ref_count -= 1) - [[maybe_unused]] int Free(const std::vector& bs); + void Free(BlockIds bs); // increase timestamp in reversed order - void Touch(const std::vector& bs); + void Touch(const BlockIds& bs); + + [[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids); Snapshot TakeSnapshot(); @@ -99,13 +111,23 @@ class BlockManager { return (max_block_count_ - blocks_.size()) + free_ids_.size(); } + Block& block(int idx) + { + return blocks_[idx]; + } + + int unique_id(int idx) + { + return blocks_[idx].unique_id; + } + friend std::ostream& operator<<(std::ostream& os, const BlockManager&); private: static size_t GetBlockCount(size_t block_size, double ratio); // move indices between sets - static void Move(std::vector& src, const std::vector& delta, std::vector& dst); + static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst); // allocate a chunk of blocks bool Malloc(); @@ -118,13 +140,12 @@ class BlockManager { std::vector chunks_; - std::vector active_ids_; - std::vector cached_ids_; - std::vector free_ids_; + BlockIds active_ids_; + BlockIds cached_ids_; + BlockIds free_ids_; std::vector blocks_; // < 100k - // uint64_t unique_id_{1UL << 63}; uint64_t unique_id_{1}; uint64_t timestamp_{1}; }; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 8ba2d3e4a3..01d05b2962 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -569,11 +569,11 @@ void LlamaBatch::Initialize(GenerationState& g) FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(), std::to_string(h_cu_block_counts_[i + 1])); - k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) { - return reinterpret_cast(sequence_manager_->OffsetKey(p->data)); + k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](int block_id) { + return reinterpret_cast(sequence_manager_->GetKeyPtr(block_id)); }); - v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) { - return reinterpret_cast(sequence_manager_->OffsetVal(p->data)); + v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](int block_id) { + return reinterpret_cast(sequence_manager_->GetValPtr(block_id)); }); } @@ -589,7 +589,7 @@ void LlamaBatch::Initialize(GenerationState& g) // check if the last sequence is partial int partial = 0; int partial_len = -1; - { + if (state_->active_size) { const int i = state_->active_size - 1; partial = state_->sequences[i]->cache_len + state_->sequences[i]->input_length != state_->h_context_length[i]; if (partial) { @@ -958,6 +958,8 @@ LlamaBatch::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i session_len_ = max_session_len; } + FT_CHECK(max_context_token_num_ >= session_len_); + for (auto& s : states_) { s.requests.resize(max_batch_size_); s.sequences.resize(max_batch_size_); diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 69bac0d4e8..70de38187f 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -76,7 +76,7 @@ LlamaV2::LlamaV2(size_t head_num, end_id_(end_id), hidden_units_(head_num * size_per_head), local_head_num_(head_num / tensor_para.world_size_), - local_kv_head_num_(head_num / tensor_para.world_size_), + local_kv_head_num_(kv_head_num / tensor_para.world_size_), weights_(weights), tensor_para_(tensor_para), stream_(stream), diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index b9f1219f55..0951fc4045 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -37,29 +37,20 @@ SequenceManager::SequenceManager(size_t layer_num, const Sequence* SequenceManager::Create(uint64_t id) { Sequence sequence{id}; - - auto it = sequences_.find(id); + auto it = sequences_.find(id); if (it != sequences_.end()) { if (rank_ == 0) { TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id); } - auto& seq = it->second; - if (seq.status != Sequence::kCached) { - unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end()); - } - seq = std::move(sequence); + Erase(it); } - else { - it = sequences_.emplace_hint(it, id, std::move(sequence)); - } - + it = sequences_.emplace_hint(it, id, std::move(sequence)); return &it->second; } const Sequence* SequenceManager::Get(uint64_t id) { if (auto it = sequences_.find(id); it != sequences_.end()) { - auto& sequence = it->second; return &it->second; } return nullptr; @@ -70,23 +61,24 @@ bool SequenceManager::Contains(uint64_t id) return sequences_.find(id) != sequences_.end(); } +void SequenceManager::Erase(std::map::iterator it) +{ + auto& seq = it->second; + if (seq.status == Sequence::kCached) { + const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids); + seq.blocks.resize(count); + } + else { + UpdateAndSetUnlock(seq); + } + freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end()); + sequences_.erase(it); +} + bool SequenceManager::Erase(uint64_t id) { if (auto it = sequences_.find(id); it != sequences_.end()) { - auto& seq = it->second; - if (seq.status != Sequence::kCached) { - unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end()); - freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end()); - } - else { - for (int i = 0; i < seq.blocks.size(); ++i) { - // filter invalidated blocks - if (seq.blocks[i]->unique_id == seq.block_unique_ids[i]) { - freed_.push_back(seq.blocks[i]); - } - } - } - sequences_.erase(it); + Erase(it); return true; } return false; @@ -94,28 +86,23 @@ bool SequenceManager::Erase(uint64_t id) void SequenceManager::VerifyAndLockCached(const Sequences& sequences) { - std::vector blocks; + BlockIds blocks; for (const auto& p : sequences) { auto& seq = const_cast(*p); if (seq.status != Sequence::kCached) { continue; } FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size()); - if (need_verify_) { - for (int i = 0; i < seq.blocks.size(); ++i) { - if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) { - seq.blocks.resize(i); - seq.block_unique_ids.resize(i); - break; - } - } - } + // Verify cache blocks that may be invalidated + const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids); + seq.blocks.resize(count); + seq.block_unique_ids.resize(count); + blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end()); seq.cache_len = std::min(seq.cache_len, seq.blocks.size() * block_seq_len_); seq.status = Sequence::kLocked; } block_manager_->Lock(blocks); - need_verify_ = false; } void SequenceManager::CommitUnlockAndFree() @@ -177,8 +164,8 @@ struct Schedule { while (vidx < it_) { const auto& blocks = seqs[--it_]->blocks; int count = 0; - for (const auto& p : blocks) { - count += static_cast(--use_count_[p->id] == 0); + for (const auto& bid : blocks) { + count += static_cast(--use_count_[bid] == 0); } unlocked_[it_] = count; } @@ -354,21 +341,20 @@ std::vector SequenceManager::CountRequiredBlocks(const Sequences& se return required; } -void SequenceManager::AssignAndActivate(const Sequences& sequences, // - const std::vector& counts, - const std::vector& blocks) +void SequenceManager::AssignAndActivate(const Sequences& sequences, // + const std::vector& counts, + const BlockIds& blocks, + const UniqueIds& unique_ids) { FT_CHECK(sequences.size() == counts.size()); - auto first = blocks.begin(); + int first = 0; for (int i = 0; i < sequences.size(); ++i) { auto& s = const_cast(*sequences[i]); auto count = counts[i]; - // dbg(count); - auto last = first + count; - std::for_each(first, last, [&](const Block* b) { - s.blocks.push_back(b); - s.block_unique_ids.push_back(b->unique_id); - }); + int last = first + count; + FT_CHECK(last <= blocks.size()); + s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last); + s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last); s.status = Sequence::kActive; first = last; } @@ -448,16 +434,16 @@ auto SequenceManager::Materialize(Sequences sequences, // evict cached blocks -> free if (schedule.evict) { block_manager_->Evict(schedule.evict); - need_verify_ = true; } // allocate & assign blocks { - std::vector blocks; + BlockIds block_ids; + UniqueIds unique_ids; if (schedule.allocate) { - blocks = block_manager_->Allocate(schedule.allocate); + std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate); } - AssignAndActivate(schedule.active, schedule.block_counts, blocks); + AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids); } // active -> locked @@ -467,6 +453,11 @@ auto SequenceManager::Materialize(Sequences sequences, } } + // TM_LOG_ERROR("active: %4d, cached: %4d, free: %4d", + // block_manager_->active_count(), + // block_manager_->cached_count(), + // block_manager_->free_count()); + return outcome; } diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index 0f27c03a3b..31a12d113b 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -19,8 +19,8 @@ struct Sequence { uint64_t id; Status status = kCached; - std::vector blocks; - std::vector block_unique_ids; + BlockIds blocks; + UniqueIds block_unique_ids; int input_length = 0; @@ -38,7 +38,7 @@ struct Sequence { mutable std::vector embedding_begins; mutable std::vector embedding_ends; - Sequence(uint64_t _id): id(_id) {} + explicit Sequence(uint64_t _id): id(_id) {} friend std::ostream& operator<<(std::ostream& os, const Sequence& seq); }; @@ -92,14 +92,14 @@ class SequenceManager { int step_length, AdjustInputCount adjust); - void* OffsetKey(void* block_ptr) + [[nodiscard]] void* GetKeyPtr(int block_id) { - return block_ptr; + return block_manager_->block(block_id).data; } - void* OffsetVal(void* block_ptr) + [[nodiscard]] void* GetValPtr(int block_id) { - return (std::byte*)block_ptr + val_offset_; + return (std::byte*)GetKeyPtr(block_id) + val_offset_; } int max_block_count() const noexcept @@ -108,6 +108,8 @@ class SequenceManager { } private: + void Erase(std::map::iterator it); + void CommitUnlockAndFree(); void VerifyAndLockCached(const Sequences& sequences); @@ -120,24 +122,23 @@ class SequenceManager { std::vector& context_lengths, const std::vector& priorities); - static void AssignAndActivate(const Sequences& sequences, // - const std::vector& block_counts, - const std::vector& blocks); + static void AssignAndActivate(const Sequences& sequences, // + const std::vector& counts, + const BlockIds& blocks, + const UniqueIds& unique_ids); private: int block_seq_len_; int rank_; size_t val_offset_{}; - bool need_verify_{}; - // Use `std::map` to avoid reference invalidation std::map sequences_; std::unique_ptr block_manager_; - std::vector unlocked_; - std::vector freed_; + BlockIds unlocked_; + BlockIds freed_; }; inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc) diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 20974eeea9..358f5c04f6 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -15,8 +15,14 @@ void UnifiedDecoder::allocateBuffer(size_t num_token, size_t pf_batch_size, s TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (pf_batch_size) { - attention_mask_ = - (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false); + if (need_causal_mask_) { + attention_mask_ = (T*)allocator_->reMalloc( + attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false); + } + else { + // just to avoid nullptr + attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T), false); + } padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false); cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false); @@ -162,14 +168,16 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con FT_CHECK(tmp_token_num == token_num - dc_batch_size); - invokeCreateCausalMasks(attention_mask_, - input_length + pf_offset, - context_length + pf_offset, - pf_max_q_len, - pf_max_k_len, - pf_batch_size, - stream_); - sync_check_cuda_error(); + if (need_causal_mask_) { + invokeCreateCausalMasks(attention_mask_, + input_length + pf_offset, + context_length + pf_offset, + pf_max_q_len, + pf_max_k_len, + pf_batch_size, + stream_); + sync_check_cuda_error(); + } } ///////////////////////////////////////////// diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index daac2b4df6..533976f947 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -5,6 +5,7 @@ #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/utils/cublasMMWrapper.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/nccl_utils.h" namespace turbomind { @@ -46,6 +47,8 @@ class UnifiedDecoder { const DataType dtype_; + bool need_causal_mask_{false}; + using WeightType = LlamaDecoderLayerWeight; void forwardSelfAttn(T* attn_io, @@ -88,6 +91,14 @@ class UnifiedDecoder { tensor_para_(tensor_para), dtype_(getTensorType()) { +#ifdef _MSC_VER + // Both unfused MHA and flash attention 1 need causal mask + need_causal_mask_ = true; +#endif + // attention mask is not used for FA-1 (which requires sm80+ and half/bf16 data type) + if (!use_fmha || (getSMVersion() < 80 || sizeof(T) != 2)) { + need_causal_mask_ = true; + } initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy); }