From 32cc6ce474a7c624f09a24dbf0eac308fb319c10 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 18 Dec 2023 16:24:20 +0800 Subject: [PATCH] Unified paging (#860) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * change 'model_format' to 'qwen' when 'model_name' starts with 'qwen' (#575) * avoid split chinese characters during decoding (#566) * add solar chat template (#576) * robust incremental decode for leading space (#581) * robust incremental decode for leading space * speed up lookup as prefix_space_tokens is shorter than no_prefix_space_tokens * add UT and fix qwen stuff * update solar chat template (#587) * Revert "[Docs] Simplify `build.md` (#370)" (#586) This reverts commit 4b5c2bda074eb4ac2e70c3c793fb5ef48f87d9c8. * Fix crash and remove `sys_instruct` from `chat.py` and `client.py`(#591) * fix crash * update profile_generation.py * format * use self.bos_id * remove sys_instruct * bump version to v0.0.12 (#604) * Add "build from docker" section (#602) * add build from docker section * update * install python package * update * update * update * Add more user-friendly CLI (#541) * add * import fire in main * wrap to speed up fire cli * update * update docs * update docs * fix * resolve commennts * resolve confict and add test for cli * support inference a batch of prompts (#467) * support inference a batch of prompts * docstring and assert * bump version to v0.0.13 (#620) * Improve api_server and webui usage (#544) * make IPv6 compatible, safe run for coroutine interrupting * instance_id -> session_id and fix api_client.py * update doc * remove useless faq * safe ip mapping * update app.py * WIP completion * completion * update doc * disable interactive mode for /v1/chat/completions * docstring * docstring * refactor gradio * update gradio * udpate * update doc * rename * session_id default -1 * missed two files * add a APIClient * add chat func for APIClient * refine * add concurrent function * sequence_start, sequence_end --> interactive_mode * update doc * comments * doc * better text completion * remove /v1/embeddings * comments * deprecate generate and use /v1/interactive/completions * /v1/interactive/completion -> /v1/chat/interactive * embeddings * rename * remove wrong arg description * docstring * fix * update cli * update doc * strict session_len limit condition * pass model args to api_server * fix: gradio gr.Button.update deprecated after 4.0.0 (#637) * add cli to list the supported model names (#639) * update * resolve comment * Refactor model conversion (#296) * split deploy.py * fix get_cuda_tensor * deploy qwen_awq * fix lint * add docstring * fix * support baichuan/baichuan-awq * parameterizing size_per_head * remove try/except * limit input model_format * add quant_path param * remove old deploy.py * fix path * fix transformer layer range when load bins * fix qwen init * split & save log * relative import * update get_config * WeightFileMgr -> Reader * rename * update * fix init_layer_id * rename llama.py -> meta_llama.py, hf.py -> llama.py * reduce code * update arg description * fix meta llama * manually cleanup meta model params * [Enchance] internlm message to prompt (#499) * update turbomind session_len with model.session_len (#634) * [Fix] Qwen's quantization results are abnormal & Baichuan cannot be quantized (#605) * fix awq * adapt new qwen code * adapt qwen 14b and baichuan2 7b * add docstring * add runtime error for qwen * FIX: fix stop_session func bug (#578) * FIX: fix stop_session func bug * keep sequence_end = False --------- Co-authored-by: honglei.yan Co-authored-by: AllentDan * Manage session id using random int for gradio local mode (#553) * Use session id from gradio state * use a new session id after reset * rename session id like a state * update comments * reformat files * init session id on block loaded * use auto increased session id * remove session id textbox * apply to api_server and tritonserver * update docstring * add lock for safety --------- Co-authored-by: AllentDan * fix benchmark serving computation mistake (#630) * fix benchmark serving computation mistake * fix timestamps computations * remove speed up * no mp * mp seems faster? * remove * update * remove * fix * update * update print log * typo * print fist token latency only stream==True * remove renew_session * update AsyncEngine * fix tokenizer_info when convert the model (#661) * Add check env sub command (#654) * add check env * update issue template' * remove some reqs from check env * resolve comment * fix Tokenizer load error when the path of the being-converted model is not writable (#669) * Add UltraCM and WizardLM chat templates (#599) * add ultracm eval chat template * add WizardLM chat template * use ultrachat template instead of ultracm usecase * bump version to v0.0.14 (#663) * Add extra_requires to reduce dependencies (#580) * update reqs * update docs * resolve comments * upgrade pydantic * fix rebase * update doc * update * update * update readme * update * add flash-attn * TurboMind 2 (#590) * refresh decoder attention kernel * block-level kv cache * `BlockManager` & `SequenceManager` * update * update * update * update * rename * GQA support * fix context length * GQA dispatch * kv8 * tune * async stream cb * nvtx * config parsing * debug * optimize output cost * split-k decoding * minor * truncate `session_len` by available blocks * minor * license * fix * dispatch `cp.async` * fix linking * fix * fix deadlock * guard input length * correct start offset * fix prefill chunking * fix `cache_block_seq_len` param passing * fix `block_size` fmtstr * fix output tokens * fix batch resizing * fix masking of finished sequences * add debug util * free unused block early * add ntk scaling and logn scaling * cmake flags * fix typo * w4a16 for sm75 * fix msvc build * fix msvc build * fix block verification * fix msvc build * use `std::shuffle` * fix lint * fix lint * fix lint * clear incoming buffer * clear finished requests * fix batch initialization * fix typo * fix typo * fix comparison * [Docs] Update Supported Matrix (#679) * update supported matrix * change the default shard size when saving quantized weights * baichuan2 kv8 * update kv8 docs (#681) * Fix init of batch state (#682) * fix init of finished buf * fix `finished_count` * fix turbomind stream canceling (#686) * fix * instance for each forward * [Fix] Fix load_checkpoint_in_model bug (#690) * fix load_checkpoint_in_model bug * fix comments * fix comments * fix bugs * [Doc] Update restful api doc (#662) * update restful_api.md * add a hint * repeat 3 time * Fix Tokenizer encode (#645) * same encode with HF * sequence_start -> add_bos * complement * Fix wrong eos_id and bos_id obtained through grpc api (#644) * Fix wrong eos_id and bos_id obtained through grpc api * fix according to review comments * update * Optimize for throughput (#701) * tmp * update * update * optimize for throughput * update * fix eos * clean up * fix serving * fix indexed copy * minor * minor --------- Co-authored-by: lvhan028 * Check-in user guide about turbomind config (#680) * update * update config guide * update guide * upate user guide according to review comments * Replace mmengine with mmengine-lite (#715) * Support loading hf model directly (#685) * turbomind support export model params * fix overflow * support turbomind.from_pretrained * fix tp * support AutoModel * support load kv qparams * update auto_awq * udpate docstring * export lmdeploy version * update doc * remove download_hf_repo * LmdeployForCausalLM -> LmdeployForCausalLM * refactor turbomind.py * update comment * add bfloat16 convert back * support gradio run_locl load hf * support resuful api server load hf * add docs * support loading previous quantized model * adapt pr 690 * udpate docs * not export turbomind config when quantize a model * check model_name when can not get it from config.json * update readme * remove model_name in auto_awq * update * update * udpate * fix build * absolute import * Fix cache/output length calculation (#738) * bump version to v0.1.0a0 (#709) * [Fix] Skip empty batch (#747) * [Fix] build docker image failed since `packaging` is missing (#753) * [Fix] Rollback the data type of input_ids to TYPE_UINT32 in preprocessor's proto (#758) * Set the default value of `max_context_token_num` 1 (#761) * rename pytorch poc * fix lint * add docstring * add docstring * refactor patch * add recompute eviction support * fix typo (#769) * add triton server test and workflow yml (#760) * add triton server test and workflow yml * update * revert changes in dockerfile * update prompts * recovery modeling * fix turbomind build on sm<80 (#754) * fix * fix lint * improvement(build): enable ninja and gold linker (#767) * feat(build): enable ninja and lld * fix(.github): add ninja installation * fix(CI): remove dimsize=256 * fix(CI): add option for generate.sh * fix(docs): update * Report first-token-latency and token-latency percentiles (#736) * update profile scripts * add top_p, top_k and temperature as input arguments * fix input_ids * update profile_throughput * update profile_restful_api * update profile_serving * update * update * add progress bar * remove TODO comments * update * remove useless profile_* argument * remove log level * change concurrency default value to 64 * update restful_api.md * update according to review comments * fix docstring * convert model with hf repo_id (#774) * bump version to 0.1.0a1 (#776) * Update benchmark user guide (#763) * user guide of benchmark generation * update benchmark generation guide * update profiling throughput guide * update profiling api_server guide * rename file names * update profile tis user guide * update * fix according to review comments * update * update according to review comments * updaste * add an example * update * add docstring * add unified paging attention support * refactor block manager * do not alloc zero * Fix early exit condition in attention kernel (#788) * add chat template for Yi (#779) * Fix missed arguments when benchmark static inference performance (#787) * minor fix in the profile scripts and docs * miss arguments * typo * fix lint * update * Unify prefill & decode passes (#775) * Unify prefill and decode passes * dynamic split-fuse * refactor * correct input count calculation * remove unused * lint * lint * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * add cuda12.1 build check ci (#782) * update cuda12.1 build check ci * use matrix * auto upload cuda12.1 python pkg to release when create new tag (#784) * add cuda12-whl-release ci * enable environment * test py310-311 windows wheel * fix py310, py311 setup.py error on windows * fix lint * fix extra colon in InternLMChat7B (#796) * fix local kv head num (#806) * Report the inference benchmark of models with different size (#794) * update test scripts for models with different sizes * update * only test after tunning gemm * chmod +x * fix typo * benchmark on a100 * fix typo * fix typo * per-token latency percentile in profile_throughput * fix * fix * rename * make the script accept parameters * minor fix * indent * reformat table * change to 3000 * minor fix * bump version to v0.1.0a2 (#807) * fix out of bounds access (#809) * update scheduler * optimize request * Simplify block manager (#812) * simplify block manager * fix lint * set smem size for repetition penalty kernel (#818) * add mbgemm&mbgemv * fix recompute, fix mbgmm --------- Co-authored-by: Lyu Han Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: Chen Xin Co-authored-by: RunningLeon Co-authored-by: Yam(长琴) Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: yunzhongyan0 <549713537@qq.com> Co-authored-by: honglei.yan Co-authored-by: AllentDan Co-authored-by: aisensiy Co-authored-by: Li Zhang Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: tpoisonooo Co-authored-by: Qian Zhao <112053249+C1rN09@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/1-bug-report.yml | 12 + .github/scripts/test_triton_server.py | 73 + .github/workflows/cuda12.1-whl-release.yml | 108 + .github/workflows/linux-x64-gpu.yml | 12 +- .github/workflows/pypi.yml | 2 + .github/workflows/test.yml | 215 ++ .github/workflows/windows-x64-gpu.yml | 8 +- .gitignore | 3 + CMakeLists.txt | 37 +- README.md | 85 +- README_zh-CN.md | 84 +- benchmark/README.md | 4 +- benchmark/benchmark_13b.sh | 92 + benchmark/benchmark_20b.sh | 92 + benchmark/benchmark_70b.sh | 71 + benchmark/benchmark_7b.sh | 93 + benchmark/profile_generation.py | 402 ++-- benchmark/profile_restful_api.py | 411 ++-- benchmark/profile_serving.py | 386 ++-- benchmark/profile_throughput.py | 286 ++- builder/manywheel/build_all_wheel.sh | 4 +- builder/manywheel/entrypoint_build.sh | 4 +- docker/Dockerfile | 2 +- docs/en/benchmark/a100_fp16.md | 134 ++ docs/en/benchmark/profile_api_server.md | 107 + docs/en/benchmark/profile_generation.md | 88 + docs/en/benchmark/profile_throughput.md | 105 + docs/en/benchmark/profile_triton_server.md | 107 + docs/en/build.md | 82 +- docs/en/faq.md | 4 +- docs/en/kv_int8.md | 13 +- docs/en/load_hf.md | 71 + docs/en/pytorch.md | 6 +- docs/en/restful_api.md | 139 +- docs/en/serving.md | 18 +- docs/en/supported_models/codellama.md | 22 +- docs/en/turbomind.md | 2 +- docs/en/turbomind_config.md | 200 ++ docs/en/w4a16.md | 12 +- docs/zh_cn/benchmark/profile_api_server.md | 105 + docs/zh_cn/benchmark/profile_generation.md | 87 + docs/zh_cn/benchmark/profile_throughput.md | 105 + docs/zh_cn/benchmark/profile_triton_server.md | 109 + docs/zh_cn/build.md | 81 +- docs/zh_cn/faq.md | 4 +- docs/zh_cn/kv_int8.md | 13 +- docs/zh_cn/load_hf.md | 72 + docs/zh_cn/restful_api.md | 137 +- docs/zh_cn/serving.md | 18 +- docs/zh_cn/supported_models/codellama.md | 24 +- docs/zh_cn/turbomind.md | 2 +- docs/zh_cn/turbomind_config.md | 202 ++ docs/zh_cn/w4a16.md | 12 +- examples/cpp/llama/llama_triton_example.cc | 39 +- generate.sh | 8 +- lmdeploy/cli/__init__.py | 4 + lmdeploy/cli/chat.py | 90 + lmdeploy/cli/cli.py | 139 ++ lmdeploy/cli/lite.py | 100 + lmdeploy/cli/serve.py | 123 + lmdeploy/legacy/pytorch/chat.py | 3 +- lmdeploy/legacy/pytorch/modules/linear.py | 46 +- lmdeploy/lite/apis/auto_awq.py | 51 +- lmdeploy/lite/apis/calibrate.py | 130 +- lmdeploy/lite/apis/kv_qparams.py | 44 +- lmdeploy/lite/quantization/awq.py | 6 +- .../lite/quantization/weight/quantizer.py | 15 +- lmdeploy/lite/utils/__init__.py | 7 +- lmdeploy/lite/utils/cal_qparams.py | 48 +- lmdeploy/lite/utils/collect.py | 4 - lmdeploy/lite/utils/export_turbomind.py | 70 + lmdeploy/lite/utils/load.py | 55 + lmdeploy/model.py | 283 ++- lmdeploy/pytorch/block.py | 106 + lmdeploy/pytorch/chat.py | 6 +- lmdeploy/pytorch/config.py | 1 + lmdeploy/pytorch/engine/engine.py | 99 +- lmdeploy/pytorch/engine/model_agent.py | 61 +- lmdeploy/pytorch/engine/request.py | 49 +- .../pytorch/kernels/apply_rotary_pos_emb.py | 1 + lmdeploy/pytorch/kernels/fill_kv_cache.py | 164 +- lmdeploy/pytorch/kernels/mbgmm.py | 267 +++ lmdeploy/pytorch/kernels/mbgmv.py | 223 ++ lmdeploy/pytorch/kernels/pagedattention.py | 76 +- lmdeploy/pytorch/kernels/rerope_attention.py | 2 +- lmdeploy/pytorch/kernels/rms_norm.py | 24 +- lmdeploy/pytorch/messages.py | 91 +- lmdeploy/pytorch/models/chatglm2.py | 4 +- lmdeploy/pytorch/models/falcon.py | 3 +- lmdeploy/pytorch/models/functional.py | 3 +- lmdeploy/pytorch/paging/__init__.py | 4 +- lmdeploy/pytorch/paging/block_manager.py | 520 +++-- .../eviction_helper/base_eviction_helper.py | 30 +- .../eviction_helper/copy_eviction_helper.py | 57 +- .../recompute_eviction_helper.py | 32 +- lmdeploy/pytorch/paging/scheduler.py | 184 +- lmdeploy/serve/async_engine.py | 244 +- lmdeploy/serve/client.py | 9 +- lmdeploy/serve/gradio/api_server_backend.py | 186 ++ lmdeploy/serve/gradio/app.py | 535 +---- lmdeploy/serve/gradio/constants.py | 28 + lmdeploy/serve/gradio/css.py | 18 - .../serve/gradio/triton_server_backend.py | 143 ++ lmdeploy/serve/gradio/turbomind_coupled.py | 194 ++ lmdeploy/serve/openai/api_client.py | 342 ++- lmdeploy/serve/openai/api_server.py | 262 ++- lmdeploy/serve/openai/protocol.py | 11 +- lmdeploy/serve/turbomind/chatbot.py | 44 +- lmdeploy/serve/turbomind/deploy.py | 1046 --------- .../triton_models/preprocessing/1/model.py | 30 +- .../triton_models/preprocessing/config.pbtxt | 37 - lmdeploy/serve/turbomind/utils.py | 6 +- lmdeploy/tokenizer.py | 92 +- lmdeploy/turbomind/chat.py | 91 +- lmdeploy/turbomind/decode.py | 3 +- lmdeploy/turbomind/deploy/__init__.py | 1 + lmdeploy/turbomind/deploy/converter.py | 260 +++ .../turbomind/deploy/source_model/__init__.py | 8 + .../turbomind/deploy/source_model/baichuan.py | 67 + .../deploy/source_model/baichuan_awq.py | 87 + .../turbomind/deploy/source_model/base.py | 175 ++ .../turbomind/deploy/source_model/llama.py | 198 ++ .../deploy/source_model/llama_awq.py | 68 + .../deploy/source_model/meta_llama.py | 224 ++ .../turbomind/deploy/source_model/qwen.py | 113 + .../turbomind/deploy/source_model/qwen_awq.py | 58 + .../turbomind/deploy/target_model/__init__.py | 3 + .../turbomind/deploy/target_model/base.py | 269 +++ lmdeploy/turbomind/deploy/target_model/fp.py | 80 + lmdeploy/turbomind/deploy/target_model/w4.py | 162 ++ lmdeploy/turbomind/generate_gemm_config.py | 4 +- lmdeploy/turbomind/hf_repo/config.json | 11 + .../hf_repo/configuration_lmdeploy.py | 36 + .../turbomind/hf_repo/modeling_lmdeploy.py | 226 ++ lmdeploy/turbomind/turbomind.py | 369 ++- lmdeploy/turbomind/utils.py | 120 + lmdeploy/version.py | 2 +- requirements.txt | 23 +- requirements/build.txt | 2 + requirements/lite.txt | 3 + requirements/readthedocs.txt | 2 +- requirements/runtime.txt | 9 + requirements/serve.txt | 5 + requirements/test.txt | 5 + setup.py | 56 +- src/turbomind/kernels/CMakeLists.txt | 1 + .../kernels/bert_preprocess_kernels.cu | 16 +- .../decoder_masked_multihead_attention_128.cu | 6 + ...er_masked_multihead_attention_template.cuh | 11 + .../CMakeLists.txt | 16 + .../decoder_multihead_attention/array_ops.h | 490 ++++ .../decoder_multihead_attention.cu | 115 + .../decoder_multihead_attention.h | 12 + .../decoder_multihead_attention_params.h | 69 + .../decoder_multihead_attention_template.h | 932 ++++++++ .../decoder_multihead_attention/iterator.h | 333 +++ .../decoder_multihead_attention/kv_cache.cu | 481 ++++ .../decoder_multihead_attention/kv_cache.h | 67 + .../test_decoder_multihead_attention.cu | 328 +++ .../decoder_multihead_attention/test_utils.cu | 252 +++ .../decoder_multihead_attention/test_utils.h | 43 + .../decoder_multihead_attention/thread_map.h | 98 + src/turbomind/kernels/gemm_s_f16/common.h | 83 +- .../kernels/gemm_s_f16/cta_iterator.h | 19 + .../kernels/gemm_s_f16/gemm_template.h | 22 +- .../kernels/sampling_penalty_kernels.cu | 6 + .../kernels/unfused_attention_kernels.cu | 81 +- .../kernels/unfused_attention_kernels.h | 3 +- src/turbomind/layers/DynamicDecodeLayer.cc | 4 +- src/turbomind/layers/DynamicDecodeLayer.h | 9 - .../sampling_layers/BaseSamplingLayer.cc | 32 - .../sampling_layers/BaseSamplingLayer.h | 11 +- .../sampling_layers/TopKSamplingLayer.cu | 4 +- .../sampling_layers/TopKSamplingLayer.h | 2 - .../sampling_layers/TopPSamplingLayer.cu | 5 +- .../sampling_layers/TopPSamplingLayer.h | 2 - src/turbomind/models/llama/Barrier.h | 17 +- src/turbomind/models/llama/BlockManager.cc | 286 +++ src/turbomind/models/llama/BlockManager.h | 153 ++ src/turbomind/models/llama/CMakeLists.txt | 17 +- src/turbomind/models/llama/LlamaBatch.cc | 1972 ++++++++++------- src/turbomind/models/llama/LlamaBatch.h | 306 ++- .../models/llama/LlamaCacheManager.cc | 192 -- .../models/llama/LlamaCacheManager.h | 102 - .../llama/LlamaContextAttentionLayer.cc | 423 ---- .../models/llama/LlamaContextAttentionLayer.h | 130 -- .../models/llama/LlamaContextDecoder.cc | 290 --- .../models/llama/LlamaContextDecoder.h | 112 - src/turbomind/models/llama/LlamaDecoder.cc | 247 --- src/turbomind/models/llama/LlamaDecoder.h | 96 - .../models/llama/LlamaDecoderLayerWeight.cc | 73 +- .../models/llama/LlamaDecoderLayerWeight.h | 3 + .../llama/LlamaDecoderSelfAttentionLayer.cc | 309 --- .../llama/LlamaDecoderSelfAttentionLayer.h | 96 - src/turbomind/models/llama/LlamaFfnLayer.cc | 26 +- src/turbomind/models/llama/LlamaV2.cc | 333 +-- src/turbomind/models/llama/LlamaV2.h | 100 +- src/turbomind/models/llama/LlamaWeight.cc | 29 + src/turbomind/models/llama/LlamaWeight.h | 2 + src/turbomind/models/llama/Request.h | 30 +- src/turbomind/models/llama/SequenceManager.cc | 466 ++++ src/turbomind/models/llama/SequenceManager.h | 147 ++ src/turbomind/models/llama/copy.h | 37 + .../llama/flash_attention2/CMakeLists.txt | 6 +- .../flash_fwd_launch_template.h | 10 + .../llama/flash_attention2/static_switch.h | 8 + .../models/llama/llama_decoder_kernels.cu | 2 + src/turbomind/models/llama/llama_kernels.cu | 961 ++++---- src/turbomind/models/llama/llama_kernels.h | 63 +- src/turbomind/models/llama/llama_params.h | 24 +- src/turbomind/models/llama/llama_utils.cu | 9 + src/turbomind/models/llama/llama_utils.h | 15 + .../models/llama/test_cache_manager.cc | 116 + .../models/llama/unified_attention_layer.cc | 630 ++++++ .../models/llama/unified_attention_layer.h | 186 ++ src/turbomind/models/llama/unified_decoder.cc | 257 +++ src/turbomind/models/llama/unified_decoder.h | 99 + src/turbomind/python/bind.cpp | 36 +- .../triton_backend/libfastertransformer.cc | 2 +- .../triton_backend/llama/LlamaTritonModel.cc | 205 +- .../triton_backend/llama/LlamaTritonModel.h | 13 +- .../llama/LlamaTritonModelInstance.cc | 69 +- .../transformer_triton_backend.hpp | 5 + src/turbomind/triton_backend/triton_utils.hpp | 1 + src/turbomind/utils/cuda_utils.h | 2 +- src/turbomind/utils/debug_utils.h | 7 + src/turbomind/utils/dispatch.h | 35 + tests/pytorch/kernel/test_mbgmm.py | 122 + tests/pytorch/kernel/test_mbgmv.py | 112 + tests/pytorch/kernel/test_paged_attention.py | 10 +- tests/pytorch/paging/test_block_manager.py | 165 +- tests/pytorch/paging/test_scheduler.py | 21 +- tests/test_lmdeploy/test_cli.py | 51 + tests/test_lmdeploy/test_tokenizer.py | 24 + 234 files changed, 18923 insertions(+), 7696 deletions(-) create mode 100644 .github/scripts/test_triton_server.py create mode 100644 .github/workflows/cuda12.1-whl-release.yml create mode 100644 .github/workflows/test.yml create mode 100755 benchmark/benchmark_13b.sh create mode 100755 benchmark/benchmark_20b.sh create mode 100755 benchmark/benchmark_70b.sh create mode 100755 benchmark/benchmark_7b.sh create mode 100644 docs/en/benchmark/a100_fp16.md create mode 100644 docs/en/benchmark/profile_api_server.md create mode 100644 docs/en/benchmark/profile_generation.md create mode 100644 docs/en/benchmark/profile_throughput.md create mode 100644 docs/en/benchmark/profile_triton_server.md create mode 100644 docs/en/load_hf.md create mode 100644 docs/en/turbomind_config.md create mode 100644 docs/zh_cn/benchmark/profile_api_server.md create mode 100644 docs/zh_cn/benchmark/profile_generation.md create mode 100644 docs/zh_cn/benchmark/profile_throughput.md create mode 100644 docs/zh_cn/benchmark/profile_triton_server.md create mode 100644 docs/zh_cn/load_hf.md create mode 100644 docs/zh_cn/turbomind_config.md create mode 100644 lmdeploy/cli/__init__.py create mode 100644 lmdeploy/cli/chat.py create mode 100644 lmdeploy/cli/cli.py create mode 100644 lmdeploy/cli/lite.py create mode 100644 lmdeploy/cli/serve.py create mode 100644 lmdeploy/lite/utils/export_turbomind.py create mode 100644 lmdeploy/lite/utils/load.py create mode 100644 lmdeploy/pytorch/kernels/mbgmm.py create mode 100644 lmdeploy/pytorch/kernels/mbgmv.py create mode 100644 lmdeploy/serve/gradio/api_server_backend.py create mode 100644 lmdeploy/serve/gradio/constants.py delete mode 100644 lmdeploy/serve/gradio/css.py create mode 100644 lmdeploy/serve/gradio/triton_server_backend.py create mode 100644 lmdeploy/serve/gradio/turbomind_coupled.py delete mode 100644 lmdeploy/serve/turbomind/deploy.py create mode 100644 lmdeploy/turbomind/deploy/__init__.py create mode 100644 lmdeploy/turbomind/deploy/converter.py create mode 100644 lmdeploy/turbomind/deploy/source_model/__init__.py create mode 100644 lmdeploy/turbomind/deploy/source_model/baichuan.py create mode 100644 lmdeploy/turbomind/deploy/source_model/baichuan_awq.py create mode 100644 lmdeploy/turbomind/deploy/source_model/base.py create mode 100644 lmdeploy/turbomind/deploy/source_model/llama.py create mode 100644 lmdeploy/turbomind/deploy/source_model/llama_awq.py create mode 100644 lmdeploy/turbomind/deploy/source_model/meta_llama.py create mode 100644 lmdeploy/turbomind/deploy/source_model/qwen.py create mode 100644 lmdeploy/turbomind/deploy/source_model/qwen_awq.py create mode 100644 lmdeploy/turbomind/deploy/target_model/__init__.py create mode 100644 lmdeploy/turbomind/deploy/target_model/base.py create mode 100644 lmdeploy/turbomind/deploy/target_model/fp.py create mode 100644 lmdeploy/turbomind/deploy/target_model/w4.py create mode 100644 lmdeploy/turbomind/hf_repo/config.json create mode 100644 lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py create mode 100644 lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py create mode 100644 lmdeploy/turbomind/utils.py create mode 100644 requirements/build.txt create mode 100644 requirements/lite.txt create mode 100644 requirements/runtime.txt create mode 100644 requirements/serve.txt create mode 100644 requirements/test.txt create mode 100644 src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt create mode 100644 src/turbomind/kernels/decoder_multihead_attention/array_ops.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu create mode 100644 src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/iterator.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu create mode 100644 src/turbomind/kernels/decoder_multihead_attention/kv_cache.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu create mode 100644 src/turbomind/kernels/decoder_multihead_attention/test_utils.cu create mode 100644 src/turbomind/kernels/decoder_multihead_attention/test_utils.h create mode 100644 src/turbomind/kernels/decoder_multihead_attention/thread_map.h create mode 100644 src/turbomind/models/llama/BlockManager.cc create mode 100644 src/turbomind/models/llama/BlockManager.h delete mode 100644 src/turbomind/models/llama/LlamaCacheManager.cc delete mode 100644 src/turbomind/models/llama/LlamaCacheManager.h delete mode 100644 src/turbomind/models/llama/LlamaContextAttentionLayer.cc delete mode 100644 src/turbomind/models/llama/LlamaContextAttentionLayer.h delete mode 100644 src/turbomind/models/llama/LlamaContextDecoder.cc delete mode 100644 src/turbomind/models/llama/LlamaContextDecoder.h delete mode 100644 src/turbomind/models/llama/LlamaDecoder.cc delete mode 100644 src/turbomind/models/llama/LlamaDecoder.h delete mode 100644 src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc delete mode 100644 src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h create mode 100644 src/turbomind/models/llama/SequenceManager.cc create mode 100644 src/turbomind/models/llama/SequenceManager.h create mode 100644 src/turbomind/models/llama/copy.h create mode 100644 src/turbomind/models/llama/test_cache_manager.cc create mode 100644 src/turbomind/models/llama/unified_attention_layer.cc create mode 100644 src/turbomind/models/llama/unified_attention_layer.h create mode 100644 src/turbomind/models/llama/unified_decoder.cc create mode 100644 src/turbomind/models/llama/unified_decoder.h create mode 100644 src/turbomind/utils/debug_utils.h create mode 100644 src/turbomind/utils/dispatch.h create mode 100644 tests/pytorch/kernel/test_mbgmm.py create mode 100644 tests/pytorch/kernel/test_mbgmv.py create mode 100644 tests/test_lmdeploy/test_cli.py create mode 100644 tests/test_lmdeploy/test_tokenizer.py diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml index 86838836de..d9e6956735 100644 --- a/.github/ISSUE_TEMPLATE/1-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -25,6 +25,18 @@ body: A placeholder for the command. validations: required: true +- type: textarea + attributes: + label: Environment + description: | + 1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here. + 2. You may add addition that may be helpful for locating the problem, such as + - How you installed PyTorch \[e.g., pip, conda, source\] + - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) + placeholder: Environment here. + render: Shell + validations: + required: true - type: textarea attributes: label: Error traceback diff --git a/.github/scripts/test_triton_server.py b/.github/scripts/test_triton_server.py new file mode 100644 index 0000000000..a5146b150c --- /dev/null +++ b/.github/scripts/test_triton_server.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from subprocess import PIPE, Popen + +import fire + + +def parse_dialogue(inputs: str): + sep = 'double enter to end input >>>' + dialogues = inputs.strip() + if dialogues.endswith(sep): + dialogues = dialogues[:-len(sep)] + dialogues = dialogues.strip() + dialogues = dialogues.split(sep) + dialogues = [d.strip() for d in dialogues] + return dialogues[1:] + + +def test(port=33337): + cmd = [f'lmdeploy serve triton_client localhost:{port}'] + + test_cases = [ + dict( + prompts='Hello! Please answer in English.', + keywords=['Hello', 'hi'], + ), + dict( + prompts='您好! 请用中文回答。', + keywords=['您好', '你好'], + ), + dict( + prompts='How many days does a week have? ', + keywords=['seven', '7'], + ), + dict( + prompts='一周有多少天?请用中文回答。', + keywords=['七天', '7天'], + ), + ] + + sep = '\n\n' + end = sep + 'exit\n\n\n' + all_pass = True + for cases in test_cases: + quest = cases['prompts'] + keywords = cases['keywords'] + inputs = quest + end + print(f'Test Input prompts: {quest}\nKey words: {keywords}') + time.sleep(5) + + with Popen(cmd, + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + shell=True, + text=True, + encoding='utf-8') as proc: + out, err = proc.communicate(input=inputs) + print(f'Output: {out}') + if proc.returncode == 0: + out = parse_dialogue(out)[0] + success = any([k in out for k in keywords]) + if not success: + print(f'>>> Failed to output keywords: {out} {keywords}') + all_pass = False + else: + all_pass = False + print(f'Failed to get outputs: {out} {err}') + assert all_pass, 'Tests failed!' + + +if __name__ == '__main__': + fire.Fire(test) diff --git a/.github/workflows/cuda12.1-whl-release.yml b/.github/workflows/cuda12.1-whl-release.yml new file mode 100644 index 0000000000..4695a3a5d2 --- /dev/null +++ b/.github/workflows/cuda12.1-whl-release.yml @@ -0,0 +1,108 @@ +name: cuda12.1-whl-release + +on: + push: + tags: + - '*' + workflow_dispatch: + +permissions: + contents: write + +jobs: + linux-build: + strategy: + matrix: + pyver: [py38, py39, py310, py311] + runs-on: ubuntu-latest + env: + PYTHON_VERSION: ${{ matrix.pyver }} + PLAT_NAME: manylinux2014_x86_64 + DOCKER_TAG: cuda12.1 + OUTPUT_FOLDER: cuda12.1_dist + CUDA_VER: 12.1 + steps: + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + # This might remove tools that are actually needed, if set to "true" but frees about 6 GB + tool-cache: false + docker-images: false + # All of these default to true, but feel free to set to "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: false + - name: Checkout repository + uses: actions/checkout@v3 + - name: Build + run: | + echo ${PYTHON_VERSION} + echo ${PLAT_NAME} + echo ${DOCKER_TAG} + echo ${OUTPUT_FOLDER} + # remove -it + sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh + bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} + - name: Upload Artifacts + uses: actions/upload-artifact@v3 + with: + if-no-files-found: error + path: builder/manywheel/${{ env.OUTPUT_FOLDER }}/* + retention-days: 1 + + windows-build: + strategy: + matrix: + pyver: ['3.8', '3.9', '3.10', '3.11'] + runs-on: windows-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Set up python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.pyver }} + - name: Install python packages + run: | + pip install pybind11 wheel + - uses: Jimver/cuda-toolkit@v0.2.11 + id: cuda-toolkit + with: + cuda: '12.1.0' + use-github-cache: false + - name: Build wheel + run: | + mkdir build + cd build + pip install -U setuptools + ..\builder\windows\generate.ps1 + cmake --build . --config Release -- /m > build.log.txt + cmake --install . --config Release + cd .. + rm build -Force -Recurse + python setup.py bdist_wheel -d build/wheel + - name: Upload Artifacts + uses: actions/upload-artifact@v3 + with: + if-no-files-found: error + path: build/wheel/* + retention-days: 1 + + publish: + runs-on: ubuntu-latest + environment: 'prod' + needs: + - linux-build + - windows-build + steps: + - name: Download artifacts + uses: actions/download-artifact@v3 + - name: Display artifacts + run: ls artifact/ -lh + - name: Publish + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: artifact/* diff --git a/.github/workflows/linux-x64-gpu.yml b/.github/workflows/linux-x64-gpu.yml index d940408ce7..ca75f22878 100644 --- a/.github/workflows/linux-x64-gpu.yml +++ b/.github/workflows/linux-x64-gpu.yml @@ -25,7 +25,11 @@ permissions: contents: read jobs: - cuda-118: + build: + strategy: + matrix: + cudaver: [11.8, 12.1] + name: cuda-${{ matrix.cudaver }} runs-on: ubuntu-latest steps: - name: Free disk space @@ -45,12 +49,12 @@ jobs: - name: Build uses: addnab/docker-run-action@v3 with: - image: openmmlab/lmdeploy-builder:cuda11.8 - options: -v ${{ github.workspace }}:/work --cpus=1.8 + image: openmmlab/lmdeploy-builder:cuda${{ matrix.cudaver }} + options: -v ${{ github.workspace }}:/work run: | cd /work source /opt/conda/bin/activate conda activate py38 mkdir build && cd build - bash ../generate.sh + bash ../generate.sh make make -j$(nproc) && make install diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 7c56e08f7d..7c9c0f8dc1 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -75,6 +75,8 @@ jobs: run: | mkdir build cd build + # https://github.com/pypa/setuptools/issues/1631 + pip install -U setuptools ..\builder\windows\generate.ps1 cmake --build . --config Release -- /m > build.log.txt cmake --install . --config Release diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000000..6f052b0d04 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,215 @@ +name: test + +on: + pull_request: + paths: + - ".github/scripts/test_triton_server.py" + - ".github/workflows/test.yml" + - "cmake/**" + - "src/**" + - "3rdparty/**" + - "lmdeploy/**" + - "requirements/**" + - "requirements.txt" + - "CMakeLists.txt" + - "setup.py" + push: + branches: + - main + paths: + - "lmdeploy/version.py" + tags: + - "v*.*.*" + + workflow_dispatch: + inputs: + markers: + required: false + description: 'Tested markers. eg: "-m internlm_chat_7b"' + type: string + default: '' + +env: + HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache + HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai + + +jobs: + test_functions: + runs-on: [self-hosted, linux-a100] + timeout-minutes: 4320 # 72hours + environment: 'prod' + env: + REPORT_DIR: /nvme/qa_test_models/test-reports + container: + image: nvcr.io/nvidia/tritonserver:22.12-py3 + options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip" + volumes: + - /nvme/github-actions/pip-cache:/root/.cache/pip + - /nvme/github-actions/packages:/root/packages + - /nvme/qa_test_models:/nvme/qa_test_models + - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro + steps: + - name: Setup systems + run: | + rm /etc/apt/sources.list.d/cuda*.list + apt-get update && apt-get install -y --no-install-recommends rapidjson-dev \ + libgoogle-glog-dev libgl1 openjdk-8-jre-headless + dpkg -i /root/packages/allure_2.24.1-1_all.deb + rm -rf /var/lib/apt/lists/* + - name: Clone repository + uses: actions/checkout@v2 + - name: Install pytorch + run: | + python3 -m pip cache dir + python3 -m pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 + - name: Build lmdeploy + run: | + python3 -m pip install cmake + python3 -m pip install -r requirements/build.txt + # use cached build + cp -r ../../build build + cd build + cmake .. \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DBUILD_PY_FFI=ON \ + -DBUILD_MULTI_GPU=ON \ + -DCMAKE_CUDA_FLAGS="-lineinfo" \ + -DUSE_NVTX=ON \ + -DSM=80 \ + -DCMAKE_CUDA_ARCHITECTURES=80 \ + -DBUILD_TEST=OFF + make -j$(nproc) && make install + - name: Install lmdeploy + run: | + python3 -m pip install packaging protobuf transformers_stream_generator + python3 -m pip install -r requirements.txt -r requirements/test.txt + python3 -m pip install . + - name: Check env + run: | + python3 -m pip list + lmdeploy check_env + - name: Test lmdeploy + run: | + echo "TODO: awaiting PR of adding autotest" + # pytest autotest ${{github.event.inputs.markers}} --alluredir=allure-results --clean-alluredir + - name: Generate reports + if: always() + run: | + if test -D "allure-results"; then + export date_today="$(date +'%Y%m%d-%H%M%S')" + export report_dir="$REPORT_DIR/$date_today" + echo "Save report to $ALLURE_DIR" + allure generate -c -o $report_dir + fi + - name: Clear workfile + if: always() + run: | + export workdir=$(pwd) + cd .. + rm -rf $workdir + mkdir $workdir + chmod -R 777 $workdir + + test_triton: + runs-on: [self-hosted, linux-a100] + timeout-minutes: 4320 # 72hours + environment: 'prod' + env: + HF_MODEL: /nvme/qa_test_models/internlm-chat-20b + WORKDIR: /nvme/qa_test_models/triton_workspace + TB_MODEL: internlm-chat-20b-fp16-tp2 + GRPC_PORT: 33337 + steps: + - name: Clone repository + uses: actions/checkout@v2 + - name: Create test container + run: | + export CONTAINER_ID=$(docker create \ + --rm \ + --gpus='"device=0,1"' \ + --shm-size 16g \ + --cap-add=SYS_PTRACE \ + --cap-add=SYS_ADMIN \ + --security-opt seccomp=unconfined \ + --name lmdeploy-ci-triton \ + --workdir /__w/lmdeploy/lmdeploy \ + --env PIP_CACHE_DIR=/root/.cache/pip \ + --env NCCL_LAUNCH_MODE=GROUP \ + -v $(pwd)/../../:/__w \ + -v ${HF_MODEL}:/root/workspace/hf_model \ + -v ${WORKDIR}:/root/workspace/workdir \ + -v ${HOST_PIP_CACHE_DIR}:/root/.cache/pip \ + -v ${HOST_LOCALTIME}:/etc/localtime:ro \ + openmmlab/lmdeploy:latest tail -f /dev/null \ + ) + docker start $CONTAINER_ID + echo "CONTAINER_ID=$CONTAINER_ID" + echo "CONTAINER_ID=$CONTAINER_ID" >> $GITHUB_ENV + - name: Build lmdeploy from source + run: | + docker exec $CONTAINER_ID cp -r ../../build build + docker exec --workdir /__w/lmdeploy/lmdeploy/build \ + --env http_proxy=${{secrets.PROXY}} \ + --env https_proxy=${{secrets.PROXY}} \ + $CONTAINER_ID cmake .. \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DBUILD_PY_FFI=ON \ + -DBUILD_MULTI_GPU=ON \ + -DCMAKE_CUDA_FLAGS="-lineinfo" \ + -DUSE_NVTX=ON \ + -DSM=80 \ + -DCMAKE_CUDA_ARCHITECTURES=80 \ + -DBUILD_TEST=OFF + docker exec --workdir /__w/lmdeploy/lmdeploy/build $CONTAINER_ID make -j$(nproc) + docker exec --workdir /__w/lmdeploy/lmdeploy/build $CONTAINER_ID make install + - name: Install lmdeploy + run: | + docker exec \ + --env http_proxy=${{secrets.PROXY}} \ + --env https_proxy=${{secrets.PROXY}} \ + $CONTAINER_ID python3 -m pip install tritonclient[grpc] + docker exec \ + --env http_proxy=${{secrets.PROXY}} \ + --env https_proxy=${{secrets.PROXY}} \ + $CONTAINER_ID python3 -m pip install -r requirements/test.txt + docker exec $CONTAINER_ID python3 -m pip install . + # docker exec $CONTAINER_ID check_env + - name: Convert to turbomind model + run: | + docker exec $CONTAINER_ID \ + lmdeploy convert \ + --model-name internlm-chat-20b \ + --model-path /root/workspace/hf_model \ + --tp 2 \ + --dst-path /root/workspace/workdir/${TB_MODEL} + - name: Start triton server service + run: | + docker exec --detach $CONTAINER_ID \ + tritonserver \ + --model-repository=/root/workspace/workdir/${TB_MODEL}/model_repository \ + --allow-http=0 \ + --allow-grpc=1 \ + --grpc-port=${GRPC_PORT} \ + --log-verbose=0 \ + --allow-metrics=1 + # wait for triton server to fully start up + sleep 180s + - name: Test triton server + run: | + docker exec \ + --env no_proxy=localhost,127.0.0.1 \ + $CONTAINER_ID python3 .github/scripts/test_triton_server.py --port ${GRPC_PORT} + - name: Clear workfile + if: always() + run: | + export workdir=$(pwd) + docker exec --workdir /__w/lmdeploy $CONTAINER_ID rm -rf lmdeploy + mkdir $workdir + chmod -R 777 $workdir + docker exec --workdir /__w/lmdeploy $CONTAINER_ID rm -rf /root/workspace/workdir/${TB_MODEL} + docker stop $CONTAINER_ID diff --git a/.github/workflows/windows-x64-gpu.yml b/.github/workflows/windows-x64-gpu.yml index 93839cfb89..c0e4e009cf 100644 --- a/.github/workflows/windows-x64-gpu.yml +++ b/.github/workflows/windows-x64-gpu.yml @@ -25,7 +25,11 @@ permissions: contents: read jobs: - cuda-118: + build: + strategy: + matrix: + cudaver: [11.8.0, 12.1.0] + name: cuda-${{ matrix.cudaver }} runs-on: windows-latest steps: - name: Checkout repository @@ -40,7 +44,7 @@ jobs: - uses: Jimver/cuda-toolkit@v0.2.11 id: cuda-toolkit with: - cuda: '11.8.0' + cuda: ${{ matrix.cudaver }} use-github-cache: false - name: Build wheel run: | diff --git a/.gitignore b/.gitignore index 7080c3d634..318740da9e 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ work_dir*/ *.bin *config.json *generate_config.json +!lmdeploy/turbomind/hf_repo/config.json # Pytorch *.pt @@ -74,3 +75,5 @@ work_dir*/ *.out *.csv *.pkl + +!CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b1979ffce..27b6b150e7 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,22 @@ option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF option(BUILD_FAST_MATH "Build in fast math mode" ON) +# the environment variable +# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0 +# must be set at runtime +# https://github.com/google/sanitizers/issues/1322 +if (LMDEPLOY_ASAN_ENABLE) + add_compile_options($<$:-fsanitize=address>) + add_link_options(-fsanitize=address) +endif () + +# notice that ubsan has linker issues for ubuntu < 18.04, see +# https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed +if (LMDEPLOY_UBSAN_ENABLE) + add_compile_options($<$:-fsanitize=undefined>) + add_link_options(-fsanitize=undefined) +endif () + if(BUILD_MULTI_GPU) message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL") add_definitions("-DBUILD_MULTI_GPU") @@ -87,7 +103,9 @@ if(USE_TRITONSERVER_DATATYPE) endif() set(CXX_STD "17" CACHE STRING "C++ standard") - +# enable gold linker for binary and .so +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fuse-ld=gold") set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) set(TF_PATH "" CACHE STRING "TensorFlow path") @@ -180,12 +198,16 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED") -set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3") # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") -set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") +set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") +set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") + if(BUILD_FAST_MATH) -set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") -message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}") + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math") + message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}") endif() set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -252,11 +274,15 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');" OUTPUT_VARIABLE USE_CXX11_ABI) message("-- USE_CXX11_ABI=${USE_CXX11_ABI}") if (USE_CXX11_ABI) + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") else() + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0") @@ -327,6 +353,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ diff --git a/README.md b/README.md index a2de4d6ac0..7da9778b40 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ ______________________________________________________________________ ## News 🎉 +- \[2023/11\] Turbomind supports loading hf model directly. Click [here](./docs/en/load_hf.md) for details. +- \[2023/11\] TurboMind major upgrades, including: Paged Attention, faster attention kernels without sequence length limitation, 2x faster KV8 kernels, Split-K decoding (Flash Decoding), and W4A16 inference for sm_75 - \[2023/09\] TurboMind supports Qwen-14B - \[2023/09\] TurboMind supports InternLM-20B - \[2023/09\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/supported_models/codellama.md) for deployment guide @@ -52,7 +54,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by ## Supported Models -`LMDeploy` has two inference backends, `Pytorch` and `TurboMind`. +`LMDeploy` has two inference backends, `Pytorch` and `TurboMind`. You can run `lmdeploy list` to check the supported model names. ### TurboMind @@ -63,12 +65,13 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by | :----------: | :-------------: | :--: | :-----: | :---: | :--: | | Llama | Yes | Yes | Yes | Yes | No | | Llama2 | Yes | Yes | Yes | Yes | No | +| SOLAR | Yes | Yes | Yes | Yes | No | | InternLM-7B | Yes | Yes | Yes | Yes | No | | InternLM-20B | Yes | Yes | Yes | Yes | No | -| QWen-7B | Yes | Yes | Yes | No | No | -| QWen-14B | Yes | Yes | Yes | No | No | +| QWen-7B | Yes | Yes | Yes | Yes | No | +| QWen-14B | Yes | Yes | Yes | Yes | No | | Baichuan-7B | Yes | Yes | Yes | Yes | No | -| Baichuan2-7B | Yes | Yes | No | No | No | +| Baichuan2-7B | Yes | Yes | Yes | Yes | No | | Code Llama | Yes | Yes | No | No | No | ### Pytorch @@ -102,32 +105,28 @@ Install lmdeploy with pip ( python 3.8+) or [from source](./docs/en/build.md) pip install lmdeploy ``` -### Deploy InternLM - -#### Get InternLM model - -```shell -# 1. Download InternLM model - -# Make sure you have git-lfs installed (https://git-lfs.com) -git lfs install -git clone https://huggingface.co/internlm/internlm-chat-7b-v1_1 /path/to/internlm-chat-7b +> **Note**
+> `pip install lmdeploy` can only install the runtime required packages. If users want to run codes from modules like `lmdeploy.lite` and `lmdeploy.serve`, they need to install the extra required packages. +> For instance, running `pip install lmdeploy[lite]` would install extra dependencies for `lmdeploy.lite` module. +> +> - `all`: Install lmdeploy with all dependencies in `requirements.txt` +> - `lite`: Install lmdeploy with extra dependencies in `requirements/lite.txt` +> - `serve`: Install lmdeploy with dependencies in `requirements/serve.txt` -# if you want to clone without large files – just their pointers -# prepend your git clone with the following env var: -GIT_LFS_SKIP_SMUDGE=1 +### Deploy InternLM -# 2. Convert InternLM model to turbomind's format, which will be in "./workspace" by default -python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-chat-7b +To use TurboMind inference engine, you need to first convert the model into TurboMind format. Currently, we support online conversion and offline conversion. With online conversion, TurboMind can load the Huggingface model directly. While with offline conversion, you should save the converted model first before using it. -``` +The following use [internlm/internlm-chat-7b-v1_1](https://huggingface.co/internlm/internlm-chat-7b-v1_1) as a example to show how to use turbomind with online conversion. You can refer to [load_hf.md](docs/en/load_hf.md) for other methods. #### Inference by TurboMind ```shell -python -m lmdeploy.turbomind.chat ./workspace +lmdeploy chat turbomind internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` +> **Note**
The internlm/internlm-chat-7b-v1_1 model will be downloaded under `.cache` folder. You can also use a local path here. + > **Note**
> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind.
> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc. @@ -139,7 +138,10 @@ python -m lmdeploy.turbomind.chat ./workspace #### Serving with gradio ```shell -python3 -m lmdeploy.serve.gradio.app ./workspace +# install lmdeploy with extra dependencies +pip install lmdeploy[serve] + +lmdeploy serve gradio internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) @@ -149,49 +151,30 @@ python3 -m lmdeploy.serve.gradio.app ./workspace Launch inference server by: ```shell -python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +# install lmdeploy with extra dependencies +pip install lmdeploy[serve] + +lmdeploy serve api_server internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b --instance_num 32 --tp 1 ``` Then, you can communicate with it by command line, ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -python -m lmdeploy.serve.openai.api_client restful_api_url +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 +lmdeploy serve api_client api_server_url ``` or webui, ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True -python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` Refer to [restful_api.md](docs/en/restful_api.md) for more details. -#### Serving with Triton Inference Server - -Launch inference server by: - -```shell -bash workspace/service_docker_up.sh -``` - -Then, you can communicate with the inference server by command line, - -```shell -python3 -m lmdeploy.serve.client {server_ip_addresss}:33337 -``` - -or webui, - -```shell -python3 -m lmdeploy.serve.gradio.app {server_ip_addresss}:33337 -``` - -For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and so on, you can find the guide from [here](docs/en/serving.md) - ### Inference with PyTorch For detailed instructions on Inference pytorch models, see [here](docs/en/pytorch.md). @@ -199,7 +182,7 @@ For detailed instructions on Inference pytorch models, see [here](docs/en/pytorc #### Single GPU ```shell -python3 -m lmdeploy.pytorch.chat $NAME_OR_PATH_TO_HF_MODEL \ +lmdeploy chat torch $NAME_OR_PATH_TO_HF_MODEL \ --max_new_tokens 64 \ --temperture 0.8 \ --top_p 0.95 \ diff --git a/README_zh-CN.md b/README_zh-CN.md index 09c66c2826..e9b3734e41 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -20,6 +20,8 @@ ______________________________________________________________________ ## 更新 🎉 +- \[2023/11\] Turbomind 支持直接读取 Huggingface 模型。点击[这里](./docs/en/load_hf.md)查看使用方法 +- \[2023/11\] TurboMind 重磅升级。包括:Paged Attention、更快的且不受序列最大长度限制的 attention kernel、2+倍快的 KV8 kernels、Split-K decoding (Flash Decoding) 和 支持 sm_75 架构的 W4A16 - \[2023/09\] TurboMind 支持 Qwen-14B - \[2023/09\] TurboMind 支持 InternLM-20B 模型 - \[2023/09\] TurboMind 支持 Code Llama 所有功能:代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/supported_models/codellama.md)阅读部署方法 @@ -53,7 +55,7 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht ## 支持的模型 -`LMDeploy` 支持 `TurboMind` 和 `Pytorch` 两种推理后端 +`LMDeploy` 支持 `TurboMind` 和 `Pytorch` 两种推理后端。运行`lmdeploy list`可查看支持模型列表 ### TurboMind @@ -64,12 +66,13 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht | :----------: | :------: | :--: | :-----: | :---: | :--: | | Llama | Yes | Yes | Yes | Yes | No | | Llama2 | Yes | Yes | Yes | Yes | No | +| SOLAR | Yes | Yes | Yes | Yes | No | | InternLM-7B | Yes | Yes | Yes | Yes | No | | InternLM-20B | Yes | Yes | Yes | Yes | No | -| QWen-7B | Yes | Yes | Yes | No | No | -| QWen-14B | Yes | Yes | Yes | No | No | +| QWen-7B | Yes | Yes | Yes | Yes | No | +| QWen-14B | Yes | Yes | Yes | Yes | No | | Baichuan-7B | Yes | Yes | Yes | Yes | No | -| Baichuan2-7B | Yes | Yes | No | No | No | +| Baichuan2-7B | Yes | Yes | Yes | Yes | No | | Code Llama | Yes | Yes | No | No | No | ### Pytorch @@ -103,32 +106,27 @@ TurboMind 的 output token throughput 超过 2000 token/s, 整体比 DeepSpeed pip install lmdeploy ``` -### 部署 InternLM - -#### 获取 InternLM 模型 - -```shell -# 1. 下载 InternLM 模型 - -# Make sure you have git-lfs installed (https://git-lfs.com) -git lfs install -git clone https://huggingface.co/internlm/internlm-chat-7b-v1_1 /path/to/internlm-chat-7b +> **Note**
+> `pip install lmdeploy`默认安装runtime依赖包,使用lmdeploy的lite和serve功能时,用户需要安装额外依赖包。例如: `pip install lmdeploy[lite]` 会额外安装`lmdeploy.lite`模块的依赖包 +> +> - `all`: 安装`lmdeploy`所有依赖包,具体可查看`requirements.txt` +> - `lite`: 额外安装`lmdeploy.lite`模块的依赖包,具体可查看`requirements/lite.txt` +> - `serve`: 额外安装`lmdeploy.serve`模块的依赖包,具体可查看`requirements/serve.txt` -# if you want to clone without large files – just their pointers -# prepend your git clone with the following env var: -GIT_LFS_SKIP_SMUDGE=1 +### 部署 InternLM -# 2. 转换为 trubomind 要求的格式。默认存放路径为 ./workspace -python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-chat-7b +使用 TurboMind 推理模型需要先将模型转化为 TurboMind 的格式,目前支持在线转换和离线转换两种形式。在线转换可以直接加载 Huggingface 模型,离线转换需需要先保存模型再加载。 -``` +下面以 [internlm/internlm-chat-7b-v1_1](https://huggingface.co/internlm/internlm-chat-7b-v1_1) 为例,展示在线转换的使用方式。其他方式可参考[load_hf.md](docs/zh_cn/load_hf.md) #### 使用 turbomind 推理 ```shell -python3 -m lmdeploy.turbomind.chat ./workspace +lmdeploy chat turbomind internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` +> **Note**
internlm/internlm-chat-7b-v1_1 会自动下载到 `.cache` 文件夹,这里也可以传下载好的路径。 + > **Note**
> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。 @@ -139,7 +137,10 @@ python3 -m lmdeploy.turbomind.chat ./workspace #### 启动 gradio server ```shell -python3 -m lmdeploy.serve.gradio.app ./workspace +# 安装lmdeploy额外依赖 +pip install lmdeploy[serve] + +lmdeploy serve gradio internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) @@ -149,49 +150,30 @@ python3 -m lmdeploy.serve.gradio.app ./workspace 使用下面的命令启动推理服务: ```shell -python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +# 安装lmdeploy额外依赖 +pip install lmdeploy[serve] + +lmdeploy serve api_server internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b --instance_num 32 --tp 1 ``` 你可以通过命令行方式与推理服务进行对话: ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -python -m lmdeploy.serve.openai.api_client restful_api_url +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 +lmdeploy serve api_client api_server_url ``` 也可以通过 WebUI 方式来对话: ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True -python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` 更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)。 -#### 通过容器部署推理服务 - -使用下面的命令启动推理服务: - -```shell -bash workspace/service_docker_up.sh -``` - -你可以通过命令行方式与推理服务进行对话: - -```shell -python3 -m lmdeploy.serve.client {server_ip_addresss}:33337 -``` - -也可以通过 WebUI 方式来对话: - -```shell -python3 -m lmdeploy.serve.gradio.app {server_ip_addresss}:33337 -``` - -其他模型的部署方式,比如 LLaMA,LLaMA-2,vicuna等等,请参考[这里](docs/zh_cn/serving.md) - ### 基于 PyTorch 的推理 你必须确保环境中有安装 deepspeed: @@ -203,7 +185,7 @@ pip install deepspeed #### 单个 GPU ```shell -python3 -m lmdeploy.pytorch.chat $NAME_OR_PATH_TO_HF_MODEL\ +lmdeploy chat torch $NAME_OR_PATH_TO_HF_MODEL\ --max_new_tokens 64 \ --temperture 0.8 \ --top_p 0.95 \ diff --git a/benchmark/README.md b/benchmark/README.md index b5573ae2b8..057d38bb11 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -29,8 +29,8 @@ pip install nvidia-ml-py ```bash python profile_generation.py \ - --model-path /path/to/your/model \ - --concurrency 1 8 --prompt-tokens 0 512 --completion-tokens 2048 512 + /path/to/your/model \ + --concurrency 1 8 --prompt-tokens 1 512 --completion-tokens 2048 512 ``` ## profile serving diff --git a/benchmark/benchmark_13b.sh b/benchmark/benchmark_13b.sh new file mode 100755 index 0000000000..cdf7355471 --- /dev/null +++ b/benchmark/benchmark_13b.sh @@ -0,0 +1,92 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of llama2-13b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=1 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert llama2 ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then + exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 500 +crudini --set ${config_path} llama max_batch_size 128 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 16 32 64 \ + --csv ${output_path}/generation.csv +} + +################################# BENCHMARK AFTER TUNING GEMM ################################# +# tune gemm +head_num=$(crudini --get "${config_path}" llama head_num) +size_per_head=$(crudini --get "${config_path}" llama size_per_head) +vocab_size=$(crudini --get "${config_path}" llama vocab_size) +inter_size=$(crudini --get "${config_path}" llama inter_size) +tensor_para_size=$(crudini --get "${config_path}" llama tensor_para_size) +max_batch_size=$(crudini --get "${config_path}" llama max_batch_size) + +echo $head_num, $size_per_head, $vocab_size, $inter_size, $tensor_para_size, $max_batch_size + +python3 -m lmdeploy.turbomind.generate_gemm_config \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --vocab_size ${vocab_size} \ + --inter_size ${inter_size} \ + --tensor_para_size ${tensor_para_size} \ + --max_batch_size ${max_batch_size} + +output_path="${workspace_dir}"/output/"${model_foldername}"-tunned-gemm-tp"${tp}" +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} + +mv gemm_config.in ${output_path} diff --git a/benchmark/benchmark_20b.sh b/benchmark/benchmark_20b.sh new file mode 100755 index 0000000000..152978c453 --- /dev/null +++ b/benchmark/benchmark_20b.sh @@ -0,0 +1,92 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of internlm-20b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=2 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert internlm-20b ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then + exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 700 +crudini --set ${config_path} llama max_batch_size 128 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 16 32 64 \ + --csv ${output_path}/generation.csv +} + +################################# BENCHMARK AFTER TUNING GEMM ################################# +# tune gemm +head_num=$(crudini --get "${config_path}" llama head_num) +size_per_head=$(crudini --get "${config_path}" llama size_per_head) +vocab_size=$(crudini --get "${config_path}" llama vocab_size) +inter_size=$(crudini --get "${config_path}" llama inter_size) +tensor_para_size=$(crudini --get "${config_path}" llama tensor_para_size) +max_batch_size=$(crudini --get "${config_path}" llama max_batch_size) + +echo $head_num, $size_per_head, $vocab_size, $inter_size, $tensor_para_size, $max_batch_size + +python3 -m lmdeploy.turbomind.generate_gemm_config \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --vocab_size ${vocab_size} \ + --inter_size ${inter_size} \ + --tensor_para_size ${tensor_para_size} \ + --max_batch_size ${max_batch_size} + +output_path="${workspace_dir}"/output/"${model_foldername}"-tunned-gemm-tp"${tp}" +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} + +cp gemm_config.in ${output_path} diff --git a/benchmark/benchmark_70b.sh b/benchmark/benchmark_70b.sh new file mode 100755 index 0000000000..0f23033064 --- /dev/null +++ b/benchmark/benchmark_70b.sh @@ -0,0 +1,71 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of llama2-70b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=4 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert llama2 ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then + exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 4000 +crudini --set ${config_path} llama max_batch_size 256 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128 256) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 64 128 256 \ + --csv ${output_path}/generation.csv +} + +output_path="${workspace_dir}"/output/"${model_foldername}"-tp"${tp}" +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} diff --git a/benchmark/benchmark_7b.sh b/benchmark/benchmark_7b.sh new file mode 100755 index 0000000000..bf702d56ab --- /dev/null +++ b/benchmark/benchmark_7b.sh @@ -0,0 +1,93 @@ +#!/bin/bash +if [ -z "$1" ] +then + echo "Error. Please input the model path of llama2-7b model" + exit 1 +fi + +workspace_dir=$(dirname $(realpath "$0")) + +tp=1 +model_path="$1" +model_foldername=$(basename "$model_path") +turbomind_model_path="${workspace_dir}"/workspace/"${model_foldername}" + +# convert +lmdeploy convert llama2 ${model_path} --dst-path ${turbomind_model_path} --tp ${tp} +if [ $? != 0 ] +then +exit 1 +fi + +# update recommended config to config.ini +config_path=${turbomind_model_path}/triton_models/weights/config.ini + +apt-get update +apt-get install crudini -y + +crudini --set ${config_path} llama max_context_token_num 4 +crudini --set ${config_path} llama cache_chunk_size -1 +crudini --set ${config_path} llama cache_max_entry_count 1000 +crudini --set ${config_path} llama max_batch_size 128 +# end of update config + +cd ${workspace_dir} + +# download dataset +wget -O ShareGPT_V3_unfiltered_cleaned_split.json https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +benchmark_rpm () { + output_path=$1 + mkdir -p "${output_path}" + + batches=(64 128) + for batch in "${batches[@]}" + do + for i in {1..3} + do + python3 profile_throughput.py \ + ShareGPT_V3_unfiltered_cleaned_split.json \ + ${turbomind_model_path} \ + --concurrency "$batch" \ + --num_prompts 3000 \ + --csv ${output_path}/rpm_localhost_batch_"${batch}"_"${i}"th.csv + done + done +} + +benchmark_generation () { + output_path=$1 + mkdir -p "${output_path}" + + python3 profile_generation.py \ + ${turbomind_model_path} \ + --concurrency 1 16 32 64 \ + --csv ${output_path}/generation.csv +} + +################################# BENCHMARK AFTER TUNING GEMM ################################# +output_path="${workspace_dir}"/output/"${model_foldername}"-tunned-gemm-tp"${tp}" + +# tune gemm +head_num=$(crudini --get "${config_path}" llama head_num) +size_per_head=$(crudini --get "${config_path}" llama size_per_head) +vocab_size=$(crudini --get "${config_path}" llama vocab_size) +inter_size=$(crudini --get "${config_path}" llama inter_size) +tensor_para_size=$(crudini --get "${config_path}" llama tensor_para_size) +max_batch_size=$(crudini --get "${config_path}" llama max_batch_size) + +echo $head_num, $size_per_head, $vocab_size, $inter_size, $tensor_para_size, $max_batch_size + +python3 -m lmdeploy.turbomind.generate_gemm_config \ + --head_num ${head_num} \ + --size_per_head ${size_per_head} \ + --vocab_size ${vocab_size} \ + --inter_size ${inter_size} \ + --tensor_para_size ${tensor_para_size} \ + --max_batch_size ${max_batch_size} + +# benchmark request throughput and static inference +benchmark_rpm ${output_path} +benchmark_generation ${output_path} + +mv gemm_config.in ${output_path} diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index c46f790da0..1e6929b490 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -1,11 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -# import multiprocessing as mp import argparse import csv import logging import os -import os.path as osp -import random import time from dataclasses import dataclass from queue import Queue @@ -17,69 +14,121 @@ nvmlDeviceGetMemoryInfo, nvmlDeviceGetName, nvmlDeviceGetPowerState, nvmlDeviceGetTemperature, nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion) +from tqdm import tqdm -def infer(model, - session_id: int, - input_ids: str, - output_seqlen: int, - test_round: int, - que: Queue, - sampling_param=None): +def infer(model, session_id: int, input_ids: List, output_seqlen: int, + top_k: int, top_p: float, temperature: float, test_round: int, + que: Queue): + from lmdeploy.pytorch.engine import Engine + from lmdeploy.pytorch.messages import SamplingParam + + if session_id == 1: + pbar = tqdm(total=test_round) chatbot = model.create_instance() stats = [] - for i in range(test_round): - start = time.perf_counter() - timestamps = [] - tokens = [] - for outputs in chatbot.stream_infer(session_id, - input_ids, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True, - ignore_eos=True, - sampling_param=sampling_param): - if len(outputs) > 1: - res, token = outputs[-2:] - else: - res, token = outputs[0] - timestamps.append(time.perf_counter()) - tokens.append(token) - - # for pytorch engine to restart a session - if hasattr(chatbot, 'end'): + for _ in range(test_round): + token_latency_stats = [0] * (output_seqlen + 1) + prev = time.perf_counter() + n_prev_token = 0 + + """ + The iterator provided by `stream_infer` denotes the number of generated tokens so far, + which is represented by the variable `n_token`. + Please note that `n_token` is not a continuous value. In other words, during the iteration, + its value might be 5, 7, 8, 16, and so on, rather than 1, 2, 3, 4, etc. + So, it is quite difficult to get the latency of each generated token. + As a work-around, we set the latency `now-prev` of each iteration to the first token of + the new generated tokens, and leave the latency of the rest tokens being 0. + For example, in the first iteration, 5 tokens are generated. + The time elapsing in this iteration `now-prev` is set to the latency of first token of + the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0` + """ # noqa: E501 + # TODO: use same inference interface + if isinstance(model, Engine): + sampling_param = SamplingParam(top_k=top_k, + top_p=top_p, + temperature=temperature, + ignore_eos=True) + for outputs in chatbot.stream_infer( + session_id, + input_ids=input_ids, + request_output_len=output_seqlen, + sampling_param=sampling_param): + if len(outputs) > 1: + _, n_token = outputs[-2:] + else: + _, n_token = outputs[0] + now = time.perf_counter() + if n_prev_token != n_token: + token_latency_stats[n_prev_token] = np.round(now - prev, 3) + n_prev_token = n_token + prev = now chatbot.end(session_id) - - # TODO: ignore first token - first_token_latency = np.round(timestamps[0] - start, 2) - if len(timestamps) == 1: - token_latency = np.round(timestamps[0] - start, 2) - token = tokens[0] else: - token_latency = np.round(timestamps[-1] - timestamps[0], 2) - token = tokens[-1] - tokens[0] - stats.append([first_token_latency, token, token_latency]) + for outputs in chatbot.stream_infer( + session_id, + input_ids, + request_output_len=output_seqlen, + sequence_start=True, + sequence_end=True, + ignore_eos=True, + stream_output=True, + top_k=top_k, + top_p=top_p, + temperature=temperature): + _, n_token = outputs[0] + now = time.perf_counter() + if n_prev_token != n_token: + token_latency_stats[n_prev_token] = np.round(now - prev, 3) + n_prev_token = n_token + prev = now + if session_id == 1: + pbar.update(1) + + assert output_seqlen <= n_token <= output_seqlen + 1, \ + f'Error. session_id({session_id}) request {output_seqlen} ' \ + f'tokens, but generate {n_token} tokens' + stats.append(token_latency_stats[:output_seqlen]) que.put((session_id, stats)) -def warmup(model, - concurrency: int, - input_ids: List[int], - output_seqlen: int, - warmup_round: int = 2, - sampling_param=None): +def warmup(model, concurrency: int, input_ids: List[int], output_seqlen: int, + warmup_round: int): + from lmdeploy.pytorch.engine import Engine + from lmdeploy.pytorch.messages import SamplingParam + if not warmup_round: + return + print('start to warmup ...') def _infer(model, session_id): chatbot = model.create_instance() for _ in range(warmup_round): - for _ in chatbot.stream_infer(session_id, - input_ids=input_ids, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True, - ignore_eos=True, - sampling_param=sampling_param): + # TODO: use same inference interface + if isinstance(model, Engine): + sampling_param = SamplingParam(top_k=1, + top_p=1.0, + temperature=0.8, + repetition_penalty=1.0, + ignore_eos=True) + generator = chatbot.stream_infer( + session_id, + input_ids=input_ids, + request_output_len=output_seqlen, + sampling_param=sampling_param) + else: + generator = chatbot.stream_infer( + session_id, + input_ids=input_ids, + request_output_len=output_seqlen, + sequence_start=True, + sequence_end=True, + ignore_eos=True, + top_k=1, + top_p=1.0, + temperature=1.0) + for _ in generator: continue # for pytorch engine to restart a session if hasattr(chatbot, 'end'): @@ -92,104 +141,105 @@ def _infer(model, session_id): procs.append(proc) proc.start() - try: - for proc in procs: - proc.join() - except Exception: - for proc in procs: - proc.stop() - exit(1) + for proc in procs: + proc.join() + _end = time.perf_counter() print(f'end warmup, elapsed time: {round(_end - _start, 2)}s') -def profile_throughput(model_path: str, - concurrency: int = 1, - input_seqlen: int = 0, - output_seqlen: int = 512, - test_round: int = 10, - tp: int = 1, - sampling_param=None): - from lmdeploy.pytorch.engine import Engine - from lmdeploy.pytorch.messages import SamplingParam - from lmdeploy.tokenizer import Tokenizer - tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') +def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, + output_seqlen: int, tp: int, top_k: int, top_p: float, + temperature: float, test_round: int, warmup_round: int, + **kwargs): + + print(f'profiling ... concurrency: {concurrency}, ' + f'n_prompt_token: {input_seqlen}, ' + f'n_completion_token: {output_seqlen}, ' + f'test_round: {test_round}, warmup_round: {warmup_round}') + + tokenizer_model_path = os.path.join(model_path, 'triton_models', + 'tokenizer') + if os.path.exists(tokenizer_model_path): from lmdeploy.turbomind import TurboMind - tokenizer = Tokenizer(tokenizer_model_path) - tm_model = TurboMind(model_path=model_path, tp=tp) + + # avoid turbomind checking chat template name by setting `model_name='llama'` # noqa + tm_model = TurboMind(model_path=model_path, + tp=tp, + model_name='llama', + **kwargs) else: - tokenizer = Tokenizer(model_path) - tm_model = Engine(model_path, tp=tp) - - sampling_param = SamplingParam( - top_k=40, - top_p=0.8, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=True, - random_seed=random.getrandbits(64), - ) - - # make up a prompt that can be tokenized into {input_seqlen} tokens - prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1) - input_ids = tokenizer.encode(prompt) - - warmup(tm_model, - concurrency, - input_ids, - output_seqlen, - sampling_param=sampling_param) + from lmdeploy.pytorch.engine import Engine + + # tokenizer = Tokenizer(model_path) + tm_model = Engine(model_path, tp=tp, model_name='llama') + + # make up a dummy `input_ids` with the length of `input_seqlen` exactly + assert input_seqlen > 0, 'input_seqlen should > 0' + input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist() + warmup(tm_model, concurrency, input_ids, output_seqlen, warmup_round) que = Queue() procs = [] _start = time.perf_counter() - # TODO: update to the multithread version for i in range(concurrency): proc = Thread(target=infer, - args=(tm_model, i, input_ids, output_seqlen, test_round, - que, sampling_param)) + args=(tm_model, i + 1, input_ids, output_seqlen, top_k, + top_p, temperature, test_round, que)) procs.append(proc) proc.start() - try: - for proc in procs: - proc.join() - except Exception: - for proc in procs: - proc.stop() - exit(1) + for proc in procs: + proc.join() + _end = time.perf_counter() elapsed_time = _end - _start - stats = [] + token_latency_stats = [] while not que.empty(): - session_id, _stats = que.get() - print(f'\n{"-" * 50}\n' - f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') - stats.append(_stats) - - stats = np.array(stats).reshape(-1, 3) - - first_token_latency_min = np.min(stats[:, 0], axis=0) - first_token_latency_max = np.max(stats[:, 0], axis=0) - first_token_latency_ave = np.mean(stats[:, 0], axis=0) - token_latency_min = np.min(stats[:, 2], axis=0) - token_latency_max = np.max(stats[:, 2], axis=0) - token_latency_ave = np.mean(stats[:, 2], axis=0) - throughput = np.sum(stats[:, 1], axis=0) / np.sum(stats[:, 2], - axis=0) * concurrency - print(f'\n{"-" * 50}\nconcurrency: {concurrency}, input_tokens: ' - f'{input_seqlen}, output_tokens: {output_seqlen}\n' - f'elapsed_time: {elapsed_time:.2f}s\n' + _, _stats = que.get() + token_latency_stats += _stats + + # The shape is [concurrency*test_round, output_seqlen] + token_latency_stats = np.stack(token_latency_stats, axis=0) + + first_token_latency_min = np.round( + np.min(token_latency_stats[:, 0], axis=0), 3) + first_token_latency_max = np.round( + np.max(token_latency_stats[:, 0], axis=0), 3) + first_token_latency_ave = np.round( + np.mean(token_latency_stats[:, 0], axis=0), 3) + token_latency_max = np.round(np.max(np.sum(token_latency_stats, axis=1)), + 3) + token_latency_min = np.round(np.min(np.sum(token_latency_stats, axis=1)), + 3) + token_latency_ave = np.round(np.mean(np.sum(token_latency_stats, axis=1)), + 3) + # sort token_latency without the first token's latency + sorted_token_latency = np.sort(token_latency_stats[:, 1:].flatten()) + percentiles = [ + np.round( + sorted_token_latency[int(percent * len(sorted_token_latency))], 3) + for percent in [0.5, 0.75, 0.95, 0.99] + ] + + throughput = np.round(token_latency_stats.size / elapsed_time, 2) + print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n' + f'concurrency: {concurrency}, test_round: {test_round}\n' + f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n' f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, ' - f'{first_token_latency_ave:.2f}s\ntoken latency(min, max, ave): ' - f'{token_latency_min:.2f}s, {token_latency_max:.2f}s, ' - f'{token_latency_ave:.2f}s\n' - f'throughput: {throughput:.2f} token/s\n{"-" * 50}') - return throughput, tm_model.gpu_count + f'{first_token_latency_min}s, {first_token_latency_max}s, ' + f'{first_token_latency_ave}s\ntotal_token latency(min, max, ave): ' + f'{token_latency_min}s, {token_latency_max}s, ' + f'{token_latency_ave}s\n' + f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n' + f'throughput: {throughput} token/s\n{"-" * 50}') + return tm_model.model_name, \ + [first_token_latency_min, first_token_latency_max, + first_token_latency_ave], \ + percentiles, throughput, tm_model.gpu_count class MemoryMonitor: @@ -266,9 +316,12 @@ def terminate(cls) -> float: @dataclass class ProfileResult: + model_name: str batch: int prompt_tokens: int completion_tokens: int + first_token_latency: List + percentiles: List throughput_per_proc: float throughput_per_node: float mem_per_proc: float @@ -278,84 +331,127 @@ class ProfileResult: def parse_args(): parser = argparse.ArgumentParser(description='Regression Test') - parser.add_argument('--model-path', + parser.add_argument('model_path', type=str, - help='benchmark test model path') + help='the path of the model in localhost or ' + 'the repo_id of the model in huggingface.co') parser.add_argument('--concurrency', nargs='+', type=int, help='how many requests launched concurrently', - default=[1, 8, 16, 32]) + default=[1, 16, 32, 64]) parser.add_argument( '--prompt-tokens', nargs='+', type=int, help='how many requests launched concurrently. One-to-one' 'correspondence with completion-tokens', - default=[64, 512, 512, 1024]) + default=[1, 128, 128, 2048, 2048]) parser.add_argument('--completion-tokens', nargs='+', type=int, help='how many tokens to be generated. One-to-one' 'correspondence with prompt-tokens', - default=[512, 512, 1024, 1024]) + default=[128, 128, 2048, 128, 2048]) parser.add_argument('--tp', type=int, help='Tensor parallel', default=1) - parser.add_argument('--dst-csv', + parser.add_argument('--top_k', + type=int, + help='The number of highest probability vocabulary ' + 'tokens to keep for top-k-filtering', + default=1) + parser.add_argument('--top_p', + type=float, + help='the set of most probable tokens with ' + 'probabilities that add up to top_p or higher ' + 'are kept for generation', + default=1.0) + parser.add_argument('--temperature', + type=float, + help='The value used to modulate the next token ' + 'probabilities', + default=1.0) + parser.add_argument('--csv', type=str, help='Where to save the result.', default='profile_generation.csv') parser.add_argument('--log-level', help='set log level', - default='INFO', + default='ERROR', choices=list(logging._nameToLevel.keys())) + parser.add_argument('--test-round', + type=int, + help='number of test rounds', + default=6) + parser.add_argument('--warmup-round', + type=int, + help='number of warmuop rounds', + default=1) args = parser.parse_args() return args def main(): - import multiprocessing as mp args = parse_args() + assert len(args.prompt_tokens) == len(args.completion_tokens), \ + f'mismatched size between `prompt-tokens` and `completion-tokenes`' \ + f', {len(args.prompt_tokens)} vs {len(args.completion_tokens)}' + os.environ['TM_LOG_LEVEL'] = args.log_level results: List[ProfileResult] = [] for batch in args.concurrency: for prompt_tokens, completion_tokens in zip(args.prompt_tokens, args.completion_tokens): - from functools import partial MemoryMonitor.start() + from functools import partial + from multiprocessing import Pool profile_target = partial(profile_throughput, concurrency=batch, input_seqlen=prompt_tokens, output_seqlen=completion_tokens, - tp=args.tp) - output = mp.Pool(1).map(profile_target, (args.model_path, )) - throughput_per_proc, tp = output[0] + tp=args.tp, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + test_round=args.test_round, + warmup_round=args.warmup_round) + output = Pool(1).map(profile_target, (args.model_path, )) + model_name, first_token_latency, percentiles, \ + throughput_per_proc, tp = output[0] time.sleep(5) # wait a while for releasing GPU mem memory = MemoryMonitor.terminate() device_count = MemoryMonitor.device_count.value results.append( - ProfileResult(batch=batch, + ProfileResult(model_name=model_name, + batch=batch, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, + first_token_latency=first_token_latency, + percentiles=percentiles, throughput_per_proc=throughput_per_proc, throughput_per_node=throughput_per_proc / tp * device_count, mem_per_proc=memory, mem_per_gpu=memory / tp, mem_per_node=memory / tp * device_count)) - with open(args.dst_csv, 'w') as csvfile: - writer = csv.writer(csvfile) - writer.writerow([ - 'batch', 'prompt_tokens', 'completion_tokens', - 'throughput_per_proc(token/s)', 'throughput_per_node(token/s)', - 'mem_per_proc(GB)', 'mem_per_gpu(GB)', 'mem_per_node(GB)' - ]) - for re in results: + if args.csv: + with open(args.csv, 'w') as csvfile: + writer = csv.writer(csvfile) writer.writerow([ - re.batch, re.prompt_tokens, re.completion_tokens, - f'{re.throughput_per_proc:.2f}', - f'{re.throughput_per_node:.2f}', f'{re.mem_per_proc:.2f}', - f'{re.mem_per_gpu:.2f}', f'{re.mem_per_node:.2f}' + 'batch', 'prompt_tokens', 'completion_tokens', + '1st_token_latency(min)(s)', '1st_token_latency(max)(s)', + '1st_token_latency(ave)(s)', 'percentile50(s)', + 'percentile75(s)', 'percentile95(s)', 'percentile99(s)', + 'throughput(token/s)', 'mem_per_proc(GB)', 'mem_per_gpu(GB)' ]) + for re in results: + writer.writerow([ + re.batch, re.prompt_tokens, re.completion_tokens, + re.first_token_latency[0], re.first_token_latency[1], + re.first_token_latency[2], re.percentiles[0], + re.percentiles[1], re.percentiles[2], re.percentiles[3], + f'{re.throughput_per_proc:.2f}', f'{re.mem_per_proc:.2f}', + f'{re.mem_per_gpu:.2f}' + ]) if __name__ == '__main__': diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py index d1f6ebf80e..b16dfcd482 100644 --- a/benchmark/profile_restful_api.py +++ b/benchmark/profile_restful_api.py @@ -1,195 +1,252 @@ +import csv import json -import multiprocessing as mp import random import time -from typing import Iterable, List +from queue import Queue +from threading import Thread +from typing import List, Tuple import fire import numpy as np -import requests +from tqdm import tqdm +from lmdeploy.serve.openai.api_client import APIClient from lmdeploy.tokenizer import Tokenizer -from lmdeploy.utils import get_logger - - -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int, - stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = False, - ignore_eos: bool = False) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos - } - response = requests.post(api_url, - headers=headers, - json=pload, - stream=stream) - for chunk in response.iter_lines(chunk_size=8192, - decode_unicode=False, - delimiter=b'\n'): - if chunk: - data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - yield output, tokens - - -def infer(server_addr: str, session_id: int, req_queue: mp.Queue, - res_que: mp.Queue): - stats = [] - while not req_queue.empty(): - prompt, input_seqlen, output_seqlen = req_queue.get() - get_logger('profile_restful_api').info( - f'request info: session {session_id}, ' - f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}') - timestamps = [] - tokens = [] - start = time.perf_counter() - for res, token in get_streaming_response( - prompt, - server_addr, - session_id, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True): - timestamps.append(time.perf_counter()) - tokens.append(token) - - first_token_latency = timestamps[1] - start - token_latency = timestamps[-1] - timestamps[0] - token = tokens[-1] - tokens[0] - stats.append([first_token_latency, token, token_latency]) - res_que.put((session_id, stats)) - - -def warmup(server_addr: str, - concurrency: int, - output_seqlen: int, - warmup_round: int = 1): - print('start to warmup ...') - - def _infer(server_addr, session_id): - for _ in range(warmup_round): - for _, _ in get_streaming_response( - '', - server_addr, - session_id, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True): - continue - - _start = time.perf_counter() - procs = [] - for i in range(concurrency): - proc = mp.Process(target=_infer, args=(server_addr, i + 1)) - procs.append(proc) - proc.start() - for proc in procs: - proc.join() - _end = time.perf_counter() - print(f'end warmup, elapsed time: {round(_end - _start, 2)} s') - - -def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, - session_len: int): - start = time.perf_counter() + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: Tokenizer, +) -> List[Tuple[str, int, int]]: + # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) - dataset = [data for data in dataset if len(data['conversations']) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data['conversations'][0]['value'], - data['conversations'][1]['value']) for data in dataset] - prompts = [prompt for prompt, _ in dataset] - completions = [completion for _, completion in dataset] - print(f'elapsed time for read data: ' - f'{round(time.perf_counter() - start, 2)} s') - - start = time.perf_counter() - tokenizer = Tokenizer(tokenizer_path) - prompts_token_lens = [len(tokenizer.encode(prompt)) for prompt in prompts] - completions_token_lens = [ - len(tokenizer.encode(prompt)) for prompt in completions - ] - print(f'elapsed time for tokenization: ' - f'{round(time.perf_counter() - start, 2)} s') - - start = time.perf_counter() - filtered_dataset = [] - for (prompt, _), input_len, output_len in zip(dataset, prompts_token_lens, - completions_token_lens): - if input_len + output_len > session_len: - # ignore too long conversation + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data['conversations']) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data['conversations'][0]['value'], + data['conversations'][1]['value']) for data in dataset] + + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. continue - filtered_dataset.append([prompt, input_len, output_len]) - - if samples > 0: - filtered_dataset = random.sample(filtered_dataset, samples) - - que = mp.Queue() - for data in filtered_dataset: - que.put(data) - print(f'elapsed time for filtering: ' - f'{round(time.perf_counter() - start, 2)} s') - return que, len(filtered_dataset) + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) + return sampled_requests + + +class Engine: + + def __init__(self, + server_addr: str, + tokenzier_path: str, + temperature: float = 0.8, + top_p: float = 1.0, + csv: str = '', + **kwargs): + self.tokenizer = Tokenizer(tokenzier_path) + self.server_addr = server_addr + self.temperature = temperature + self.top_p = top_p + self.csv = csv + client = APIClient(self.server_addr) + self.model_name = client.available_models[0] + self.pbar = None + + def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, + stream_output: bool): + + stats = [] + client = APIClient(self.server_addr) + + for prompt, input_seqlen, output_seqlen in iter( + req_queue.get, [None, None, None]): + timestamps = [] + timestamps.append(time.perf_counter()) + for output in client.chat_completions_v1( + model=self.model_name, + messages=prompt, + temperature=self.temperature, + top_p=self.top_p, + n=1, + max_tokens=output_seqlen, + stream=stream_output, + session_id=session_id, + ignore_eos=True): + timestamps.append(time.perf_counter()) + + first_token_latency = np.round(timestamps[1] - timestamps[0], 3) + token_latency = np.round(timestamps[-1] - timestamps[0], 3) + # assert output.pop('finish_reason') == 'length', \ + # f'Error. session_id({session_id}) request {output_seqlen} ' \ + # f'tokens, but `finish_reason` is not `length`' + total_tokens = input_seqlen + output_seqlen + stats.append([ + first_token_latency, output_seqlen, output_seqlen, + total_tokens, token_latency + ]) + self.pbar.update(1) + + res_queue.put((session_id, stats)) + + def process_request(self, + requests, + concurrency: int = 1, + stream_output: bool = False): + res_queue = Queue() + req_queue = Queue() + threads = [] + + self.pbar = tqdm(total=len(requests)) + + # feed request to q + for req in requests: + req_queue.put(req) + for i in range(concurrency): + req_queue.put([None, None, None]) + + start = time.time() + + # start threads + for i in range(concurrency): + t = Thread(target=self._inference, + args=(req_queue, res_queue, i, stream_output)) + t.start() + threads.append(t) + + # wait for finish + for t in threads: + t.join() + + elapsed_time = time.time() - start + + stats = [] + while not res_queue.empty(): + session_id, _stats = res_queue.get() + # print(f'\n{"-" * 50}\n' + # f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') + stats.append(np.array(_stats)) + + stats = np.concatenate(stats).reshape(-1, 5) + + first_token_latency_min = np.min(stats[:, 0], axis=0) + first_token_latency_max = np.max(stats[:, 0], axis=0) + first_token_latency_ave = np.mean(stats[:, 0], axis=0) + completion_tokens = np.sum(stats[:, 1], axis=0) + request_output_tokens = np.sum(stats[:, 2], axis=0) + total_tokens = np.sum(stats[:, 3], axis=0) + prompt_tokens = total_tokens - completion_tokens + completion_token_throughput = completion_tokens / elapsed_time + total_token_throughput = total_tokens / elapsed_time + rps = len(requests) / elapsed_time + rpm = rps * 60 + + if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: + print(f'Did not generate requested number of tokens. ' + f'Request {request_output_tokens:.0f}, ' + f'but got {completion_tokens:.0f}') + + print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' + f'elapsed_time: {elapsed_time:.3f}s\n') + if stream_output: + print(f'first_token latency(min, max, ave): ' + f'{first_token_latency_min:.3f}s, ' + f'{first_token_latency_max:.3f}s, ' + f'{first_token_latency_ave:.3f}s\n') + print( + f'number of prompt tokens: {prompt_tokens:.0f}\n' + f'number of completion tokens: {completion_tokens:.0f}\n' + f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa + f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa + f'RPS (request per second): {rps:.3f} req/s\n' + f'RPM (request per minute): {rpm:.3f} req/min\n' + f'{"-" * 50}\n') + + if self.csv: + with open(self.csv, 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + 'batch', 'num_prompts', 'prompt_tokens', + 'completion_tokens', '1st_token_latency(min)(s)', + '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)', + 'output token thr(tokens/s', 'total token thr(token/s)', + 'RPS', 'RPM' + ]) + writer.writerow([ + concurrency, + len(requests), prompt_tokens, completion_tokens, + f'{first_token_latency_min:.3f}' if stream_output else '-', + f'{first_token_latency_max:.3f}' if stream_output else '-', + f'{first_token_latency_ave:.3f}' if stream_output else '-', + f'{completion_token_throughput:.3f}', + f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}' + ]) def main(server_addr: str, tokenizer_path: str, - dataset_path: str, - concurrency: int = 1, - session_len: int = 2048, - samples: int = 1000): - api_url = server_addr + '/generate' - warmup(api_url, concurrency, session_len - 1) - req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples, - session_len) - res_que = mp.Queue() - procs = [] - _start = time.perf_counter() - for i in range(concurrency): - proc = mp.Process(target=infer, - args=(api_url, i + 1, req_queue, res_que)) - procs.append(proc) - proc.start() - for proc in procs: - proc.join() - _end = time.perf_counter() - elapsed_time = _end - _start - - stats = [] - while not res_que.empty(): - session_id, _stats = res_que.get() - print(f'\n{"-" * 50}\n' - f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') - stats.append(np.array(_stats)) - - stats = np.concatenate(stats).reshape(-1, 3) - - first_token_latency_min = np.min(stats[:, 0], axis=0) - first_token_latency_max = np.max(stats[:, 0], axis=0) - first_token_latency_ave = np.mean(stats[:, 0], axis=0) - token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time - req_throughput = n_req / elapsed_time - - print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' - f'elapsed_time: {elapsed_time:.2f}s\n' - f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, ' - f'{first_token_latency_ave:.2f}s\n' - f'token throughput: {token_throughput:.2f} token/s\n' - f'req throughput: {req_throughput:.2f} req/s\n' - f'{"-" * 50}\n') + dataset: str, + concurrency: int = 64, + num_prompts: int = 2000, + top_p: float = 1.0, + temperature: float = 1.0, + stream_output: bool = False, + csv: str = './profile_api_server.csv', + seed: int = 0): + """Benchmark the request througput of api server. + + Args: + server_addr (str): http url of api_server with format http://0.0.0.0:0 + tokenizer_path (str): Path to the tokenizer model in localhost + dataset (str): Path to the dataset + concurrency (int, optional): Number of working threads to process the sampled prompts. + Defaults to 64. + num_prompts (int, optional): Number of prompts to process. Defaults to 2000. + top_p (float, optional): the set of most probable tokens with + probabilities that add up to top_p or higher + are kept for generation. Defaults to 1.0. + temperature (float, optional): The value used to modulate the next token probabilities. + Defaults to 1.0. + stream_output (bool, optional): Indicator for streaming output. Defaults to False. + csv (str, optional): The path to save the result. + seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0. + """ # noqa + if not server_addr.startswith('http://'): + print(f'[WARNING] server_addr of the api_server should ' + f'start with "http://", but got "{server_addr}"') + server_addr = 'http://' + server_addr.strip() + + random.seed(seed) + + engine = Engine(server_addr, + tokenizer_path, + top_p=top_p, + temperature=temperature, + csv=csv) + + requests = sample_requests(dataset, num_prompts, engine.tokenizer) + + engine.process_request(requests, concurrency, stream_output) if __name__ == '__main__': diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py index 4580757eeb..154751737e 100644 --- a/benchmark/profile_serving.py +++ b/benchmark/profile_serving.py @@ -1,177 +1,257 @@ +import csv import json -import logging -import multiprocessing as mp import random import time +from queue import Queue +from threading import Thread +from typing import List, Tuple import fire import numpy as np +from tqdm import tqdm from lmdeploy.serve.turbomind.chatbot import Chatbot from lmdeploy.tokenizer import Tokenizer -def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): - stats = [] - for prompt, input_seqlen, output_seqlen in iter(req_que.get, - [None, None, None]): - timestamps = [] - tokens = [] - start = time.perf_counter() - for status, res, token in chatbot.stream_infer( - session_id, - prompt, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True): +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: Tokenizer, +) -> List[Tuple[str, int, int]]: + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data['conversations']) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data['conversations'][0]['value'], + data['conversations'][1]['value']) for data in dataset] + + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) + return sampled_requests + + +class Engine: + + def __init__(self, + server_addr: str, + tokenzier_path: str, + temperature: float = 0.8, + top_k: int = 1, + top_p: float = 1.0, + csv: str = '', + log_level: str = 'ERROR', + **kwargs): + self.server_addr = server_addr + self.tokenizer = Tokenizer(tokenzier_path) + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + self.csv = csv + self.log_level = log_level + self.pbar = None + + def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, + stream_output: bool): + + chatbot = Chatbot(self.server_addr, + ignore_eos=True, + profile_serving=True, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + log_level=self.log_level) + stats = [] + for prompt, input_seqlen, output_seqlen in iter( + req_queue.get, [None, None, None]): + timestamps = [] + tokens = [] timestamps.append(time.perf_counter()) - tokens.append(token) - - first_token_latency = np.round(timestamps[1] - start, 3) - token_latency = np.round(timestamps[-1] - timestamps[0], 3) - token = tokens[-1] - tokens[0] - stats.append([first_token_latency, token, token_latency]) - print(f'session {session_id}: ' - f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}') - res_que.put((session_id, stats)) - - -def warmup(tritonserver_addr: str, - concurrency: int, - output_seqlen: int, - warmup_round: int = 1): - print('start to warmup ...') - - def _infer(_chatbot, session_id): - for _ in range(warmup_round): - for _, _, _ in _chatbot.stream_infer( + for _, _, n_token in chatbot.stream_infer( session_id, - prompt='', + prompt, request_output_len=output_seqlen, sequence_start=True, sequence_end=True): - continue - _chatbot.reset_session() - - _start = time.perf_counter() - chatbots = [ - Chatbot(tritonserver_addr=tritonserver_addr, - ignore_eos=True, - log_level=logging.ERROR, - profile_generation=True) for _ in range(concurrency) - ] - procs = [] - for i, chatbot in enumerate(chatbots): - proc = mp.Process(target=_infer, args=(chatbot, i + 1)) - procs.append(proc) - proc.start() - for proc in procs: - proc.join() - _end = time.perf_counter() - print(f'end warmup, elapsed time: {round(_end - _start, 2)} s') - - -def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, - session_len: int, que: mp.Queue): - start = time.perf_counter() - with open(dataset_path) as f: - dataset = json.load(f) - dataset = [data for data in dataset if len(data['conversations']) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data['conversations'][0]['value'], - data['conversations'][1]['value']) for data in dataset] - prompts = [prompt for prompt, _ in dataset] - completions = [completion for _, completion in dataset] - print(f'elapsed time for read data: ' - f'{round(time.perf_counter() - start, 2)} s') - - start = time.perf_counter() - tokenizer = Tokenizer(tokenizer_path) - prompts_token_lens = [len(tokenizer.encode(prompt)) for prompt in prompts] - completions_token_lens = [ - len(tokenizer.encode(prompt)) for prompt in completions - ] - print(f'elapsed time for tokenization: ' - f'{round(time.perf_counter() - start, 2)} s') - - start = time.perf_counter() - filtered_dataset = [] - for (prompt, _), input_len, output_len in zip(dataset, prompts_token_lens, - completions_token_lens): - if input_len + output_len > session_len: - # ignore too long conversation - continue - filtered_dataset.append([prompt, input_len, output_len]) + timestamps.append(time.perf_counter()) + tokens.append(n_token) + first_token_latency = np.round(timestamps[1] - timestamps[0], 3) + token_latency = np.round(timestamps[-1] - timestamps[0], 3) + completion_tokens = tokens[-1] + assert output_seqlen <= completion_tokens <= output_seqlen + 1, \ + f'Error. session_id({session_id}) request {output_seqlen} ' \ + f'tokens, but generate {completion_tokens} tokens.\n' \ + f'prompt: {prompt}' + total_tokens = tokens[-1] + input_seqlen + stats.append([ + first_token_latency, completion_tokens, output_seqlen, + total_tokens, token_latency + ]) + self.pbar.update(1) + res_queue.put((session_id, stats)) + + def process_request(self, + requests, + concurrency: int = 1, + stream_output: bool = True): + res_queue = Queue() + req_queue = Queue() + threads = [] - if samples > 0: - filtered_dataset = random.sample(filtered_dataset, samples) + self.pbar = tqdm(total=len(requests)) - for data in filtered_dataset: - que.put(data) - print(f'elapsed time for filtering: ' - f'{round(time.perf_counter() - start, 2)} s') - return len(filtered_dataset) + # feed request to q + for req in requests: + req_queue.put(req) + for i in range(concurrency): + req_queue.put([None, None, None]) + start = time.time() -def main(tritonserver_addr: str, + # start threads + for i in range(concurrency): + t = Thread(target=self._inference, + args=(req_queue, res_queue, i, stream_output)) + t.start() + threads.append(t) + + # wait for finish + for t in threads: + t.join() + + elapsed_time = time.time() - start + + stats = [] + while not res_queue.empty(): + session_id, _stats = res_queue.get() + # print(f'\n{"-" * 50}\n' + # f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') + stats.append(np.array(_stats)) + + stats = np.concatenate(stats).reshape(-1, 5) + + first_token_latency_min = np.min(stats[:, 0], axis=0) + first_token_latency_max = np.max(stats[:, 0], axis=0) + first_token_latency_ave = np.mean(stats[:, 0], axis=0) + completion_tokens = np.sum(stats[:, 1], axis=0) + request_output_tokens = np.sum(stats[:, 2], axis=0) + total_tokens = np.sum(stats[:, 3], axis=0) + prompt_tokens = total_tokens - completion_tokens + completion_token_throughput = completion_tokens / elapsed_time + total_token_throughput = total_tokens / elapsed_time + rps = len(requests) / elapsed_time + rpm = rps * 60 + + if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: + print(f'Did not generate requested number of tokens. ' + f'Request {request_output_tokens:.0f}, ' + f'but got {completion_tokens:.0f}') + + print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' + f'elapsed_time: {elapsed_time:.3f}s\n') + if stream_output: + print(f'first_token latency(min, max, ave): ' + f'{first_token_latency_min:.3f}s, ' + f'{first_token_latency_max:.3f}s, ' + f'{first_token_latency_ave:.3f}s\n') + print( + f'number of prompt tokens: {prompt_tokens:.0f}\n' + f'number of completion tokens: {completion_tokens:.0f}\n' + f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa + f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa + f'RPS (request per second): {rps:.3f} req/s\n' + f'RPM (request per minute): {rpm:.3f} req/min\n' + f'{"-" * 50}\n') + + if self.csv: + with open(self.csv, 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + 'batch', 'num_prompts', 'prompt_tokens', + 'completion_tokens', '1st_token_latency(min)(s)', + '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)', + 'output token thr(tokens/s', 'total token thr(token/s)', + 'RPS', 'RPM' + ]) + writer.writerow([ + concurrency, + len(requests), prompt_tokens, completion_tokens, + f'{first_token_latency_min:.3f}' if stream_output else '-', + f'{first_token_latency_max:.3f}' if stream_output else '-', + f'{first_token_latency_ave:.3f}' if stream_output else '-', + f'{completion_token_throughput:.3f}', + f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}' + ]) + + +def main(server_addr: str, tokenizer_path: str, - dataset_path: str, - concurrency: int = 1, - session_len: int = 2048, - samples: int = 1000): - warmup(tritonserver_addr, concurrency, session_len - 1) - req_que = mp.Queue() - res_que = mp.Queue() - - procs = [] - _start = time.perf_counter() - for i in range(concurrency): - chatbot = Chatbot(tritonserver_addr=tritonserver_addr, - display=False, - profile_serving=True, - ignore_eos=True, - log_level=logging.ERROR) - proc = mp.Process(target=infer, - args=(chatbot, i + 1, req_que, res_que)) - procs.append(proc) - proc.start() - - # read data and put it to queue - n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len, - req_que) - for i in range(concurrency): - req_que.put([None, None, None]) - - stats = [] - for i in range(concurrency): - session_id, _stats = res_que.get() - print(f'\n{"-" * 50}\n' - f'session {session_id}: processed reqs {len(_stats)}, ' - f'stats: \n{_stats}\n{"-" * 50}\n') - stats.append(np.array(_stats)) - - _end = time.perf_counter() - elapsed_time = _end - _start - - stats = np.concatenate(stats).reshape(-1, 3) - - first_token_latency_min = np.min(stats[:, 0], axis=0) - first_token_latency_max = np.max(stats[:, 0], axis=0) - first_token_latency_ave = np.mean(stats[:, 0], axis=0) - token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time - req_throughput = n_req / elapsed_time - - print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' - f'elapsed_time: {elapsed_time:.3f}s\n' - f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, ' - f'{first_token_latency_ave:.3f}s\n' - f'token throughput: {token_throughput:.3f} token/s\n' - f'req throughput: {req_throughput:.3f} req/s\n' - f'{"-" * 50}\n') - - for proc in procs: - proc.join() + dataset: str, + concurrency: int = 32, + num_prompts: int = 1000, + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + stream_output: bool = True, + csv: str = './profile_tis.csv', + seed: int = 0): + """Benchmark the request througput of the triton inference server. + + Args: + server_addr (str): Address of the triton inference server with format 0.0.0.0:0 + tokenizer_path (str): Path to the tokenizer model in localhost + dataset (str): Path to the dataset + concurrency (int, optional): Number of working threads to process the sampled prompts. + Defaults to 32. + num_prompts (int, optional): Number of prompts to process. Defaults to 1000. + top_k (int, optional): The number of highest probability vocabulary tokens + to keep for top-k-filtering. Defaults to 1. + top_p (float, optional): the set of most probable tokens with + probabilities that add up to top_p or higher + are kept for generation. Defaults to 1.0. + temperature (float, optional): The value used to modulate the next token probabilities. + Defaults to 1.0. + stream_output (bool, optional): Indicator for streaming output. Defaults to True. + seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0. + """ # noqa + + random.seed(seed) + + engine = Engine(server_addr, + tokenizer_path, + top_k=top_k, + top_p=top_p, + temperature=temperature, + log_level='ERROR', + csv=csv) + + requests = sample_requests(dataset, num_prompts, engine.tokenizer) + + engine.process_request(requests, concurrency, stream_output) if __name__ == '__main__': diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 1f458665ab..010218aeeb 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -1,6 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv import json import os -import os.path as osp import random import time from queue import Queue @@ -8,6 +9,8 @@ from typing import List, Tuple import fire +import numpy as np +from tqdm import tqdm from lmdeploy.tokenizer import Tokenizer @@ -55,44 +58,66 @@ def sample_requests( class Engine: - def __init__(self, model_path: str, tp: int = 1): + def __init__(self, model_path: str, tp: int, csv: str, **kwargs): + # avoid turbomind checking chat template name by setting + # `model_name='llama'` from lmdeploy.pytorch.engine import Engine - from lmdeploy.pytorch.messages import SamplingParam - tokenizer_model_path = osp.join(model_path, 'triton_models', - 'tokenizer') + tokenizer_model_path = os.path.join(model_path, 'triton_models', + 'tokenizer') if os.path.exists(tokenizer_model_path): from lmdeploy.turbomind import TurboMind - tokenizer = Tokenizer(tokenizer_model_path) - tm_model = TurboMind(model_path=model_path, tp=tp) + tm_model = TurboMind(model_path=model_path, + model_name='llama', + tp=tp, + **kwargs) else: - tokenizer = Tokenizer(model_path) - tm_model = Engine(model_path, tp=tp) - - self.sampling_param = SamplingParam( - top_k=40, - top_p=0.8, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=True, - random_seed=random.getrandbits(64), - ) - + tm_model = Engine(model_path, tp=tp, model_name='llama') self.tm_model = tm_model - self.tokenizer = tokenizer - - def _inference(self, queue, session_id: int): + self.tokenizer = tm_model.tokenizer + self.csv = csv + self.pbar = None + def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, + stream_output: bool): + from lmdeploy.pytorch.engine import Engine + from lmdeploy.pytorch.messages import SamplingParam model_inst = self.tm_model.create_instance() - while True: - request = queue.get() - if request is None: - # stop signal - queue.put(None) - return - else: - prompt, _, output_seqlen = request - input_ids = self.tokenizer.encode(prompt) + stats = [] + # get each generated token's latency + per_token_latency_stats = [] + for prompt, input_seqlen, output_seqlen in iter( + req_queue.get, [None, None, None]): + _per_token_latency_stats = [0] * (output_seqlen + 1) + offset = 0 + prev = time.perf_counter() + n_prev_token = 0 + input_ids = self.tokenizer(prompt).input_ids + # TODO: share same stream infer + if isinstance(self.tm_model, Engine): + sampling_param = SamplingParam(top_k=1, + top_p=1.0, + temperature=1.0, + ignore_eos=True) + for outputs in model_inst.stream_infer( + session_id, + input_ids=input_ids, + request_output_len=output_seqlen, + sampling_param=sampling_param): + if len(outputs) > 1: + res, n_token = outputs[-2:] + else: + res, n_token = outputs[0] + self.tokenizer.decode(res, offset) + offset = n_token + now = time.perf_counter() + if n_prev_token != n_token: + _per_token_latency_stats[n_prev_token] = np.round( + now - prev, 3) + n_prev_token = n_token + prev = now + model_inst.end(session_id) + else: for outputs in model_inst.stream_infer( session_id, input_ids=input_ids, @@ -102,66 +127,185 @@ def _inference(self, queue, session_id: int): sequence_start=True, sequence_end=True, ignore_eos=True, - sampling_param=self.sampling_param): - if len(outputs) > 1: - res, tokens = outputs[-2:] - else: - res, tokens = outputs[0] - self.tokenizer.decode(res) + stream_output=stream_output): + res, n_token = outputs[0] + self.tokenizer.decode(res, offset) + offset = n_token + now = time.perf_counter() + if n_prev_token != n_token: + _per_token_latency_stats[n_prev_token] = np.round( + now - prev, 3) + n_prev_token = n_token + prev = now - # for pytorch engine to restart a session - if hasattr(model_inst, 'end'): - model_inst.end(session_id) + assert output_seqlen <= n_token <= output_seqlen + 1, \ + f'Error. session_id({session_id}) request {output_seqlen} ' \ + f'tokens, but generate {n_token} tokens.\n' \ + f'prompt: {prompt}' - def process_request(self, requests, concurrency: int = 1): - q = Queue() + first_token_latency = _per_token_latency_stats[0] + completion_tokens = n_token + total_tokens = n_token + input_seqlen + stats.append([ + first_token_latency, completion_tokens, output_seqlen, + total_tokens + ]) + # skip the first token latency + per_token_latency_stats.append(_per_token_latency_stats[1:]) + self.pbar.update(1) + res_queue.put((session_id, stats, per_token_latency_stats)) + + def process_request(self, + requests, + concurrency: int = 1, + stream_output: bool = True): + res_queue = Queue() + req_queue = Queue() threads = [] + self.pbar = tqdm(total=len(requests)) + + # feed request to q + for req in requests: + req_queue.put(req) + for i in range(concurrency): + req_queue.put([None, None, None]) + start = time.time() # start threads for i in range(concurrency): - t = Thread(target=self._inference, args=(q, i)) + t = Thread(target=self._inference, + args=(req_queue, res_queue, i, stream_output)) t.start() threads.append(t) - # feed request to q - for req in requests: - q.put(req) - - q.put(None) - # wait for finish for t in threads: t.join() - end = time.time() + elapsed_time = time.time() - start - return end - start + stats = [] + per_token_latency_stats = [] + while not res_queue.empty(): + session_id, _stats, _per_token_latency_stats = res_queue.get() + stats.append(np.array(_stats)) + per_token_latency_stats += [ + item for sublist in _per_token_latency_stats + for item in sublist + ] + stats = np.concatenate(stats).reshape(-1, 4) + + first_token_latency_min = np.min(stats[:, 0], axis=0) + first_token_latency_max = np.max(stats[:, 0], axis=0) + first_token_latency_ave = np.mean(stats[:, 0], axis=0) + completion_tokens = np.sum(stats[:, 1], axis=0) + total_tokens = np.sum(stats[:, 3], axis=0) + prompt_tokens = total_tokens - completion_tokens + completion_token_throughput = completion_tokens / elapsed_time + total_token_throughput = total_tokens / elapsed_time + rps = len(requests) / elapsed_time + rpm = rps * 60 + + per_token_latency_stats.sort() + percentiles = [ + np.round( + per_token_latency_stats[int(percent * + len(per_token_latency_stats))], 3) + for percent in [0.5, 0.75, 0.95, 0.99] + ] + + print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' + f'elapsed_time: {elapsed_time:.3f}s\n') + if stream_output: + print(f'first token latency(s)(min, max, ave): ' + f'{first_token_latency_min:.3f}, ' + f'{first_token_latency_max:.3f}, ' + f'{first_token_latency_ave:.3f}') + print(f'per-token latency(s) percentile(50, 75, 95, 99): ' + f'{percentiles}\n') + print( + f'number of prompt tokens: {prompt_tokens:.0f}\n' + f'number of completion tokens: {completion_tokens:.0f}\n' + f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa + f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa + f'RPS (request per second): {rps:.3f} req/s\n' + f'RPM (request per minute): {rpm:.3f} req/min\n' + f'{"-" * 50}\n') + + if self.csv: + with open(self.csv, 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + 'batch', 'num_promts', 'prompt_tokens', + 'completion_tokens', '1st_token_latency(min)(s)', + '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)', + 'percentile50(s)', 'percentile75(s)', 'percentile95(s)', + 'percentile99(s)', 'output token thr(tokens/s)', + 'total token thr(token/s)', 'RPS', 'RPM' + ]) + writer.writerow([ + concurrency, + len(requests), prompt_tokens, completion_tokens, + f'{first_token_latency_min:.3f}' if stream_output else '-', + f'{first_token_latency_max:.3f}' if stream_output else '-', + f'{first_token_latency_ave:.3f}' if stream_output else '-', + f'{percentiles[0]:.3f}' if stream_output else '-', + f'{percentiles[1]:.3f}' if stream_output else '-', + f'{percentiles[2]:.3f}' if stream_output else '-', + f'{percentiles[3]:.3f}' if stream_output else '-', + f'{completion_token_throughput:.3f}', + f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}' + ]) def main(dataset: str, model_path: str, - concurrency: int = 1, - num_prompts: int = 1000, - tp: int = 1): - - engine = Engine(model_path, tp=tp) - tokenizer = engine.tokenizer - - requests = sample_requests(dataset, num_prompts, tokenizer) - - elapsed_time = engine.process_request(requests, concurrency) - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len in requests) - total_num_out_tokens = sum(output_len for _, _, output_len in requests) - print(f'Throughput requests: {len(requests) / elapsed_time:.2f} req/s') - print( - f'Throughput requests: {len(requests) * 60 / elapsed_time:.2f} req/min' - ) - print(f'Throughput tokens: {total_num_tokens / elapsed_time:.2f} tokens/s') - print('Throughput tokens(output only):' - f'{total_num_out_tokens / elapsed_time:.2f} tokens/s') + concurrency: int = 64, + num_prompts: int = 2000, + tp: int = 1, + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + stream_output: bool = True, + csv: str = './profile_throughput.csv', + log_level: str = 'ERROR', + seed: int = 0): + """Benchmark the request throughput of lmdeploy in localhost. + + Args: + dataset (str): Path to the dataset + model_path (str): Path to a model in localhost or a model_repo_id in huggingface.co + concurrency (int, optional): Number of working threads to process the sampled prompts. + Defaults to 64. + num_prompts (int, optional): Number of prompts to process. Defaults to 2000. + tp (int, optional): Number of GPUs for tensor parallel. Defaults to 1. + top_k (int, optional): The number of highest probability vocabulary tokens + to keep for top-k-filtering. Defaults to 1. + top_p (float, optional): the set of most probable tokens with + probabilities that add up to top_p or higher + are kept for generation. Defaults to 1.0. + temperature (float, optional): The value used to modulate the next token probabilities. + Defaults to 1.0. + stream_output (bool, optional): Indicator for streaming output. Defaults to True. + csv (str, optional): The path to save the result. + log_level(str, optional): The log level. Defaults to INFO + seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0. + """ # noqa + random.seed(seed) + os.environ['TM_LOG_LEVEL'] = log_level + + engine = Engine(model_path, + tp=tp, + top_k=top_k, + top_p=top_p, + temperature=temperature, + csv=csv) + + requests = sample_requests(dataset, num_prompts, engine.tokenizer) + + engine.process_request(requests, concurrency, stream_output) if __name__ == '__main__': diff --git a/builder/manywheel/build_all_wheel.sh b/builder/manywheel/build_all_wheel.sh index 967743f4c0..b8da6fd720 100755 --- a/builder/manywheel/build_all_wheel.sh +++ b/builder/manywheel/build_all_wheel.sh @@ -4,8 +4,10 @@ set -eou pipefail TOPDIR=$(git rev-parse --show-toplevel)/builder +CUDA_VER=${CUDA_VER:-11.8} + PLAT_NAME=manylinux2014_x86_64 -for cuver in 11.8; do +for cuver in ${CUDA_VER}; do DOCKER_TAG=cuda${cuver} OUTPUT_FOLDER=cuda${cuver}_dist for pyver in py38 py39 py310 py311; do diff --git a/builder/manywheel/entrypoint_build.sh b/builder/manywheel/entrypoint_build.sh index abb90562a2..c9c2cae6e9 100755 --- a/builder/manywheel/entrypoint_build.sh +++ b/builder/manywheel/entrypoint_build.sh @@ -11,8 +11,8 @@ source /opt/conda/bin/activate conda activate $PYTHON_VERSION cd lmdeploy -mkdir build && cd build -bash ../generate.sh +mkdir -p build && cd build && rm -rf * +bash ../generate.sh make make -j$(nproc) && make install if [ $? != 0 ]; then echo "build failed" diff --git a/docker/Dockerfile b/docker/Dockerfile index 6d41afa7a6..1cc53d3888 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ RUN rm /etc/apt/sources.list.d/cuda*.list && apt-get update && apt-get install - && rm -rf /var/lib/apt/lists/* RUN python3 -m pip install --no-cache-dir torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -RUN python3 -m pip install --no-cache-dir cmake +RUN python3 -m pip install --no-cache-dir cmake packaging ENV NCCL_LAUNCH_MODE=GROUP diff --git a/docs/en/benchmark/a100_fp16.md b/docs/en/benchmark/a100_fp16.md new file mode 100644 index 0000000000..b157137868 --- /dev/null +++ b/docs/en/benchmark/a100_fp16.md @@ -0,0 +1,134 @@ +# Benchmark on A100 (FP16) + +All the following results are tested on (x8) A100-80G CUDA 11.8. + +The tested lmdeploy version is `v0.1.0a1`. + +The commands provided below facilitate benchmarking both [static inference performance](#static-inference-benchmark) and [request throughput](#request-throughput-benchmark) on an A100-80G(x8) for models of various sizes. + +```shell +bash benchmark/benchmark_7b.sh +bash benchmark/benchmark_13b.sh +bash benchmark/benchmark_20b.sh +bash benchmark/benchmark_70b.sh +``` + +## Static Inference Benchmark + +FTL: **F**irst **T**oken **L**atency + +### llama2-7b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 1 | 1 | 128 | 100.02 | 76.55 | 0.011 | 0.01 | 0.011 | 0.009 | 0.009 | 0.01 | 0.011 | +| 1 | 1 | 128 | 128 | 102.21 | 76.59 | 0.022 | 0.022 | 0.022 | 0.01 | 0.01 | 0.01 | 0.01 | +| 1 | 1 | 128 | 2048 | 98.92 | 76.59 | 0.022 | 0.022 | 0.022 | 0.01 | 0.01 | 0.01 | 0.01 | +| 1 | 1 | 2048 | 128 | 86.1 | 76.77 | 0.139 | 0.139 | 0.14 | 0.01 | 0.01 | 0.01 | 0.011 | +| 1 | 1 | 2048 | 2048 | 93.78 | 76.77 | 0.14 | 0.139 | 0.141 | 0.011 | 0.011 | 0.011 | 0.011 | +| 16 | 1 | 1 | 128 | 1504.72 | 76.59 | 0.021 | 0.011 | 0.031 | 0.01 | 0.011 | 0.011 | 0.013 | +| 16 | 1 | 128 | 128 | 1272.47 | 76.77 | 0.129 | 0.023 | 0.149 | 0.011 | 0.011 | 0.012 | 0.014 | +| 16 | 1 | 128 | 2048 | 1010.62 | 76.77 | 0.13 | 0.023 | 0.144 | 0.015 | 0.018 | 0.02 | 0.021 | +| 16 | 1 | 2048 | 128 | 348.87 | 78.3 | 2.897 | 0.143 | 3.576 | 0.02 | 0.021 | 0.022 | 0.025 | +| 16 | 1 | 2048 | 2048 | 601.63 | 78.3 | 2.678 | 0.142 | 3.084 | 0.025 | 0.028 | 0.03 | 0.031 | +| 32 | 1 | 1 | 128 | 2136.73 | 76.62 | 0.079 | 0.014 | 0.725 | 0.011 | 0.012 | 0.013 | 0.021 | +| 32 | 1 | 128 | 128 | 2125.47 | 76.99 | 0.214 | 0.022 | 0.359 | 0.012 | 0.013 | 0.014 | 0.035 | +| 32 | 1 | 128 | 2048 | 1462.12 | 76.99 | 0.2 | 0.026 | 0.269 | 0.021 | 0.026 | 0.031 | 0.033 | +| 32 | 1 | 2048 | 128 | 450.43 | 78.3 | 4.288 | 0.143 | 5.267 | 0.031 | 0.032 | 0.034 | 0.161 | +| 32 | 1 | 2048 | 2048 | 733.34 | 78.34 | 4.118 | 0.19 | 5.429 | 0.04 | 0.045 | 0.05 | 0.053 | +| 64 | 1 | 1 | 128 | 4154.81 | 76.71 | 0.042 | 0.013 | 0.21 | 0.012 | 0.018 | 0.028 | 0.041 | +| 64 | 1 | 128 | 128 | 3024.07 | 77.43 | 0.44 | 0.026 | 1.061 | 0.014 | 0.018 | 0.026 | 0.158 | +| 64 | 1 | 128 | 2048 | 1852.06 | 77.96 | 0.535 | 0.027 | 1.231 | 0.03 | 0.041 | 0.048 | 0.053 | +| 64 | 1 | 2048 | 128 | 493.46 | 78.4 | 6.59 | 0.142 | 16.235 | 0.046 | 0.049 | 0.055 | 0.767 | +| 64 | 1 | 2048 | 2048 | 755.65 | 78.4 | 39.105 | 0.142 | 116.285 | 0.047 | 0.049 | 0.051 | 0.207 | + +### llama2-13b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 1 | 1 | 128 | 57.49 | 74.84 | 0.018 | 0.018 | 0.019 | 0.017 | 0.017 | 0.017 | 0.017 | +| 1 | 1 | 128 | 128 | 56.58 | 74.84 | 0.04 | 0.039 | 0.04 | 0.017 | 0.017 | 0.017 | 0.018 | +| 1 | 1 | 128 | 2048 | 55.29 | 74.84 | 0.04 | 0.04 | 0.04 | 0.018 | 0.018 | 0.018 | 0.019 | +| 1 | 1 | 2048 | 128 | 48.99 | 75.09 | 0.242 | 0.242 | 0.243 | 0.019 | 0.019 | 0.019 | 0.019 | +| 1 | 1 | 2048 | 2048 | 52.12 | 75.09 | 0.243 | 0.24 | 0.244 | 0.019 | 0.019 | 0.019 | 0.02 | +| 16 | 1 | 1 | 128 | 869.45 | 74.87 | 0.036 | 0.019 | 0.053 | 0.018 | 0.019 | 0.019 | 0.02 | +| 16 | 1 | 128 | 128 | 757.3 | 75.09 | 0.252 | 0.041 | 0.272 | 0.019 | 0.02 | 0.02 | 0.021 | +| 16 | 1 | 128 | 2048 | 605.88 | 75.09 | 0.253 | 0.041 | 0.275 | 0.026 | 0.03 | 0.033 | 0.034 | +| 16 | 1 | 2048 | 128 | 257.92 | 76.96 | 3.442 | 0.245 | 3.668 | 0.033 | 0.034 | 0.035 | 0.035 | +| 16 | 1 | 2048 | 2048 | 366.67 | 76.99 | 3.122 | 0.249 | 3.671 | 0.04 | 0.044 | 0.047 | 0.047 | +| 32 | 1 | 1 | 128 | 1667.5 | 74.9 | 0.034 | 0.021 | 0.057 | 0.019 | 0.02 | 0.021 | 0.023 | +| 32 | 1 | 128 | 128 | 1301.27 | 75.37 | 0.461 | 0.04 | 0.497 | 0.021 | 0.022 | 0.023 | 0.025 | +| 32 | 1 | 128 | 2048 | 860.14 | 75.84 | 0.833 | 0.041 | 1.151 | 0.034 | 0.042 | 0.047 | 0.048 | +| 32 | 1 | 2048 | 128 | 291.54 | 77.02 | 5.315 | 0.245 | 13.483 | 0.046 | 0.047 | 0.049 | 0.51 | +| 32 | 1 | 2048 | 2048 | 389.64 | 77.02 | 38.725 | 0.245 | 108.104 | 0.047 | 0.047 | 0.049 | 0.05 | +| 64 | 1 | 1 | 128 | 3049.16 | 74.96 | 0.044 | 0.025 | 0.073 | 0.02 | 0.022 | 0.026 | 0.029 | +| 64 | 1 | 128 | 128 | 2033.22 | 75.87 | 0.703 | 0.046 | 0.951 | 0.024 | 0.026 | 0.029 | 0.032 | +| 64 | 1 | 128 | 2048 | 998.86 | 76.9 | 7.805 | 0.042 | 60.1 | 0.045 | 0.047 | 0.05 | 0.063 | +| 64 | 1 | 2048 | 128 | 286.32 | 76.99 | 19.69 | 0.245 | 32.394 | 0.047 | 0.048 | 0.05 | 0.27 | +| 64 | 1 | 2048 | 2048 | 387.86 | 77.09 | 190.453 | 0.245 | 307.331 | 0.047 | 0.048 | 0.049 | 0.05 | + +### internlm-20b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 2 | 1 | 128 | 61.14 | 73.55 | 0.018 | 0.017 | 0.019 | 0.016 | 0.016 | 0.016 | 0.018 | +| 1 | 2 | 128 | 128 | 60.03 | 73.55 | 0.042 | 0.041 | 0.043 | 0.016 | 0.016 | 0.016 | 0.017 | +| 1 | 2 | 128 | 2048 | 58.26 | 73.55 | 0.042 | 0.042 | 0.043 | 0.017 | 0.017 | 0.018 | 0.018 | +| 1 | 2 | 2048 | 128 | 51.93 | 73.68 | 0.217 | 0.216 | 0.217 | 0.018 | 0.018 | 0.018 | 0.018 | +| 1 | 2 | 2048 | 2048 | 56.36 | 73.68 | 0.217 | 0.217 | 0.217 | 0.018 | 0.018 | 0.018 | 0.018 | +| 16 | 2 | 1 | 128 | 903.01 | 73.65 | 0.034 | 0.018 | 0.051 | 0.017 | 0.018 | 0.019 | 0.02 | +| 16 | 2 | 128 | 128 | 794.13 | 73.74 | 0.227 | 0.043 | 0.248 | 0.018 | 0.019 | 0.02 | 0.021 | +| 16 | 2 | 128 | 2048 | 669.87 | 73.74 | 0.227 | 0.043 | 0.25 | 0.024 | 0.027 | 0.029 | 0.03 | +| 16 | 2 | 2048 | 128 | 288.60 | 75.60 | 3.09 | 0.247 | 4.485 | 0.029 | 0.03 | 0.031 | 0.032 | +| 16 | 2 | 2048 | 2048 | 441.46 | 75.61 | 3.172 | 0.219 | 4.442 | 0.035 | 0.037 | 0.04 | 0.041 | +| 32 | 2 | 1 | 128 | 1673.64 | 73.71 | 0.037 | 0.02 | 0.066 | 0.019 | 0.02 | 0.021 | 0.023 | +| 32 | 2 | 128 | 128 | 1347.57 | 73.90 | 0.351 | 0.043 | 0.436 | 0.02 | 0.021 | 0.023 | 0.025 | +| 32 | 2 | 128 | 2048 | 1025.62 | 73.90 | 0.391 | 0.042 | 0.441 | 0.031 | 0.037 | 0.041 | 0.043 | +| 32 | 2 | 2048 | 128 | 352.45 | 75.74 | 6.062 | 0.218 | 6.3 | 0.042 | 0.043 | 0.045 | 0.046 | +| 32 | 2 | 2048 | 2048 | 514.60 | 75.77 | 10.36 | 0.222 | 70.328 | 0.049 | 0.05 | 0.051 | 0.053 | +| 64 | 2 | 1 | 128 | 2954.34 | 73.82 | 0.05 | 0.029 | 0.074 | 0.021 | 0.023 | 0.026 | 0.03 | +| 64 | 2 | 128 | 128 | 2122.92 | 74.24 | 0.591 | 0.047 | 0.808 | 0.024 | 0.026 | 0.029 | 0.032 | +| 64 | 2 | 128 | 2048 | 1276.61 | 75.18 | 2.529 | 0.049 | 41.212 | 0.042 | 0.048 | 0.052 | 0.055 | +| 64 | 2 | 2048 | 128 | 350.82 | 75.88 | 12.382 | 0.219 | 20.986 | 0.05 | 0.051 | 0.054 | 0.249 | +| 64 | 2 | 2048 | 2048 | 512.37 | 76.26 | 111.149 | 0.221 | 211.531 | 0.05 | 0.051 | 0.052 | 0.055 | + +### llama2-70b + +| batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | +| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | +| 1 | 4 | 1 | 128 | 33.94 | 73.72 | 0.031 | 0.03 | 0.031 | 0.029 | 0.029 | 0.029 | 0.03 | +| 1 | 4 | 128 | 128 | 33.63 | 73.72 | 0.074 | 0.073 | 0.074 | 0.029 | 0.029 | 0.029 | 0.03 | +| 1 | 4 | 128 | 2048 | 32.38 | 73.72 | 0.074 | 0.074 | 0.075 | 0.031 | 0.031 | 0.031 | 0.031 | +| 1 | 4 | 2048 | 128 | 28.32 | 73.78 | 0.402 | 0.401 | 0.403 | 0.031 | 0.031 | 0.031 | 0.051 | +| 1 | 4 | 2048 | 2048 | 31.9 | 73.78 | 0.405 | 0.402 | 0.407 | 0.031 | 0.031 | 0.031 | 0.031 | +| 16 | 4 | 1 | 128 | 468.52 | 73.72 | 0.071 | 0.034 | 0.939 | 0.03 | 0.031 | 0.032 | 0.251 | +| 16 | 4 | 128 | 128 | 439.77 | 73.81 | 0.437 | 0.08 | 0.687 | 0.03 | 0.031 | 0.032 | 0.207 | +| 16 | 4 | 128 | 2048 | 482.99 | 73.81 | 0.403 | 0.079 | 0.44 | 0.033 | 0.033 | 0.035 | 0.036 | +| 16 | 4 | 2048 | 128 | 189.34 | 73.98 | 5.776 | 0.437 | 7.612 | 0.035 | 0.036 | 0.036 | 0.037 | +| 16 | 4 | 2048 | 2048 | 399.42 | 73.98 | 5.773 | 0.411 | 6.844 | 0.036 | 0.037 | 0.038 | 0.041 | +| 32 | 4 | 1 | 128 | 906.03 | 73.75 | 0.098 | 0.043 | 0.253 | 0.032 | 0.033 | 0.035 | 0.178 | +| 32 | 4 | 128 | 128 | 746.36 | 73.91 | 0.749 | 0.078 | 1.026 | 0.032 | 0.033 | 0.035 | 0.438 | +| 32 | 4 | 128 | 2048 | 853.56 | 73.91 | 0.732 | 0.076 | 1.129 | 0.036 | 0.038 | 0.041 | 0.158 | +| 32 | 4 | 2048 | 128 | 232.6 | 73.99 | 11.834 | 0.408 | 13.321 | 0.04 | 0.041 | 0.043 | 0.248 | +| 32 | 4 | 2048 | 2048 | 636.23 | 73.99 | 11.711 | 0.409 | 12.689 | 0.043 | 0.045 | 0.048 | 0.179 | +| 64 | 4 | 1 | 128 | 1425.79 | 73.81 | 0.213 | 0.046 | 1.264 | 0.037 | 0.039 | 0.044 | 0.329 | +| 64 | 4 | 128 | 128 | 1159.84 | 73.96 | 1.292 | 0.107 | 2.676 | 0.037 | 0.04 | 0.045 | 0.378 | +| 64 | 4 | 128 | 2048 | 1391.8 | 73.95 | 1.173 | 0.135 | 1.623 | 0.043 | 0.047 | 0.052 | 0.251 | +| 64 | 4 | 2048 | 128 | 270.47 | 74.02 | 17.402 | 0.452 | 24.164 | 0.05 | 0.052 | 0.057 | 0.345 | +| 64 | 4 | 2048 | 2048 | 930.46 | 74.01 | 21.29 | 0.423 | 24.498 | 0.055 | 0.059 | 0.065 | 0.299 | + +## Request Throughput Benchmark + +FTL: **F**irst **T**oken **L**atency + +| model | batch | tp | num_prompts | PRS | PRM | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | throughput(out tok/s) | throughput(total tok/s) | +| ------------ | ----- | --- | ----------- | ------ | ------- | ----------- | ----------- | ----------- | --------------------- | ----------------------- | +| llama2-7b | 64 | 1 | 3000 | 10.275 | 616.477 | 0.092 | 0.036 | 1.145 | 2562.435 | 5283.547 | +| | 128 | 1 | 3000 | 12.611 | 756.677 | 0.205 | 0.056 | 2.241 | 3210.281 | 6619.357 | +| llama2-13b | 64 | 1 | 3000 | 6.337 | 380.244 | 0.159 | 0.051 | 2.048 | 1474.786 | 3039.398 | +| | 128 | 1 | 3000 | 7.588 | 455.273 | 0.412 | 0.085 | 4.445 | 1765.788 | 3639.128 | +| internlm-20b | 64 | 2 | 3000 | 7.842 | 470.516 | 0.166 | 0.059 | 2.461 | 1564.696 | 3311.16 | +| | 128 | 2 | 3000 | 9.776 | 586.568 | 0.34 | 0.079 | 5.808 | 1950.627 | 4127.855 | +| llama2-70b | 64 | 4 | 3000 | 4.285 | 257.08 | 0.301 | 0.083 | 4.689 | 1000.376 | 2062.7 | +| | 128 | 4 | 3000 | 5.833 | 349.996 | 0.633 | 0.107 | 8.431 | 1361.939 | 2808.216 | +| | 256 | 4 | 3000 | 6.568 | 394.108 | 1.49 | 0.171 | 19.52 | 1533.592 | 3162.15 | diff --git a/docs/en/benchmark/profile_api_server.md b/docs/en/benchmark/profile_api_server.md new file mode 100644 index 0000000000..23b7f8d385 --- /dev/null +++ b/docs/en/benchmark/profile_api_server.md @@ -0,0 +1,107 @@ +# API Server Performance Test Method + +The way to profiling api_server performance is similar to the method for [profiling throughput](./profile_throughput.md). The difference is api_server should be launched successfully before testing. + +The evaluation script is `profile_restful_api.py`. Before running it, please install the lmdeploy precompiled package, download the evaluation script and the test dataset: + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +During performance test, a specific model needs to be inputted. We recommend converting the model into turbomind format via `lmdeploy convert`, then proceed with testing. +The reason is to conveniently adjust the parameters of the inference engine in order to achieve better performance, such as batch size (max_batch_size), K/V cache size (max_cache_entry_count), etc. For detailed explanations of these parameters, please refer to [here](../turbomind_config.md). + +In the following sections, we assume the model is in turbomind format. + +## Metrics + +LMDeploy records the performance metrics like first token latency, token throughput (tokens/s) and request throughput (RPM) + +`first_token_latency` is only reported in the case of streaming inference. + +The formula for calculating `token throughput` is: + +$$ +TokenThroughput = Number\\ of\\ generated\\ tokens/TotalTime +$$ + +And the formula for calculating `request throughput` is: + +$$ +RPM(request\\ per\\ minute)=Number\\ of\\ prompts/TotalTime * 60 +$$ + +Total time includes prefill time. + +## Example + +We take `internlm-7b` as an example. The entire benchmark procedure is: + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +# get internlm-7b from huggingface and convert it to turbomind format +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# launch server +lmdeploy serve api_server ./internlm-7b --server-port 23333 + +# open another terminal and run the following command in the directory `lmdeploy/benchmark` +python3 ./profile_restful_api.py http://0.0.0.0:23333 ./internlm-7b/triton_models/tokenizer ./ShareGPT_V3_unfiltered_cleaned_split.json +``` + +## Methods + +Please refer to [this](../restful_api.md) guide to start `api_server`. +The argument `--instance-num` reflects the inference instance number. When more than `--instance-num` requests arrive at the `api_server` at the same time, the exceeding part of the requests will wait in the inference queue. + +```shell +python3 profile_restful_api.py +``` + +The required parameters are: + +- `server_addr` + + The address of api_server with format `http://{server_ip}:{server_port}` + +- `tokenizer_path` + + The path of the tokenizer model, which is used to encode the dataset to get the token size of prompts and responses + +- `dataset` + + The path of the downloaded dataset + +Optional arguments are listed as below: + +- `--concurrency` + + It represents the number of request threads with default value 64. Requests of concurrent threads will be batched by the inference engine. Its value should not exceed the number of inference instances in the api_server. + Otherwise, the excess requests will wait in the inference queue. + +- `--num-prompts` + + The number of sampled prompts from dataset to process. The default is 2000. + +- `--top_p` and `--temperature` + + They are used to sample the generated token_id. + +- `--stream_output` + + Indicator for streaming output. The default is `False`. + +- `--csv` + + The path of a csv file to save the result with default value `../profile_api_server.csv` + +- `--seed` + + It is the seed used in sampling prompts from dataset with default value 0. diff --git a/docs/en/benchmark/profile_generation.md b/docs/en/benchmark/profile_generation.md new file mode 100644 index 0000000000..5b117b8828 --- /dev/null +++ b/docs/en/benchmark/profile_generation.md @@ -0,0 +1,88 @@ +# Static Inference Performance Test Method + +We view the performance of the inference engine under the fixed batch and fixed input/output token as static inference performance. + +The evaluation script is `profile_generation.py`. Before running it, please install the lmdeploy precompiled package and download the evaluation script: + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +``` + +During performance test, a specific model needs to be inputted. We recommend converting the model into turbomind format via `lmdeploy convert`, then proceed with testing. +The reason is to conveniently adjust the parameters of the inference engine in order to achieve better performance, such as batch size (max_batch_size), K/V cache size (max_cache_entry_count), etc. For detailed explanations of these parameters, please refer to [here](../turbomind_config.md). + +In the following sections, we assume the model is in turbomind format. + +## Metrics + +LMDeploy records test results like first token latency, token throughput (tokens/s), percentile data of each token's latency (P50, P75, P95, P99), GPU mem, etc. + +`first_token_latency` is only reported in the case of streaming inference. + +The formula for calculating `throughput` is: + +$$ +TokenThroughput = Number\\ of\\ generated\\ tokens/TotalTime +$$ + +Total time includes prefill time. + +During the test process, all graphics cards on the node should not run any other programs, otherwise the statistics of GPU mem would be inaccurate. + +## Example + +We take `internlm-7b` as an example. The entire benchmark procedure is: + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark + +# get internlm-7b from huggingface and convert it to turbomind format +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# benchmark +python3 profile_generation.py ./internlm-7b +``` + +## Command details + +```shell +python3 profile_generation.py +``` + +`model_path` refers to the path on localhost where the model in turbomind format is located. + +Optional arguments are listed as below: + +- `--concurrency` + + It represents the number of request threads. Requests of concurrent threads will be batched by the inference engine. It is a list with default value `[1, 16, 32, 64]`, which implies that the performance under 4 different levels of concurrency is tested. The level of concurrency should not exceed `max_batch_size` in [turbomind config](../turbomind_config.md#turbomind-20-config). Otherwise, there will be `max_batch_size - concurrency` number of threads randomly waiting almost at any time during test. + +- `--prompt-tokens` and `--completion-tokens` + + Input token and output token numbers. They are lists of the same length. The elements in the list correspond one-to-one, that is, + the pair `(prompt_tokens[i], completion_tokens[i])` is a test case. In the default list `[1, 128, 128, 2048, 2048]` and `[128, 128, 2048, 128, 2048]`, the test cases are `(1, 128)`, `(128, 128)`, `(128, 2048)`, `(2048, 128)` and `(2048, 2048)` + +- `--tp` + + The number of GPUs used when the inference is in tensor parallel mode. It must be a power of 2. The default is 1. + +- `--top_k`, `--top_p` and `--temperature` + + They are used to sample the generated token_id. + +- `--csv` + + A csv file path used to store test results. The default is `./profile_generation.csv` + +- `--log-level` + + The log level. The default is 'ERROR'. + +- `--test-round` + + The number of test rounds is set to 10 by default. This means that each case will undergo 10 rounds of testing, and the average result will be calculated. + +We refer to a tuple of `(#concurrency, #prompt_token, #completion_token)` as a test case. Therefore, the total number of test cases (`#test_cases`) executed by the script is `len(concurrency) * len(prompt-tokens)`, and the total test rounds are `#test_cases * #test_round`. Users can flexibly adjust test parameters according to their actual situation. diff --git a/docs/en/benchmark/profile_throughput.md b/docs/en/benchmark/profile_throughput.md new file mode 100644 index 0000000000..31bfe1f959 --- /dev/null +++ b/docs/en/benchmark/profile_throughput.md @@ -0,0 +1,105 @@ +# Request Throughput Test Method + +In the applications, the length of the user's input prompt and the size of generated tokens are dynamic. The static inference performance is insufficient to reflect the inference engine's ability to handle the dynamic characteristics. + +Therefore, it is necessary to use real dialogue data to evaluate the dynamic inference capabilities of the inference engine. This article will introduce how to test the dynamic inference performance of LMDeploy on localhost. + +The evaluation script is `profile_throughput.py`. Before running it, please install the lmdeploy precompiled package, download the evaluation script and the test dataset: + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +During performance test, a specific model needs to be inputted. We recommend converting the model into turbomind format via `lmdeploy convert`, then proceed with testing. +The reason is to conveniently adjust the parameters of the inference engine in order to achieve better performance, such as batch size (max_batch_size), K/V cache size (max_cache_entry_count), etc. For detailed explanations of these parameters, please refer to [here](../turbomind_config.md). + +In the following sections, we assume the model is in turbomind format. + +## Metrics + +LMDeploy records the performance metrics like first token latency, token throughput (tokens/s) and request throughput (RPM) + +`first_token_latency` is only reported in the case of streaming inference. + +The formula for calculating `token throughput` is: + +$$ +TokenThroughput = Number\\ of\\ generated\\ tokens/TotalTime +$$ + +And the formula for calculating `request throughput` is: + +$$ +RPM(request\\ per\\ minute) = Number\\ of\\ prompts/TotalTime * 60 +$$ + +Total time includes prefill time. + +## Example + +We take `internlm-7b` as an example. The entire benchmark procedure is: + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +# get internlm-7b from huggingface and convert it to turbomind format +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +python3 profile_throughput.py ./ShareGPT_V3_unfiltered_cleaned_split.json ./internlm-7b +``` + +## Command details + +```shell +python3 profile_throughput.py +``` + +The required parameters are: + +- `dataset` + + The path of the downloaded dataset + +- `model_path` + + The path on localhost where the model in turbomind format is located. + +Optional arguments are listed as below: + +- `--concurrency` + + It represents the number of request threads with default value 64. Requests of concurrent threads will be batched by the inference engine. Its value should not exceed `max_batch_size` in `config.ini`. Otherwise, the excess requests will wait in the inference queue. + +- `--num-prompts` + + The number of sampled prompts from dataset to process. The default is 2000. + +- `--tp` + + The number of GPUs used when the inference is in tensor parallel mode. It must be a power of 2. The default is 1. + +- `--top_k`、`--top_p` and `--temperature` + + They are used to sample the generated token_id. + +- `--stream_output` + + Indicator for streaming output. The default is `True`. + +- `--csv` + + The path of a csv file to save the result with default value `./profile_throughput.csv` + +- `--log-level` + + The log level. The default is `ERROR`. + +- `--seed` + + It is the seed used in sampling prompts from dataset with default value 0. diff --git a/docs/en/benchmark/profile_triton_server.md b/docs/en/benchmark/profile_triton_server.md new file mode 100644 index 0000000000..267d830055 --- /dev/null +++ b/docs/en/benchmark/profile_triton_server.md @@ -0,0 +1,107 @@ +# Triton Inference Server Performance Test Method + +Triton Inference Server (TIS) is another serving method supported by LMDeploy besides from api_server. Its performance testing methods and metrics are similar to those of [api_server](./profile_api_server.md). + +The evaluation script is `profile_serving.py`. Before running it, please install the lmdeploy precompiled package, download the evaluation script and the test dataset: + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +During performance test, a specific model needs to be inputted. We recommend converting the model into turbomind format via `lmdeploy convert`, then proceed with testing. +The reason is to conveniently adjust the parameters of the inference engine in order to achieve better performance, such as batch size (max_batch_size), K/V cache size (max_cache_entry_count), etc. For detailed explanations of these parameters, please refer to [here](../turbomind_config.md). + +In the following sections, we assume the model is in turbomind format. + +## Metrics + +LMDeploy records the performance metrics like first token latency, token throughput (tokens/s) and request throughput (RPM) + +`first_token_latency` is only reported in the case of streaming inference. + +The formula for calculating `token throughput` is: + +$$ +TokenThroughput=Number\\ of\\ generated\\ tokens/TotalTime +$$ + +And the formula for calculating `request throughput` is: + +$$ +RPM(request\\ per\\ minute)=Number\\ of\\ prompts/TotalTime * 60 +$$ + +Total time includes prefill time. + +## Example + +We take `internlm-7b` as an example. The entire benchmark procedure is: + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +# get internlm-7b from huggingface and convert it to turbomind format +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# launch server +bash ./internlm-7b/service_docker_up.sh + +# open another terminal and run the following command in the directory `lmdeploy/benchmark` +python3 ./profile_serving 0.0.0.0:33337 ./internlm-7b/triton_models/tokenizer ./ShareGPT_V3_unfiltered_cleaned_split.json +``` + +## Command details + +```shell +python3 profile_serving.py +``` + +The required parameters are: + +- `server_addr` + + The address of api_server with format `{server_ip}:{server_port}` + +- `tokenizer_path` + + The path of the tokenizer model, which is used to encode the dataset to get the token size of prompts and responses + +- `dataset` + + The path of the downloaded dataset + +Optional arguments are listed as below: + +- `--concurrency` + + It represents the number of request threads with default value 32. Requests of concurrent threads will be batched by the inference engine. + It is recommended that `concurrency` does not exceed the `max_batch_size` in `config.ini`, nor should it exceed the number of inference instances in `triton_models`. + Otherwise, the excess requests will wait in the inference queue. + + The configuration item for the number of inference instances is `instance_group`, which is located in the file `{model_path}/triton_models/interactive/config.pbtxt`, and the default is 48. + +- `--num-prompts` + + The number of sampled prompts from dataset to process. The default is 1000. It is suggested 2000 when `concurrency >= 64` + +- `--top_k`、`--top_p` and `--temperature` + + They are used to sample the generated token_id. + +- `--stream_output` + + Indicator for streaming output. The default is `True`. + +- `--csv` + + The path of a csv file to save the result with default value `../profile_tis.csv` + +- `--seed` + + It is the seed used in sampling prompts from dataset with default value 0. diff --git a/docs/en/build.md b/docs/en/build.md index 7ee53ac90c..b2de1d34b6 100644 --- a/docs/en/build.md +++ b/docs/en/build.md @@ -1,22 +1,80 @@ -## Build from source +# Build from source -- install packages for compiling and running: +LMDeploy provides prebuilt package that can be easily installed by `pip install lmdeploy`. - ```shell - conda create -n lmdeploy python=3.10 - conda activate lmdeploy +If you have requests to build lmdeploy from source, please clone lmdeploy repository from GitHub, and follow instructions in next sections - git clone https://github.com/InternLM/lmdeploy.git - cd lmdeploy +```shell +git clone --depth=1 https://github.com/InternLM/lmdeploy +``` - pip install -r requirements.txt - conda install openmpi-mpicxx nccl rapidjson -c conda-forge - ``` +## Build in Docker (recommended) + +We highly advise using the provided docker image for lmdeploy build to circumvent complex environment setup. + +The docker image is `openmmlab/lmdeploy-builder:cuda11.8`. Make sure that docker is installed before using this image. + +In the root directory of the lmdeploy source code, please run the following command: + +```shell +cd lmdeploy # the home folder of lmdeploy source code +bash builder/manywheel/build_all_wheel.sh +``` + +All the wheel files for lmdeploy under py3.8 - py3.11 will be found in the `builder/manywheel/cuda11.8_dist` directory, such as, + +```text +builder/manywheel/cuda11.8_dist/ +├── lmdeploy-0.0.12-cp310-cp310-manylinux2014_x86_64.whl +├── lmdeploy-0.0.12-cp311-cp311-manylinux2014_x86_64.whl +├── lmdeploy-0.0.12-cp38-cp38-manylinux2014_x86_64.whl +└── lmdeploy-0.0.12-cp39-cp39-manylinux2014_x86_64.whl +``` + +If the wheel file for a specific Python version is required, such as py3.8, please execute: + +```shell +bash builder/manywheel/build_wheel.sh py38 manylinux2014_x86_64 cuda11.8 cuda11.8_dist +``` + +And the wheel file will be found in the `builder/manywheel/cuda11.8_dist` directory. + +You can use `pip install` to install the wheel file that matches the Python version on your host machine. -- build and install lmdeploy: +## Build in localhost (optional) +Firstly, please make sure gcc version is no less than 9, which can be conformed by `gcc --version`. + +Then, follow the steps below to set up the compilation environment: + +- install the dependent packages: + ```shell + pip install -r requirements.txt + apt-get install rapidjson-dev + ``` +- install [nccl](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html), and set environment variables: + ```shell + export NCCL_ROOT_DIR=/path/to/nccl/build + export NCCL_LIBRARIES=/path/to/nccl/build/lib + ``` +- install openmpi from source: ```shell + wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz + tar xf openmpi-4.1.5.tar.gz + cd openmpi-4.1.5 + ./configure + make -j$(nproc) && make install + ``` +- build and install lmdeploy libraries: + ```shell + apt install ninja-build # install ninja + cd lmdeploy # the home folder of lmdeploy mkdir build && cd build sh ../generate.sh - make -j$(nproc) && make install + ninja -j$(nproc) && ninja install + ``` +- install lmdeploy python package: + ```shell + cd .. + pip install -e . ``` diff --git a/docs/en/faq.md b/docs/en/faq.md index 636da5c947..6f109997cc 100644 --- a/docs/en/faq.md +++ b/docs/en/faq.md @@ -17,7 +17,7 @@ It may have been caused by the following reasons. 1. You haven't installed lmdeploy's precompiled package. `_turbomind` is the pybind package of c++ turbomind, which involves compilation. It is recommended that you install the precompiled one. ```shell -pip install lmdeploy +pip install lmdeploy[all] ``` 2. If you have installed it and still encounter this issue, it is probably because you are executing turbomind-related command in the root directory of lmdeploy source code. Switching to another directory will fix it @@ -26,7 +26,7 @@ pip install lmdeploy ### libnccl.so.2 not found -Make sure you have install lmdeploy (>=v0.0.5) through `pip install lmdeploy`. +Make sure you have install lmdeploy (>=v0.0.5) through `pip install lmdeploy[all]`. If the issue still exists after lmdeploy installation, add the path of `libnccl.so.2` to environment variable LD_LIBRARY_PATH. diff --git a/docs/en/kv_int8.md b/docs/en/kv_int8.md index 1f5f5aa125..10df19e84a 100644 --- a/docs/en/kv_int8.md +++ b/docs/en/kv_int8.md @@ -18,7 +18,7 @@ dequant: f = q * scale + zp Convert the Hugging Face model format to the TurboMind inference format to create a workspace directory. ```bash -python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-chat-7b +lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b ``` If you already have a workspace directory, skip this step. @@ -29,7 +29,7 @@ Get the quantization parameters by these two steps: ```bash # get minmax -python3 -m lmdeploy.lite.apis.calibrate \ +lmdeploy lite calibrate \ --model $HF_MODEL \ --calib_dataset 'c4' \ # Support c4, ptb, wikitext2, pileval --calib_samples 128 \ # Number of samples in the calibration set, if the memory is not enough, it can be adjusted appropriately @@ -37,7 +37,7 @@ python3 -m lmdeploy.lite.apis.calibrate \ --work_dir $WORK_DIR \ # Directory for saving quantized statistical parameters and quantized weights in Pytorch format # get quant parameters -python3 -m lmdeploy.lite.apis.kv_qparams \ +lmdeploy lite kv_qparams \ --work_dir $WORK_DIR \ # Directory of the last output --turbomind_dir workspace/triton_models/weights/ \ # Directory to save the quantization parameters --kv_sym False \ # Symmetric or asymmetric quantization, default is False @@ -52,19 +52,14 @@ You can also first set `turbomind_dir` to a private directory, then copy the sca Modify `workspace/triton_models/weights/config.ini`: -- Set use_context_fmha to 0, which means turning off flashattention - Set quant_policy to 4. This means enabling kv_cache int8 -This is because there are two versions of flashattention, v1 and v2, and kv_cache int8 has also previously realized the symmetric version. - -Considering there are four combinations of kernels needed to be implemented, premature optimization when the algorithm is uncertain can be disastrous for software. - ### **Step Four** Test the chat performance. ```bash -python3 -m lmdeploy.turbomind.chat ./workspace +lmdeploy chat turbomind ./workspace ``` ## GPU Memory Test diff --git a/docs/en/load_hf.md b/docs/en/load_hf.md new file mode 100644 index 0000000000..ddf6fe8bfd --- /dev/null +++ b/docs/en/load_hf.md @@ -0,0 +1,71 @@ +# Load huggingface model directly + +Starting from v0.1.0, Turbomind adds the ability to pre-process the model parameters on-the-fly while loading them from huggingface style models. + +## Supported model type + +Currently, Turbomind support loading three types of model: + +1. A lmdeploy-quantized model hosted on huggingface.co, such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc. +2. Other LM models on huggingface.co like Qwen/Qwen-7B-Chat +3. A model converted by `lmdeploy convert`, legacy format + +## Usage + +### 1) A lmdeploy-quantized model + +For models quantized by `lmdeploy.lite` such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc. + +``` +repo_id=internlm/internlm-chat-20b-4bit +model_name=internlm-chat-20b +# or +# repo_id=/path/to/downloaded_model + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 2) Other LM models + +For other LM models such as Qwen/Qwen-7B-Chat or baichuan-inc/Baichuan2-7B-Chat. LMDeploy supported models can be viewed through `lmdeploy list`. + +``` +repo_id=Qwen/Qwen-7B-Chat +model_name=qwen-7b +# or +# repo_id=/path/to/Qwen-7B-Chat/local_path + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 3) A model converted by `lmdeploy convert` + +The usage is like previous + +``` +# Convert a model +lmdeploy convert /path/to/model ./workspace --model-name MODEL_NAME + +# Inference by TurboMind +lmdeploy chat turbomind ./workspace + +# Serving with gradio +lmdeploy serve gradio ./workspace + +# Serving with Restful API +lmdeploy serve api_server ./workspace --instance_num 32 --tp 1 +``` diff --git a/docs/en/pytorch.md b/docs/en/pytorch.md index e3662ab373..e4cd5a9cbe 100644 --- a/docs/en/pytorch.md +++ b/docs/en/pytorch.md @@ -9,13 +9,13 @@ This submodule allow user to chat with language model through command line, and **Example 1**: Chat with default setting ```shell -python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL +lmdeploy chat torch $PATH_TO_HF_MODEL ``` **Example 2**: Disable sampling and chat history ```shell -python -m lmdeploy.pytorch.chat \ +lmdeploy chat torch \ $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ --temperature 0 --max-history 0 ``` @@ -23,7 +23,7 @@ python -m lmdeploy.pytorch.chat \ **Example 3**: Accelerate with deepspeed inference ```shell -python -m lmdeploy.pytorch.chat \ +lmdeploy chat torch \ $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ --accel deepspeed ``` diff --git a/docs/en/restful_api.md b/docs/en/restful_api.md index cb70e26375..279f563b53 100644 --- a/docs/en/restful_api.md +++ b/docs/en/restful_api.md @@ -2,57 +2,67 @@ ### Launch Service +The user can open the http url print by the following command in a browser. + +- **Please check the http url for the detailed api usage!!!** +- **Please check the http url for the detailed api usage!!!** +- **Please check the http url for the detailed api usage!!!** + ```shell -python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1 +lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${server_port} --instance_num 64 --tp 1 ``` -Then, the user can open the swagger UI: `http://{server_ip}:{server_port}` for the detailed api usage. -We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try -our own api which provides more arguments for users to modify. The performance is comparatively better. +We provide four restful api in total. Three of them are in OpenAI format. + +- /v1/chat/completions +- /v1/models +- /v1/completions + +However, we recommend users try +our own api `/v1/chat/interactive` which provides more arguments for users to modify. The performance is comparatively better. + +**Note** please, if you want to launch multiple requests, you'd better set different `session_id` for both +`/v1/chat/completions` and `/v1/chat/interactive` apis. Or, we will set them random values. ### python -Here is an example for our own api `generate`. +We have integrated the client-side functionalities of these services into the `APIClient` class. Below are some examples demonstrating how to invoke the `api_server` service on the client side. + +If you want to use the `/v1/chat/completions` endpoint, you can try the following code: + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +messages = [{"role": "user", "content": "Say this is a test!"}] +for item in api_client.chat_completions_v1(model=model_name, messages=messages): + print(item) +``` + +For the `/v1/completions` endpoint. If you want to use the `/v1/completions` endpoint, you can try: + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +for item in api_client.completions_v1(model=model_name, prompt='hi'): + print(item) +``` + +Lmdeploy supports maintaining session histories on the server for `/v1/chat/interactive` api. We disable the +feature by default. + +- On interactive mode, the chat history is kept on the server. In a multiple rounds of conversation, you should set + `interactive_mode = True` and the same `session_id` (can't be -1, it's the default number) to `/v1/chat/interactive` for requests. +- On normal mode, no chat history is kept on the server. + +The interactive mode can be controlled by the `interactive_mode` boolean parameter. The following is an example of normal mode. If you want to experience the interactive mode, simply pass in `interactive_mode=True`. ```python -import json -import requests -from typing import Iterable, List - - -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int, - stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = True, - ignore_eos: bool = False) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos - } - response = requests.post( - api_url, headers=headers, json=pload, stream=stream) - for chunk in response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\n'): - if chunk: - data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - yield output, tokens - - -for output, tokens in get_streaming_response( - "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0, - 512): - print(output, end='') +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +for item in api_client.generate(prompt='hi'): + print(item) ``` ### Java/Golang/Rust @@ -84,16 +94,15 @@ List Models: curl http://{server_ip}:{server_port}/v1/models ``` -Generate: +Interactive Chat: ```bash -curl http://{server_ip}:{server_port}/generate \ +curl http://{server_ip}:{server_port}/v1/chat/interactive \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello! How are you?", "session_id": 1, - "sequence_start": true, - "sequence_end": true + "interactive_mode": true }' ``` @@ -104,19 +113,19 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", - "messages": [{"role": "user", "content": "Hello! Ho are you?"}] + "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` -Embeddings: +Text Completions: -```bash -curl http://{server_ip}:{server_port}/v1/embeddings \ - -H "Content-Type: application/json" \ +```shell +curl http://{server_ip}:{server_port}/v1/completions \ + -H 'Content-Type: application/json' \ -d '{ - "model": "internlm-chat-7b", - "input": "Hello world!" - }' + "model": "llama", + "prompt": "two steps to build a house:" +}' ``` ### CLI client @@ -125,7 +134,7 @@ There is a client script for restful api server. ```shell # restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -python -m lmdeploy.serve.openai.api_client restful_api_url +lmdeploy serve api_client api_server_url ``` ### webui @@ -133,23 +142,19 @@ python -m lmdeploy.serve.openai.api_client restful_api_url You can also test restful-api through webui. ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True -python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` ### FAQ -1. When user got `"finish_reason":"length"` which means the session is too long to be continued. - Please add `"renew_session": true` into the next request. +1. When user got `"finish_reason":"length"`, it means the session is too long to be continued. The session length can be + modified by passing `--session_len` to api_server. 2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service. -3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards. - -4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue, - - - kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses +3. When the request with the same `session_id` to `/v1/chat/interactive` got a empty return value and a negative `tokens`, please consider setting `interactive_mode=false` to restart the session. -5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions` +4. The `/v1/chat/interactive` api disables engaging in multiple rounds of conversation by default. The input argument `prompt` consists of either single strings or entire chat histories. diff --git a/docs/en/serving.md b/docs/en/serving.md index 1e6f783d7a..6cc18018d0 100644 --- a/docs/en/serving.md +++ b/docs/en/serving.md @@ -8,7 +8,7 @@ You can download [llama-2 models from huggingface](https://huggingface.co/meta-l 7B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-7b-chat-hf +lmdeploy convert llama2 /path/to/llama-2-7b-chat-hf bash workspace/service_docker_up.sh ``` @@ -18,7 +18,7 @@ bash workspace/service_docker_up.sh 13B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-13b-chat-hf --tp 2 +lmdeploy convert llama2 /path/to/llama-2-13b-chat-hf --tp 2 bash workspace/service_docker_up.sh ``` @@ -28,7 +28,7 @@ bash workspace/service_docker_up.sh 70B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-70b-chat-hf --tp 8 +lmdeploy convert llama2 /path/to/llama-2-70b-chat-hf --tp 8 bash workspace/service_docker_up.sh ``` @@ -42,7 +42,7 @@ Weights for the LLaMA models can be obtained from by filling out [this form](htt 7B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-7b llama \ +lmdeploy convert llama /path/to/llama-7b llama \ --tokenizer_path /path/to/tokenizer/model bash workspace/service_docker_up.sh ``` @@ -53,7 +53,7 @@ bash workspace/service_docker_up.sh 13B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-13b llama \ +lmdeploy convert llama /path/to/llama-13b llama \ --tokenizer_path /path/to/tokenizer/model --tp 2 bash workspace/service_docker_up.sh ``` @@ -64,7 +64,7 @@ bash workspace/service_docker_up.sh 30B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-30b llama \ +lmdeploy convert llama /path/to/llama-30b llama \ --tokenizer_path /path/to/tokenizer/model --tp 4 bash workspace/service_docker_up.sh ``` @@ -75,7 +75,7 @@ bash workspace/service_docker_up.sh 65B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-65b llama \ +lmdeploy convert llama /path/to/llama-65b llama \ --tokenizer_path /path/to/tokenizer/model --tp 8 bash workspace/service_docker_up.sh ``` @@ -94,7 +94,7 @@ python3 -m fastchat.model.apply_delta \ --target-model-path /path/to/vicuna-7b \ --delta-path lmsys/vicuna-7b-delta-v1.1 -python3 -m lmdeploy.serve.turbomind.deploy vicuna /path/to/vicuna-7b +lmdeploy convert vicuna /path/to/vicuna-7b bash workspace/service_docker_up.sh ``` @@ -110,7 +110,7 @@ python3 -m fastchat.model.apply_delta \ --target-model-path /path/to/vicuna-13b \ --delta-path lmsys/vicuna-13b-delta-v1.1 -python3 -m lmdeploy.serve.turbomind.deploy vicuna /path/to/vicuna-13b +lmdeploy convert vicuna /path/to/vicuna-13b bash workspace/service_docker_up.sh ``` diff --git a/docs/en/supported_models/codellama.md b/docs/en/supported_models/codellama.md index 1b51402056..4471b78401 100644 --- a/docs/en/supported_models/codellama.md +++ b/docs/en/supported_models/codellama.md @@ -26,10 +26,10 @@ Based on the above table, download the model that meets your requirements. Execu ```shell # install lmdeploy -python3 -m pip install lmdeploy +python3 -m pip install lmdeploy[all] # convert weight layout -python3 -m lmdeploy.serve.turbomind.deploy codellama /the/path/of/codellama/model +lmdeploy convert codellama /the/path/of/codellama/model ``` Then, you can communicate with codellama in consolo by following instructions in next sections @@ -42,13 +42,13 @@ Then, you can communicate with codellama in consolo by following instructions in ### Completion ```shell -python3 -m lmdeploy.turbomind.chat ./workspace --cap completion +lmdeploy chat turbomind ./workspace --cap completion ``` ### Infilling ```shell -python3 -m lmdeploy.turbomind.chat ./workspace --cap infilling +lmdeploy chat turbomind ./workspace --cap infilling ``` The input code is supposed to have a special placeholder ``. For example, @@ -64,7 +64,7 @@ And the generated code piece by `turbomind.chat` is the one to be filled in ` 0, it represents the total number of k/v blocks + +The `cache_chunk_size` indicates the size of the k/v cache chunk to be allocated each time new k/v cache blocks are needed. Different values represent different meanings: + +- When it is an integer > 0, `cache_chunk_size` number of k/v cache blocks are allocated. +- When the value is -1, `cache_max_entry_count` number of k/v cache blocks are allocated. +- When the value is 0, `sqrt(cache_max_entry_count)` number of k/v cache blocks are allocated. + +### kv int8 switch + +When initiating 8bit k/v inference, set `quant_policy = 4`. Please refer to [kv int8](./kv_int8.md) for a guide. + +### long context switch + +By setting `rope_scaling_factor = 1.0`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output. + +Regarding the principle of Dynamic NTK, please refer to: + +1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases +2. https://kexue.fm/archives/9675 + +You can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`. + +## TurboMind 1.0 config + +Taking the `llama-2-7b-chat` model as an example, in TurboMind 1.0, its `config.ini` content is as follows: + +```toml +[llama] +model_name = llama2 +tensor_para_size = 1 +head_num = 32 +kv_head_num = 32 +vocab_size = 32000 +num_layer = 32 +inter_size = 11008 +norm_eps = 1e-06 +attn_bias = 0 +start_id = 1 +end_id = 2 +session_len = 4104 +weight_type = fp16 +rotary_embedding = 128 +rope_theta = 10000.0 +size_per_head = 128 +group_size = 0 +max_batch_size = 32 +max_context_token_num = 4 +step_length = 1 +cache_max_entry_count = 48 +cache_chunk_size = 1 +use_context_fmha = 1 +quant_policy = 0 +max_position_embeddings = 2048 +use_dynamic_ntk = 0 +use_logn_attn = 0 +``` + +These parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**. + +```toml +model_name = llama2 +head_num = 32 +kv_head_num = 32 +vocab_size = 32000 +num_layer = 32 +inter_size = 11008 +norm_eps = 1e-06 +attn_bias = 0 +start_id = 1 +end_id = 2 +rotary_embedding = 128 +rope_theta = 10000.0 +size_per_head = 128 +``` + +In the following sections, we will focus on introducing the inference parameters. + +### data type + +`weight_type` and `group_size` are the relevant parameters, **which cannot be modified**. + +`weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included. + +### batch size + +`max_batch_size` determines the max size of a batch during inference. In general, the larger the batch size is, the higher the throughput is. But make sure that `max_batch_size <= cache_max_entry_count` + +### k/v cache size + +TurboMind allocates k/v cache memory based on `session_len`, `cache_chunk_size`, and `cache_max_entry_count`. + +- `session_len` denotes the maximum length of a sequence, i.e., the size of the context window. +- `cache_chunk_size` indicates the size of k/v sequences to be allocated when new sequences are added. +- `cache_max_entry_count` signifies the maximum number of k/v sequences that can be cached. + +### kv int8 switch + +When initiating 8bit k/v inference, change `quant_policy = 4` and `use_context_fmha = 0`. Please refer to [kv int8](./kv_int8.md) for a guide. + +### long context switch + +By setting `use_dynamic_ntk = 1`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output. + +Regarding the principle of Dynamic NTK, please refer to: + +1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases +2. https://kexue.fm/archives/9675 + +You can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`. diff --git a/docs/en/w4a16.md b/docs/en/w4a16.md index 96bde48571..1a23678923 100644 --- a/docs/en/w4a16.md +++ b/docs/en/w4a16.md @@ -5,7 +5,7 @@ LMDeploy supports LLM model inference of 4-bit weight, with the minimum requirem Before proceeding with the inference, please ensure that lmdeploy is installed. ```shell -pip install lmdeploy +pip install lmdeploy[all] ``` ## 4-bit LLM model Inference @@ -26,14 +26,14 @@ As demonstrated in the command below, first convert the model's layout using `tu ```shell ## Convert the model's layout and store it in the default path, ./workspace. -python3 -m lmdeploy.serve.turbomind.deploy \ +lmdeploy convert \ --model-name llama2 \ --model-path ./llama2-chat-7b-w4 \ --model-format awq \ --group-size 128 ## inference -python3 -m lmdeploy.turbomind.chat ./workspace +lmdeploy chat turbomind ./workspace ``` ## Serve with gradio @@ -41,7 +41,7 @@ python3 -m lmdeploy.turbomind.chat ./workspace If you wish to interact with the model via web ui, please initiate the gradio server as indicated below: ```shell -python3 -m lmdeploy.serve.turbomind ./workspace --server_name {ip_addr} ----server_port {port} +lmdeploy serve gradio ./workspace --server_name {ip_addr} --server_port {port} ``` Subsequently, you can open the website `http://{ip_addr}:{port}` in your browser and interact with the model @@ -84,7 +84,7 @@ It includes two steps: ### Step 1: Generate Quantization Parameter ```shell -python3 -m lmdeploy.lite.apis.calibrate \ +lmdeploy lite calibrate \ --model $HF_MODEL \ --calib_dataset 'c4' \ # Calibration dataset, supports c4, ptb, wikitext2, pileval --calib_samples 128 \ # Number of samples in the calibration set, if memory is insufficient, you can appropriately reduce this @@ -97,7 +97,7 @@ python3 -m lmdeploy.lite.apis.calibrate \ LMDeploy employs AWQ algorithm for model weight quantization. ```shell -python3 -m lmdeploy.lite.apis.auto_awq \ +lmdeploy lite auto_awq \ --model $HF_MODEL \ --w_bits 4 \ # Bit number for weight quantization --w_group_size 128 \ # Group size for weight quantization statistics diff --git a/docs/zh_cn/benchmark/profile_api_server.md b/docs/zh_cn/benchmark/profile_api_server.md new file mode 100644 index 0000000000..c73151f57d --- /dev/null +++ b/docs/zh_cn/benchmark/profile_api_server.md @@ -0,0 +1,105 @@ +# api_server 性能测试 + +api_server 的测试方式与[求吞吐量测试方法](./profile_throughput.md)类似。不同的是,在测试前,需要先启动 api_server,然后再通过测试脚本发送请求进行测试。 + +测试脚本是 `profile_restful_api.py`。测试之前,请安装 lmdeploy 预编译包,并下载评测脚本和测试数据集。 + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +测速时,需输入具体的模型。我们推荐把模型下载到本地,并通过 `lmdeploy convert` 把模型转换为 turbomind 格式,然后再进行测试。 +这么做的原因是,方便调节推理引擎参数,以达到比较好的推理性能,比如批处理大小(max_batch_size),K/V cache缓存大小(max_cache_entry_count)等等。有关这些参数的详细说明,请参考[这里](../turbomind_config.md). + +以下章节中,我们默认模型是 turbomind 格式的。 + +## 测量指标 + +LMDeploy 统计首token延时(first_token_latency)、token吞吐量(tokens/s)和请求吞吐量(RPM)。 + +`first_token_latency` 只有在流式推理的情况下才会输出。 + +token吞吐量的计算公式为: + +$$ +吞吐量 = 生成的token数量 / 总时间 +$$ + +请求吞吐量的计算公式为: + +$$ +吞吐量 = 请求数量 / 总时间 +$$ + +总时间包括 prefill 时间 + +## 测试案例 + +我们用 `internlm-7b` 为例,api_server的速度测试全流程如下: + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +# 从huggingface下载internlm-7b,并转为turbomind模型格式 +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# 启动server +lmdeploy serve api_server ./internlm-7b --server-port 23333 + +# 另起终端,在`lmdeploy/benchmark`目录下,执行测速脚本 +python3 ./profile_restful_api.py http://0.0.0.0:23333 ./internlm-7b/triton_models/tokenizer ./ShareGPT_V3_unfiltered_cleaned_split.json +``` + +## 测试方法 + +请参考[这里](../restful_api.md) 启动推理服务。启动时的参数 `--instance-num` 表示推理服务中的推理实例数量。当同一时刻到达 api_server 的请求数超过它时,请求会在推理队列中等待。 + +```shell +python3 profile_restful_api.py +``` + +其中,必填参数是: + +- `server_addr` + + api_server 的地址,格式是 `http://{server_ip}:{server_port}` + +- `tokenizer_path` + + tokenizer model 的路径。作用是对测试数据集预先 encode,获取对话数据的 token 长度 + +- `dataset` + + 下载的测试数据集的路径 + +可选测试参数如下: + +- `--concurrency` + + 客户端请求线程的数量,并发请求会被推理引擎拼成 batch,默认为 64。并发请求会被推理引擎拼成 batch。并发数不能超过api_server的`--instance-num`。否则,超出部分的请求会在推理队列中等待。 + +- `--num-prompts` + + 从数据集中采样的prompt数量,默认是 2000 + +- `--top_p` 和 `--temperature` + + 这三个参数用来采样生成的 token_id + +- `--stream_output` + + 流式推理的开关。默认值为 `False` + +- `--csv` + + 一个 csv 文件路径,用来存放测试结果。默认是 `./profile_api_server.csv` + +- `--seed` + + 从测试数据集中随机采样prompt时的种子。默认为0 diff --git a/docs/zh_cn/benchmark/profile_generation.md b/docs/zh_cn/benchmark/profile_generation.md new file mode 100644 index 0000000000..7756b9af36 --- /dev/null +++ b/docs/zh_cn/benchmark/profile_generation.md @@ -0,0 +1,87 @@ +# 静态推理性能测试方法 + +我们把推理引擎在固定 batch、固定输入输出 token 数量的前提下的推理,称之为静态推理。 + +评测脚本是 `profile_generation.py`,在运行此脚本前,请安装 lmdeploy 预编译包,并下载评测脚本 + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +``` + +测速时,需输入具体的模型。我们推荐把模型下载到本地,并通过 `lmdeploy convert` 把模型转换为 turbomind 格式,然后再进行测试。 +这么做的原因是,方便调节推理引擎参数,以达到比较好的推理性能,比如批处理大小(max_batch_size),K/V cache缓存大小(max_cache_entry_count)等等。有关这些参数的详细说明,请参考[这里](../turbomind_config.md). + +以下章节中,我们默认模型是 turbomind 格式的。 + +## 测量指标 + +LMDeploy 统计首token延时(first_token_latency)、token 吞吐量(tokens/s),每个token延时的百分位数据(P50,P75,P95,P99)、GPU mem 等测试结果。 + +`first_token_latency` 只有在流式推理的情况下才会输出。 + +吞吐量的计算公式为: + +$$ +token吞吐量 = 生成的token数量 / 总时间 +$$ + +总时间包括 prefill 时间。 + +测试过程中,节点上所有的显卡不要运行其他任何程序,否则 GPU mem 的统计会不准确。 + +## 测试案例 + +我们用 `internlm-7b` 为例,api_server的速度测试全流程如下: + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark + +# 从huggingface下载internlm-7b,并转为turbomind模型格式 +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# 执行测速脚本 +python3 profile_generation ./internlm-7b +``` + +## 测试方法 + +```shell +python3 profile_generation.py +``` + +其中,`model_path` turbomind格式的模型在 localhost 上的路径。 + +可选测试参数如下: + +- `--concurrency` + + 代表请求线程的数量,并发请求会被推理引擎拼成 batch。默认值为`[1, 16, 32, 64]`,意味着默认测试 4 种不同并发度下的性能。并发量不能超过`config.ini`中的`max_batch_size`。否则,超出部分的请求会在推理队列中等待。 + +- `--prompt-tokens` 和 `--completion-tokens` + + 输入token和输出token数量。它们是一个列表,列表中的元素是一一对应关系,即,`(--prompt-tokens[i]`, `--completion-tokens[i])` 是一组。比如在默认列表中,`[1, 128, 128, 2048, 2048]`和`[128, 128, 2048, 128, 2048]`,测试组合分别是,`(1, 128)`、`(128, 128)`、`(128, 2048)`、`(2048, 128)`和`(2048, 2048)` + +- `--tp` + + 模型在张量并行时,使用的显卡数量。必须是2的整数次幂。默认为 1。 + +- `--top_k`、`--top_p` 和 `--temperature` + + 这三个参数用来采样生成的 token_id。 + +- `--csv` + + 一个 csv 文件路径,用来存放测试结果。默认是 `./profile_generation.csv` + +- `--log-level` + + 日志级别。默认是 `ERROR` + +- `--test-round` + + 测试的轮数,默认是 10。表示每组测试设置,会测试 10 轮,统计其平均结果。 + +我们把一组 `(并发数, prompt_token数量, completion-token数量)` 称为一组测试用例。所以,脚本执行的`测试用例总数 = 并发数列表长度 x prompt_token 列表长度`,`测试规模 = 测试用例总数 x 测试轮数`。用户可以根据自己的实际情况,灵活的调整测试参数。 diff --git a/docs/zh_cn/benchmark/profile_throughput.md b/docs/zh_cn/benchmark/profile_throughput.md new file mode 100644 index 0000000000..d39ae0e873 --- /dev/null +++ b/docs/zh_cn/benchmark/profile_throughput.md @@ -0,0 +1,105 @@ +# 请求吞吐量测试方法 + +在真实应用中,用户输入的 prompt 长度以及模型回复的 token 数量是动态变化的。而静态推理能力不足以反映推理引擎对动态输入输出的处理能力。 + +所以需要使用真实对话数据,评测推理引擎的动态推理能力。本文将介绍如何在 localhost 上测试 LMDeploy 的动态推理性能。 + +测试脚本是 `profile_restful_api.py`。测试之前,请安装 lmdeploy 预编译包,并下载评测脚本和测试数据集。 + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +测速时,需输入具体的模型。我们推荐把模型下载到本地,并通过 `lmdeploy convert` 把模型转换为 turbomind 格式,然后再进行测试。 +这么做的原因是,方便调节推理引擎参数,以达到比较好的推理性能,比如批处理大小(max_batch_size),K/V cache缓存大小(max_cache_entry_count)等等。有关这些参数的详细说明,请参考[这里](../turbomind_config.md). + +以下章节中,我们默认模型是 turbomind 格式的。 + +## 测量指标 + +LMDeploy 统计首token延时(first_token_latency)、token吞吐量(tokens/s)和请求吞吐量(RPM)。 + +`first_token_latency` 只有在流式推理的情况下才会输出。 + +token吞吐量的计算公式为: + +$$ +token吞吐量 = 生成的token数量 / 总时间 +$$ + +请求吞吐量的计算公式为: + +$$ +吞吐量 = 请求数量 / 总时间 +$$ + +总时间包括 prefill 时间 + +## 测试案例 + +我们用 `internlm-7b` 为例,api_server的速度测试全流程如下: + +```shell +pip install 'lmdeploy>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +# 从huggingface下载internlm-7b,并转为turbomind模型格式 +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# 执行测速脚本 +python3 profile_throughput.py ./ShareGPT_V3_unfiltered_cleaned_split.json ./internlm-7b +``` + +## 测试方法 + +```shell +python3 profile_throughput.py +``` + +其中,必填参数是: + +- `dataset` + + 测试数据集的路径 + +- `model_path` + + turbomind格式的模型在 localhost 上的路径 + +可选测试参数如下: + +- `--concurrency` + + 代表请求线程的数量,并发请求会被推理引擎拼成 batch,默认为 64。并发请求会被推理引擎拼成 batch。并发数不能超过`config.ini`中的`max_batch_size`。否则,超出部分的请求会在推理队列中等待。 + +- `--num-prompts` + + 从数据集中采样的prompt数量。默认是 2000 + +- `--tp` + + 模型在张量并行时,使用的显卡数量。必须是2的整数次幂。默认为 1 + +- `--top_k`、`--top_p` 和 `--temperature` + + 这三个参数用来采样生成的 token_id + +- `--stream_output` + + 流式推理的开关。默认值为 `True` + +- `--csv` + + 一个 csv 文件路径,用来存放测试结果。默认是 `./profile_throughput.csv` + +- `--log-level` + + 日志级别。默认是 `ERROR` + +- `--seed` + + 从测试数据集中随机采样prompt时的种子。默认为0 diff --git a/docs/zh_cn/benchmark/profile_triton_server.md b/docs/zh_cn/benchmark/profile_triton_server.md new file mode 100644 index 0000000000..beafc80937 --- /dev/null +++ b/docs/zh_cn/benchmark/profile_triton_server.md @@ -0,0 +1,109 @@ +# Triton Inference Server 性能测试方法 + +Triton Inference Server(TIS) 是 LMDeploy 支持的除了 api_server 之外的另一种 serving 方式。它的性能测试方式和测试指标和 [api_server](./profile_api_server.md) 的测试方式类似。 + +```{note} +LMDeploy 尚未实现 Triton Inference Server 的 ensemble 推理模式,所以推理性能要比 api_server 弱。对于追求性能的用户,我们推荐使用 api_server 部署服务。 +``` + +TIS 性能测试脚本是 `profile_serving.py`。测试之前,请安装 lmdeploy 预编译包,并下载评测脚本和测试数据集。 + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +测速时,需输入具体的模型。我们推荐把模型下载到本地,并通过 `lmdeploy convert` 把模型转换为 turbomind 格式,然后再进行测试。 +这么做的原因是,方便调节推理引擎参数,以达到比较好的推理性能,比如批处理大小(max_batch_size),K/V cache缓存大小(max_cache_entry_count)等等。有关这些参数的详细说明,请参考[这里](../turbomind_config.md). + +以下章节中,我们默认模型是 turbomind 格式的。 + +## 测量指标 + +LMDeploy 统计首token延时(first_token_latency)、token吞吐量(tokens/s)和请求吞吐量(RPM)。 + +`first_token_latency` 只有在流式推理的情况下才会输出。 + +token吞吐量的计算公式为: + +$$ +吞吐量 = 生成的token数量 / 总时间 +$$ + +请求吞吐量的计算公式为: + +$$ +吞吐量 = 请求数量 / 总时间 +$$ + +总时间包括 prefill 时间 + +## 测试案例 + +我们用 `internlm-7b` 为例,api_server的速度测试全流程如下: + +```shell +pip install 'lmdeploy[serve]>=0.1.0a1' +git clone --depth=1 https://github.com/InternLM/lmdeploy +cd lmdeploy/benchmark +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +# 从huggingface下载internlm-7b,并转为turbomind模型格式 +lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b + +# 启动server +bash ./internlm-7b/service_docker_up.sh + +# 另起终端,在`lmdeploy/benchmark`目录下,执行测速脚本 +python3 ./profile_serving 0.0.0.0:33337 ./internlm-7b/triton_models/tokenizer ./ShareGPT_V3_unfiltered_cleaned_split.json +``` + +## 测试方法 + +启动服务 + +```shell +python3 profile_restful_api.py +``` + +其中,必填参数是: + +- `server_addr` + + api_server 的地址,格式是 `{server_ip}:{server_port}` + +- `tokenizer_path` + + tokenizer model 的路径。作用是对测试数据集预先 encode,获取对话数据的 token 长度 + +- `dataset` + + 下载的测试数据集的路径 + +可选测试参数如下: + +- `--concurrency` + + 客户端请求线程的数量,并发请求会被推理引擎拼成 batch,默认为 32。并发请求会被推理引擎拼成 batch。建议 concurrency 的值不要超过推理引擎的 `max_batch_size`,也不要超过 triton_models 中的推理实例的数量。 + 推理实例数量的配置项是 `instance_group`,在文件 `{model_path}/triton_models/interactive/config.pbtxt` 里,默认是 48。 + +- `--num-prompts` + + 从数据集中采样的prompt数量,默认是 1000 + +- `--top_k`、`--top_p` 和 `--temperature` + + 这三个参数用来采样生成的 token_id + +- `--stream_output` + + 流式推理的开关。默认值为 `False` + +- `--csv` + + 一个 csv 文件路径,用来存放测试结果。默认是 `./profile_tis.csv` + +- `--seed` + + 从测试数据集中随机采样prompt时的种子。默认为0 diff --git a/docs/zh_cn/build.md b/docs/zh_cn/build.md index d97bab7196..2d3b329b62 100644 --- a/docs/zh_cn/build.md +++ b/docs/zh_cn/build.md @@ -1,22 +1,81 @@ -### 源码安装 +# 编译和安装 -- 安装编译和运行依赖包: +LMDeploy 提供了预编译包,可以很方便的通过 `pip install lmdeploy` 安装和使用。 - ```shell - conda create -n lmdeploy python=3.10 - conda activate lmdeploy +如果有源码编译的需求,请先下载 lmdeploy 源码: + +```shell +git clone --depth=1 https://github.com/InternLM/lmdeploy +``` + +然后,参考以下章节编译和安装。 + +## 在 docker 内编译安装(强烈推荐) + +LMDeploy 提供了编译镜像 `openmmlab/lmdeploy-builder:cuda11.8`。使用之前,请确保 docker 已安装。 + +在 lmdeploy 源码的根目录下,运行以下命令: + +```shell +cd lmdeploy # lmdeploy 源码根目录 +bash builder/manywheel/build_all_wheel.sh +``` + +即可在 `builder/manywheel/cuda11.8_dist` 文件夹下,得到 lmdeploy 在 py3.8 - py3.11 下所有的 wheel 文件。比如, + +```text +builder/manywheel/cuda11.8_dist/ +├── lmdeploy-0.0.12-cp310-cp310-manylinux2014_x86_64.whl +├── lmdeploy-0.0.12-cp311-cp311-manylinux2014_x86_64.whl +├── lmdeploy-0.0.12-cp38-cp38-manylinux2014_x86_64.whl +└── lmdeploy-0.0.12-cp39-cp39-manylinux2014_x86_64.whl +``` + +如果需要固定 python 版本的 wheel 文件,比如 py3.8,可以执行: + +```shell +bash builder/manywheel/build_wheel.sh py38 manylinux2014_x86_64 cuda11.8 cuda11.8_dist +``` + +wheel 文件存放在目录 `builder/manywheel/cuda11.8_dist` 下。 + +在宿主机上,通过 `pip install` 安装和宿主机python版本一致的 wheel 文件,即完成 lmdeploy 整个编译安装过程。 - git clone https://github.com/InternLM/lmdeploy.git - cd lmdeploy +## 在物理机上编译安装(可选) +首先,请确保物理机环境的 gcc 版本不低于 9,可以通过`gcc --version`确认。 + +然后,按如下步骤,配置编译环境: + +- 安装编译和运行依赖包: + ```shell pip install -r requirements.txt - conda install openmpi-mpicxx nccl rapidjson -c conda-forge + apt-get install rapidjson-dev + ``` +- 安装 [nccl](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html),设置环境变量 + ```shell + export NCCL_ROOT_DIR=/path/to/nccl/build + export NCCL_LIBRARIES=/path/to/nccl/build/lib + ``` +- 源码编译安装 openmpi: + ```shell + wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz + tar xf openmpi-4.1.5.tar.gz + cd openmpi-4.1.5 + ./configure + make -j$(nproc) && make install ``` - - lmdeploy 编译安装: - ```shell + apt install ninja-build # 安装更快的 Ninja + cd lmdeploy # lmdeploy 源码的根目录 mkdir build && cd build sh ../generate.sh - make -j$(nproc) && make install + ninja && ninja install + ninja -j$(nproc) && ninja install + ``` +- 安装 lmdeploy python package: + ```shell + cd .. + pip install -e . ``` diff --git a/docs/zh_cn/faq.md b/docs/zh_cn/faq.md index 5f3bf0b117..c86bfc1841 100644 --- a/docs/zh_cn/faq.md +++ b/docs/zh_cn/faq.md @@ -17,7 +17,7 @@ pip install --upgrade mmengine 1. 您没有安装 lmdeploy 的预编译包。`_turbomind`是 turbomind c++ 的 pybind部分,涉及到编译。推荐您直接安装预编译包。 ``` -pip install lmdeploy +pip install lmdeploy[all] ``` 2. 如果已经安装了,还是出现这个问题,请检查下执行目录。不要在 lmdeploy 的源码根目录下执行 python -m lmdeploy.turbomind.\*下的package,换到其他目录下执行。 @@ -26,7 +26,7 @@ pip install lmdeploy ### libnccl.so.2 not found -确保通过 `pip install lmdeploy` 安装了 lmdeploy (>=v0.0.5)。 +确保通过 `pip install lmdeploy[all]` 安装了 lmdeploy (>=v0.0.5)。 如果安装之后,问题还存在,那么就把`libnccl.so.2`的路径加入到环境变量 LD_LIBRARY_PATH 中。 diff --git a/docs/zh_cn/kv_int8.md b/docs/zh_cn/kv_int8.md index 3e006c6135..f392abf1fb 100644 --- a/docs/zh_cn/kv_int8.md +++ b/docs/zh_cn/kv_int8.md @@ -18,7 +18,7 @@ dequant: f = q * scale + zp 把 huggingface 格式的模型,转成 turbomind 推理格式,得到一个 workspace 目录 ```bash -python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-chat-7b +lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b ``` 如果已经有 workspace 目录,可以跳过这步。 @@ -29,7 +29,7 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch ```bash # 计算 minmax -python3 -m lmdeploy.lite.apis.calibrate \ +lmdeploy lite calibrate \ --model $HF_MODEL \ --calib_dataset 'c4' \ # 校准数据集,支持 c4, ptb, wikitext2, pileval --calib_samples 128 \ # 校准集的样本数,如果显存不够,可以适当调小 @@ -37,7 +37,7 @@ python3 -m lmdeploy.lite.apis.calibrate \ --work_dir $WORK_DIR \ # 保存 Pytorch 格式量化统计参数和量化后权重的文件夹 # 通过 minmax 获取量化参数 -python3 -m lmdeploy.lite.apis.kv_qparams \ +lmdeploy lite kv_qparams \ --work_dir $WORK_DIR \ # 上一步的结果 --turbomind_dir workspace/triton_models/weights/ \ # 保存量化参数的目录,推理要用 --kv_sym False \ # 对称量化或非对称量化,默认为 False @@ -52,19 +52,14 @@ python3 -m lmdeploy.lite.apis.kv_qparams \ 修改 `workspace/triton_models/weights/config.ini`: -- use_context_fmha 改为 0,表示关闭 flashattention - quant_policy 设置为 4。表示打开 kv_cache int8 -这是因为 flashattention 有 v1、v2 两个版本,kv cache int8 曾经也实现过对称版本。 - -排列组合需要实现 4 套 kernel,在算法不确定的时候过早优化,对软件来说是场灾难。 - ### **第四步** 测试聊天效果 ```bash -python3 -m lmdeploy.turbomind.chat ./workspace +lmdeploy chat turbomind ./workspace ``` ## 显存测试 diff --git a/docs/zh_cn/load_hf.md b/docs/zh_cn/load_hf.md new file mode 100644 index 0000000000..63c08fe2d9 --- /dev/null +++ b/docs/zh_cn/load_hf.md @@ -0,0 +1,72 @@ +# 直接读取 huggingface 模型 + +从 v0.1.0 开始,Turbomid 添加了直接读取 Huggingface 格式权重的能力。 + +## 支持的类型 + +目前,TurboMind 支持加载三种类型的模型: + +1. 在 huggingface.co 上面通过 lmdeploy 量化的模型,如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit) +2. huggingface.co 上面其他 LM 模型,如Qwen/Qwen-7B-Chat +3. 通过 `lmdeploy convert` 命令转换好的模型,兼容旧格式 + +## 使用方式 + +### 1) 通过 lmdeploy 量化的模型 + +对于通过 `lmdeploy.lite` 量化的模型,TurboMind 可以直接加载,比如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit). + +``` +repo_id=internlm/internlm-chat-20b-4bit +model_name=internlm-chat-20b + +# or +# repo_id=/path/to/downloaded_model + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 2) 其他的 LM 模型 + +其他 LM 模型比如 Qwen/Qwen-7B-Chat, baichuan-inc/Baichuan2-7B-Chat。LMDeploy 模型支持情况可通过 `lmdeploy list` 查看。 + +``` +repo_id=Qwen/Qwen-7B-Chat +model_name=qwen-7b +# or +# repo_id=/path/to/Qwen-7B-Chat/local_path + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 3) 通过 `lmdeploy convert` 命令转换好的模型 + +使用方式与之前相同 + +``` +# Convert a model +lmdeploy convert /path/to/model ./workspace --model-name MODEL_NAME + +# Inference by TurboMind +lmdeploy chat turbomind ./workspace + +# Serving with gradio +lmdeploy serve gradio ./workspace + +# Serving with Restful API +lmdeploy serve api_server ./workspace --instance_num 32 --tp 1 +``` diff --git a/docs/zh_cn/restful_api.md b/docs/zh_cn/restful_api.md index 2b56fa0f26..96a3094ac7 100644 --- a/docs/zh_cn/restful_api.md +++ b/docs/zh_cn/restful_api.md @@ -2,59 +2,62 @@ ### 启动服务 -运行脚本 +用户将下面命令输出的 http url 复制到浏览器打开,详细查看所有的 API 及其使用方法。 +请一定查看`http://{server_ip}:{server_port}`!!! +请一定查看`http://{server_ip}:{server_port}`!!! +请一定查看`http://{server_ip}:{server_port}`!!! +重要的事情说三遍。 ```shell -python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1 +lmdeploy serve api_server ./workspace 0.0.0.0 --server_port ${server_port} --instance_num 64 --tp 1 ``` -然后用户可以打开 swagger UI: `http://{server_ip}:{server_port}` 详细查看所有的 API 及其使用方法。 -我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate`。 +我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。 + +- /v1/chat/completions +- /v1/models +- /v1/completions + +不过,我们建议用户用我们提供的另一个 API: `/v1/chat/interactive`。 它有更好的性能,提供更多的参数让用户自定义修改。 ### python -这是一个 python 示例,展示如何使用 `generate`。 +我们将这些服务的客户端功能集成在 `APIClient` 类中。下面是一些例子,展示如何在客户端调用 `api_server` 服务。 +如果你想用 `/v1/chat/completions` 接口,你可以尝试下面代码: + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +messages = [{"role": "user", "content": "Say this is a test!"}] +for item in api_client.chat_completions_v1(model=model_name, messages=messages): + print(item) +``` + +如果你想用 `/v1/completions` 接口,你可以尝试: ```python -import json -import requests -from typing import Iterable, List - - -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int, - stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = True, - ignore_eos: bool = False) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos - } - response = requests.post( - api_url, headers=headers, json=pload, stream=stream) - for chunk in response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\n'): - if chunk: - data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - yield output, tokens - - -for output, tokens in get_streaming_response( - "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0, - 512): - print(output, end='') +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +for item in api_client.completions_v1(model=model_name, prompt='hi'): + print(item) +``` + +LMDeploy 的 `/v1/chat/interactive` api 支持将对话内容管理在服务端,但是我们默认关闭。如果想尝试,请阅读以下介绍: + +- 交互模式下,对话历史保存在 server。在一次完整的多轮对话中,所有请求设置`interactive_mode = True`, `session_id`保持相同 (不为 -1,这是缺省值)。 +- 非交互模式下,server 不保存历史记录。 + +交互模式可以通过 `interactive_mode` 布尔量参数控制。下面是一个普通模式的例子, +如果要体验交互模式,将 `interactive_mode=True` 传入即可。 + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +for item in api_client.generate(prompt='hi'): + print(item) ``` ### Java/Golang/Rust @@ -86,16 +89,15 @@ cURL 也可以用于查看 API 的输出结果 curl http://{server_ip}:{server_port}/v1/models ``` -使用 generate: +Interactive Chat: ```bash -curl http://{server_ip}:{server_port}/generate \ +curl http://{server_ip}:{server_port}/v1/chat/interactive \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello! How are you?", "session_id": 1, - "sequence_start": true, - "sequence_end": true + "interactive_mode": true }' ``` @@ -106,19 +108,19 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", - "messages": [{"role": "user", "content": "Hello! Ho are you?"}] + "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` -Embeddings: +Text Completions: -```bash -curl http://{server_ip}:{server_port}/v1/embeddings \ - -H "Content-Type: application/json" \ +```shell +curl http://{server_ip}:{server_port}/v1/completions \ + -H 'Content-Type: application/json' \ -d '{ - "model": "internlm-chat-7b", - "input": "Hello world!" - }' + "model": "llama", + "prompt": "two steps to build a house:" +}' ``` ### CLI client @@ -126,8 +128,8 @@ curl http://{server_ip}:{server_port}/v1/embeddings \ restful api 服务可以通过客户端测试,例如 ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 -python -m lmdeploy.serve.openai.api_client restful_api_url +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +lmdeploy serve api_client api_server_url ``` ### webui @@ -135,25 +137,18 @@ python -m lmdeploy.serve.openai.api_client restful_api_url 也可以直接用 webui 测试使用 restful-api。 ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 -# server_ip 和 server_port 是用来提供 gradio ui 访问服务的 -# 例子: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True -python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +# api_server_url 就是 api_server 产生的,比如 http://localhost:23333 +# server_name 和 server_port 是用来提供 gradio ui 访问服务的 +# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` ### FAQ -1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。 - 请添加 `"renew_session": true` 到下一次请求中。 +1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。如需调整会话支持的最大长度,可以通过启动`api_server`时,设置`--session_len`参数大小。 2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数 -3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false` - -4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数: - - - 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。 +3. 当同一个 `session_id` 的请求给 `/v1/chat/interactive` 函数后,出现返回空字符串和负值的 `tokens`,应该是 `session_id` 混乱了,可以先将交互模式关闭,再重新开启。 -5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。 - 两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置 - `renew_session: true` 传入 `v1/chat/completions`。 +4. `/v1/chat/interactive` api 支持多轮对话, 但是默认关闭。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。 diff --git a/docs/zh_cn/serving.md b/docs/zh_cn/serving.md index e0a2f5a986..db4ebb8d3c 100644 --- a/docs/zh_cn/serving.md +++ b/docs/zh_cn/serving.md @@ -8,7 +8,7 @@ 7B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-7b-chat-hf +lmdeploy convert llama2 /path/to/llama-2-7b-chat-hf bash workspace/service_docker_up.sh ``` @@ -18,7 +18,7 @@ bash workspace/service_docker_up.sh 13B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-13b-chat-hf --tp 2 +lmdeploy convert llama2 /path/to/llama-2-13b-chat-hf --tp 2 bash workspace/service_docker_up.sh ``` @@ -28,7 +28,7 @@ bash workspace/service_docker_up.sh 70B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-70b-chat-hf --tp 8 +lmdeploy convert llama2 /path/to/llama-2-70b-chat-hf --tp 8 bash workspace/service_docker_up.sh ``` @@ -42,7 +42,7 @@ bash workspace/service_docker_up.sh 7B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-7b llama \ +lmdeploy convert llama /path/to/llama-7b llama \ --tokenizer_path /path/to/tokenizer/model bash workspace/service_docker_up.sh ``` @@ -53,7 +53,7 @@ bash workspace/service_docker_up.sh 13B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-13b llama \ +lmdeploy convert llama /path/to/llama-13b llama \ --tokenizer_path /path/to/tokenizer/model --tp 2 bash workspace/service_docker_up.sh ``` @@ -64,7 +64,7 @@ bash workspace/service_docker_up.sh 30B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-30b llama \ +lmdeploy convert llama /path/to/llama-30b llama \ --tokenizer_path /path/to/tokenizer/model --tp 4 bash workspace/service_docker_up.sh ``` @@ -75,7 +75,7 @@ bash workspace/service_docker_up.sh 65B ```shell -python3 -m lmdeploy.serve.turbomind.deploy llama /path/to/llama-65b llama \ +lmdeploy convert llama /path/to/llama-65b llama \ --tokenizer_path /path/to/tokenizer/model --tp 8 bash workspace/service_docker_up.sh ``` @@ -94,7 +94,7 @@ python3 -m fastchat.model.apply_delta \ --target-model-path /path/to/vicuna-7b \ --delta-path lmsys/vicuna-7b-delta-v1.1 -python3 -m lmdeploy.serve.turbomind.deploy vicuna /path/to/vicuna-7b +lmdeploy convert vicuna /path/to/vicuna-7b bash workspace/service_docker_up.sh ``` @@ -110,7 +110,7 @@ python3 -m fastchat.model.apply_delta \ --target-model-path /path/to/vicuna-13b \ --delta-path lmsys/vicuna-13b-delta-v1.1 -python3 -m lmdeploy.serve.turbomind.deploy vicuna /path/to/vicuna-13b +lmdeploy convert vicuna /path/to/vicuna-13b bash workspace/service_docker_up.sh ``` diff --git a/docs/zh_cn/supported_models/codellama.md b/docs/zh_cn/supported_models/codellama.md index ca9029a527..b7776a1e64 100644 --- a/docs/zh_cn/supported_models/codellama.md +++ b/docs/zh_cn/supported_models/codellama.md @@ -26,10 +26,10 @@ ```shell # 安装 lmdeploy -python3 -m pip install lmdeploy +python3 -m pip install lmdeploy[all] # 转模型格式 -python3 -m lmdeploy.serve.turbomind.deploy codellama /path/of/codellama/model +lmdeploy convert codellama /path/of/codellama/model ``` 接下来,可参考如下章节,在控制台与 codellama 进行交互式对话。 @@ -42,13 +42,13 @@ python3 -m lmdeploy.serve.turbomind.deploy codellama /path/of/codellama/model ### 代码续写 ```shell -python3 -m lmdeploy.turbomind.chat ./workspace --cap completion +lmdeploy chat turbomind ./workspace --cap completion ``` ### 代码填空 ```shell -python3 -m lmdeploy.turbomind.chat ./workspace --cap infilling +lmdeploy chat turbomind ./workspace --cap infilling ``` 输入的代码块中要包含 ``,比如: @@ -64,7 +64,7 @@ def remove_non_ascii(s: str) -> str: ### 对话 ``` -python3 -m lmdeploy.turbomind.chat ./workspace --cap chat --sys-instruct "Provide answers in Python" +lmdeploy chat turbomind ./workspace --cap chat --sys-instruct "Provide answers in Python" ``` 可以把 `--sys-instruct` 的指令换成 codellama 支持的其他变成语言。 @@ -72,7 +72,7 @@ python3 -m lmdeploy.turbomind.chat ./workspace --cap chat --sys-instruct "Provid ### Python 专项 ``` -python3 -m lmdeploy.turbomind.chat ./workspace --cap python +lmdeploy chat turbomind ./workspace --cap python ``` 建议这里部署 Python 微调模型 @@ -90,7 +90,7 @@ TBD ```shell # --instance_num: turbomind推理实例的个数。可理解为支持的最大并发数 # --tp: 在 tensor parallel时,使用的GPU数量 -python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${server_port} --instance_num 32 --tp 1 ``` 打开 `http://{server_ip}:{server_port}`,即可访问 swagger,查阅 RESTful API 的详细信息。 @@ -98,17 +98,17 @@ python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port -- 你可以用命令行,在控制台与 server 通信: ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 -python -m lmdeploy.serve.openai.api_client restful_api_url +# api_server_url 就是 api_server 产生的,比如 http://localhost:23333 +lmdeploy serve api_client api_server_url ``` 或者,启动 gradio,在 webui 的聊天对话框中,与 codellama 交流: ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 +# api_server_url 就是 api_server 产生的,比如 http://localhost:23333 # server_ip 和 server_port 是用来提供 gradio ui 访问服务的 -# 例子: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True -python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` 关于 RESTful API的详细介绍,请参考[这份](../restful_api.md)文档。 diff --git a/docs/zh_cn/turbomind.md b/docs/zh_cn/turbomind.md index e51a0199b8..95d30f14d1 100644 --- a/docs/zh_cn/turbomind.md +++ b/docs/zh_cn/turbomind.md @@ -35,7 +35,7 @@ TurboMind 是一款关于 LLM 推理的高效推理引擎,基于英伟达的 [ ## KV 缓存管理器 -TurboMind 的 [KV 缓存管理器](https://github.com/InternLM/lmdeploy/blob/main/src/turbomind/models/llama/LlamaCacheManager.h) 是一个内存池类型的对象,并且在其中加入了 LRU 的实现,这样整个管理器可以被看作是一个 **KV 缓存的缓存**。大致工作方式如下: +TurboMind 的 [KV 缓存管理器](https://github.com/InternLM/lmdeploy/blob/main/src/turbomind/models/llama/SequenceManager.h) 是一个内存池类型的对象,并且在其中加入了 LRU 的实现,这样整个管理器可以被看作是一个 **KV 缓存的缓存**。大致工作方式如下: - KV 缓存由管理器分配。管理器会根据预先配置好的 slot 数量开辟空间。每个 slot 对应于一个 sequence 所需的 KV 缓存。分配的内存块大小可通过配置来实现预分配或者按需分配(或介于两者之间)。 - 当有新的请求,但是缓存池中没有空闲 slot时,根据 LRU 机制,管理器会踢除最近使用最少的 sequence,把它占据的 slot 分给新的请求。不仅仅如此, diff --git a/docs/zh_cn/turbomind_config.md b/docs/zh_cn/turbomind_config.md new file mode 100644 index 0000000000..86cb92313a --- /dev/null +++ b/docs/zh_cn/turbomind_config.md @@ -0,0 +1,202 @@ +# TurboMind 配置 + +TurboMind 是 LMDeploy 的推理引擎,在用它推理 LLM 模型时,需要把输入模型转成 TurboMind 模型。在 TurboMind 的模型文件夹中,除模型权重外,TurboMind 模型还包括其他一些文件,其中最重要的是和推理性能息息相关的配置文件`triton_models/weights/config.ini`。 + +如果你使用的是 LMDeploy 0.0.x 版本,请参考[turbomind 1.0 配置](#turbomind-10-配置)章节,了解配置中的相关内容。如果使用的是 LMDeploy 0.1.x 版本,请阅读[turbomind 2.0 配置](#turbomind-20-配置)了解配置细节。 + +## TurboMind 2.0 配置 + +以 `llama-2-7b-chat` 模型为例,在 TurboMind 2.0 中,它的`config.ini`内容如下: + +```toml +[llama] +model_name = llama2 +tensor_para_size = 1 +head_num = 32 +kv_head_num = 32 +vocab_size = 32000 +num_layer = 32 +inter_size = 11008 +norm_eps = 1e-06 +attn_bias = 0 +start_id = 1 +end_id = 2 +session_len = 4104 +weight_type = fp16 +rotary_embedding = 128 +rope_theta = 10000.0 +size_per_head = 128 +group_size = 0 +max_batch_size = 64 +max_context_token_num = 1 +step_length = 1 +cache_max_entry_count = 0.5 +cache_block_seq_len = 128 +cache_chunk_size = 1 +use_context_fmha = 1 +quant_policy = 0 +max_position_embeddings = 2048 +rope_scaling_factor = 0.0 +use_logn_attn = 0 +``` + +这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等,它们**不可修改** + +```toml +model_name = llama2 +head_num = 32 +kv_head_num = 32 +vocab_size = 32000 +num_layer = 32 +inter_size = 11008 +norm_eps = 1e-06 +attn_bias = 0 +start_id = 1 +end_id = 2 +rotary_embedding = 128 +rope_theta = 10000.0 +size_per_head = 128 +``` + +和 TurboMind 1.0 config 相比,TurboMind 2.0 config 中的模型属性部分和 1.0 一致,但推理参数发生了变化。 + +在接下来的章节中,我们重点介绍推理参数。 + +### 数据类型 + +和数据类型相关的参数是 `weight_type` 和 `group_size`。它们**不可被修改**。 + +`weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时,`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前,在 LMDeploy 的预编译包中,使用的是 `group_size = 128`。 + +### 批处理大小 + +仍通过 `max_batch_size` 设置最大批处理量。默认值由原来的 32 改成 64。 +在 TurboMind 2.0 中,`max_batch_size` 和 `cache_max_entry_count`无关。 + +### k/v 缓存大小 + +`cache_block_seq_len` 和 `cache_max_entry_count` 用来调节 k/v cache 的内存大小。 + +TurboMind 2.0 实现了 Paged Attention,按块管理 k/v cache。 + +`cache_block_seq_len` 表示一块 k/v block 可以存放的 token 序列长度,默认 128。TurboMind 按照以下公式计算 k/v block 的内存大小: + +``` +cache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type) +``` + +对于 llama2-7b 模型来说,以 half 类型存放 k/v 时,一块 k/v block 的内存为:`128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB` + +`cache_max_entry_count` 根据取值不同,表示不同的含义: + +- 当值为 (0, 1) 之间的小数时,`cache_max_entry_count` 表示 k/v block 使用的内存百分比。比如 A100-80G 显卡内存是80G,当`cache_max_entry_count`为0.5时,表示 k/v block 使用的内存总量为 80 * 0.5 = 40G +- 当值为 > 1的整数时,表示 k/v block 数量 + +`cache_chunk_size` 表示在每次需要新的 k/v cache 块时,开辟 k/v cache 块的大小。不同的取值,表示不同的含义: + +- 当为 > 0 的整数时,开辟 `cache_chunk_size` 个 k/v cache 块 +- 当值为 -1 时,开辟 `cache_max_entry_count` 个 k/v cache 块 +- 当值为 0 时,时,开辟 `sqrt(cache_max_entry_count)` 个 k/v cache 块 + +### kv int8 开关 + +`quant_policy`是 KV-int8 推理开关。具体使用方法,请参考 [kv int8](./kv_int8.md) 部署文档 + +### 外推能力开关 + +默认 `rope_scaling_factor = 0` 不具备外推能力。设置为 1.0,可以开启 RoPE 的 Dynamic NTK 功能,支持长文本推理。 + +关于 Dynamic NTK 的原理,详细请参考: + +1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases +2. https://kexue.fm/archives/9675 + +设置 `use_logn_attn = 1`,可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)。 + +## TurboMind 1.0 配置 + +以 `llama-2-7b-chat` 模型为例,在 TurboMind 1.0 中,它的`config.ini`内容如下: + +```toml +[llama] +model_name = llama2 +tensor_para_size = 1 +head_num = 32 +kv_head_num = 32 +vocab_size = 32000 +num_layer = 32 +inter_size = 11008 +norm_eps = 1e-06 +attn_bias = 0 +start_id = 1 +end_id = 2 +session_len = 4104 +weight_type = fp16 +rotary_embedding = 128 +rope_theta = 10000.0 +size_per_head = 128 +group_size = 0 +max_batch_size = 32 +max_context_token_num = 4 +step_length = 1 +cache_max_entry_count = 48 +cache_chunk_size = 1 +use_context_fmha = 1 +quant_policy = 0 +max_position_embeddings = 2048 +use_dynamic_ntk = 0 +use_logn_attn = 0 +``` + +这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等,它们**不可修改** + +```toml +model_name = llama2 +head_num = 32 +kv_head_num = 32 +vocab_size = 32000 +num_layer = 32 +inter_size = 11008 +norm_eps = 1e-06 +attn_bias = 0 +start_id = 1 +end_id = 2 +rotary_embedding = 128 +rope_theta = 10000.0 +size_per_head = 128 +``` + +在接下来的章节中,我们重点介绍推理参数。 + +### 数据类型 + +和数据类型相关的参数是 `weight_type` 和 `group_size`。它们**不可被修改**。 + +`weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时,`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前,在 LMDeploy 的预编译包中,使用的是 `group_size = 128`。 + +### 批处理大小 + +可通过`max_batch_size`调节推理时最大的 batch 数。一般,batch 越大吞吐量越高。但务必保证 `max_batch_size <= cache_max_entry_count` + +### k/v cache 大小 + +TurboMind 根据 `session_len`、 `cache_chunk_size` 和 `cache_max_entry_count` 开辟 k/v cache 内存。 + +- `session_len` 表示一个序列的最大长度,即 context window 的大小。 +- `cache_chunk_size` 表示当新增对话序列时,每次要开辟多少个序列的 k/v cache +- `cache_max_entry_count` 表示最多缓存多少个对话序列 + +### kv int8 开关 + +当启动 8bit k/v 推理时,需要修改参数 `quant_policy` 和 `use_context_fmha`。详细内容请查阅 [kv int8](./kv_int8.md) 部署文档。 + +### 外推能力开关 + +设置 `use_dynamic_ntk = 1`,可以开启 RoPE 的 Dynamic NTK 选项,支持长文本推理。 + +关于 Dynamic NTK 的原理,详细请参考: + +1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases +2. https://kexue.fm/archives/9675 + +设置 `use_logn_attn = 1`,可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)。 diff --git a/docs/zh_cn/w4a16.md b/docs/zh_cn/w4a16.md index 68cc094df8..899a451d3e 100644 --- a/docs/zh_cn/w4a16.md +++ b/docs/zh_cn/w4a16.md @@ -5,7 +5,7 @@ LMDeploy 支持 4bit 权重模型的推理,**对 NVIDIA 显卡的最低要求 在推理之前,请确保安装了 lmdeploy ```shell -pip install lmdeploy +pip install lmdeploy[all] ``` ## 4bit 权重模型推理 @@ -24,14 +24,14 @@ git clone https://huggingface.co/lmdeploy/llama2-chat-7b-w4 ```shell ## 转换模型的layout,存放在默认路径 ./workspace 下 -python3 -m lmdeploy.serve.turbomind.deploy \ +lmdeploy convert \ --model-name llama2 \ --model-path ./llama2-chat-7b-w4 \ --model-format awq \ --group-size 128 ## 推理 -python3 -m lmdeploy.turbomind.chat ./workspace +lmdeploy chat turbomind ./workspace ``` ## 启动 gradio 服务 @@ -39,7 +39,7 @@ python3 -m lmdeploy.turbomind.chat ./workspace 如果想通过 webui 与模型对话,请执行以下命令启动 gradio 服务 ```shell -python3 -m lmdeploy.serve.turbomind ./workspace --server_name {ip_addr} ----server_port {port} +lmdeploy serve gradio ./workspace --server_name {ip_addr} --server_port {port} ``` 然后,在浏览器中打开 http://{ip_addr}:{port},即可在线对话 @@ -82,7 +82,7 @@ python benchmark/profile_generation.py \ ### 第一步:生成量化参数 ```shell -python3 -m lmdeploy.lite.apis.calibrate \ +lmdeploy lite calibrate \ --model $HF_MODEL \ --calib_dataset 'c4' \ # 校准数据集,支持 c4, ptb, wikitext2, pileval --calib_samples 128 \ # 校准集的样本数,如果显存不够,可以适当调小 @@ -95,7 +95,7 @@ python3 -m lmdeploy.lite.apis.calibrate \ LMDeploy 使用 AWQ 算法对模型权重进行量化。在执行下面的命令时,需要把步骤1的`$WORK_DIR`传入。量化结束后,权重文件也会存放在这个目录中。然后就可以根据 ["4bit权重模型推理"](#4bit-权重模型推理)章节的说明,进行模型推理。 ```shell -python3 -m lmdeploy.lite.apis.auto_awq \ +lmdeploy lite auto_awq \ --model $HF_MODEL \ --w_bits 4 \ # 权重量化的 bit 数 --w_group_size 128 \ # 权重量化分组统计尺寸 diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc index 2f50ae19a0..800090a9e7 100644 --- a/examples/cpp/llama/llama_triton_example.cc +++ b/examples/cpp/llama/llama_triton_example.cc @@ -80,7 +80,9 @@ broadCastRequest(const std::vector& v_start_ids, if (node_id == 0) { memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); - memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + if (!v_input_bad_words.empty()) { + memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + } } if (kUSE_MPI) { ft::mpi::barrier(); @@ -431,6 +433,8 @@ int main(int argc, char* argv[]) const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + ft::FT_CHECK(beam_width == 1); + std::vector seq_lens(batch_size); // step 6: check results if (node_id == 0) { @@ -440,32 +444,25 @@ int main(int argc, char* argv[]) printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); } else { - size_t outCount = batch_size * beam_width * seq_len; - // int* hBuf = new int[outCount]; + const size_t outCount = batch_size * beam_width * seq_len; + std::vector hBuf(outCount); - ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount); - ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size); + + ft::cudaAutoCpy(hBuf.data(), d_output_ids, outCount); + ft::cudaAutoCpy(seq_lens.data(), d_seq_lens, batch_size); + std::cout << "sequence length: "; for (int i = 0; i < batch_size; ++i) { std::cout << (i ? ", " : "") << seq_lens[i]; } std::cout << "\n"; - { - std::cout << "Writing " << outCount << " elements\n"; - int zeroCount = 0; - for (size_t i = 0; i < outCount; i++) { - if (hBuf[i] == int(0)) - zeroCount++; - outFile << hBuf[i] << " "; - if ((i + 1) % (seq_len) == 0) - outFile << std::endl; - - if (i < 10) - printf("%5d ", hBuf[i]); - if ((i + 1) % (seq_len) == 0 && i < 10) - std::cout << std::endl; + + for (int i = 0; i < batch_size; ++i) { + outFile << (i ? "\n" : ""); + auto buf = hBuf.data() + seq_len * i; + for (int j = 0; j < seq_lens[i]; ++j) { + outFile << buf[j] << " "; } - std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; } } } @@ -475,7 +472,7 @@ int main(int argc, char* argv[]) } cudaDeviceSynchronize(); - if (1) { + if (0) { // test time auto start = std::chrono::high_resolution_clock::now(); diff --git a/generate.sh b/generate.sh index 5e09688663..6648d2e22a 100755 --- a/generate.sh +++ b/generate.sh @@ -1,6 +1,12 @@ #!/bin/sh -cmake .. \ +builder="-G Ninja" + +if [ "$1" == "make" ]; then + builder="" +fi + +cmake ${builder} .. \ -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ -DCMAKE_INSTALL_PREFIX=./install \ diff --git a/lmdeploy/cli/__init__.py b/lmdeploy/cli/__init__.py new file mode 100644 index 0000000000..3575bec5bd --- /dev/null +++ b/lmdeploy/cli/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cli import run + +__all__ = ['run'] diff --git a/lmdeploy/cli/chat.py b/lmdeploy/cli/chat.py new file mode 100644 index 0000000000..735b24c7cc --- /dev/null +++ b/lmdeploy/cli/chat.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + + +class SubCliChat(object): + """Chat through terminal with pytorch or turbomind model.""" + + def torch(self, + model_path: str, + tokenizer_path: Optional[str] = None, + accel: Optional[str] = None, + max_new_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + seed: int = 0, + use_fast_tokenizer: bool = True, + max_alloc: int = 2048, + max_session_len: int = None, + log_file: Optional[str] = None, + debug: bool = False, + adapter: Optional[str] = None): + """Chat with pytorch model through terminal. + + Args: + model_path (str): Path to pytorch model. + tokenizer_path (str): Path to tokenizer. + accel (str): Model accelerator. + max_new_tokens (int): Maximum number of tokens to generate. + temperature (float): Temperature for sampling. + top_p (float): Top p for sampling. + seed (int): Random seed. + use_fast_tokenizer (bool): Whether to use fast tokenizer. + This argument is directly pass to transformer's + ``AutoTokenizer.from_pretrained``. + Generally, user should choose to use fast tokenizers. + But if using fast raise some error, try to force using a slow one. + max_alloc (int): Maximum memory to allocate (for deepspeed). + max_session_len (int): Maximum number of tokens allowed for all chat sessions. + This include both history and current session. + log_file (str): Path to log file. + debug (bool): Whether to enable debug mode. + adapter (str): Force to use an adapter. + Generally user should not use this argument because adapter is selected based + on the type of model. Only when it is impossible, e.g. distinguishing llama 1/2 + based on `LlamaforCausalLM` class, this argument is required. + Currently, only "llama1" is acceptable for llama1 models. + """ # noqa: E501 + from lmdeploy.pytorch.chat import main as run_torch_model + + run_torch_model(model_path, + tokenizer_path=tokenizer_path, + accel=accel, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + seed=seed, + use_fast_tokenizer=use_fast_tokenizer, + max_alloc=max_alloc, + max_session_len=max_session_len, + log_file=log_file, + debug=debug, + adapter=adapter) + + def turbomind(self, + model_path, + session_id: int = 1, + cap: str = 'chat', + tp=1, + stream_output=True, + **kwargs): + """Chat with turbomind model through terminal. + + Args: + model_path (str): the path of the deployed model + session_id (int): the identical id of a session + cap (str): the capability of a model. For example, codellama has + the ability among ['completion', 'infilling', 'chat', 'python'] + tp (int): GPU number used in tensor parallelism + stream_output (bool): indicator for streaming output or not + **kwarg (dict): other arguments for initializing model's chat + template + """ + from lmdeploy.turbomind.chat import main as run_turbomind_model + + run_turbomind_model(model_path, + session_id=session_id, + cap=cap, + tp=tp, + stream_output=stream_output, + **kwargs) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py new file mode 100644 index 0000000000..12babfc751 --- /dev/null +++ b/lmdeploy/cli/cli.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import fire + +from .chat import SubCliChat +from .lite import SubCliLite +from .serve import SubCliServe + + +class CLI(object): + """LMDeploy Command Line Interface. + + The CLI provides a unified API for converting, compressing and deploying + large language models. + """ + + def convert(self, + model_name: str, + model_path: str, + model_format: str = None, + tokenizer_path: str = None, + dst_path: str = './workspace', + tp: int = 1, + quant_path: str = None, + group_size: int = 0, + **kwargs): + """Convert LLMs to lmdeploy format. + + Args: + model_name (str): The name of the to-be-deployed model, such as + llama-7b, llama-13b, vicuna-7b and etc. + model_path (str): The directory path of the model or huggingface + repo_id like 'internlm/internlm-chat-20b' + model_format (str): the format of the model, should choose from + ['llama', 'hf', 'awq', None]. 'llama' stands for META's llama + format, 'hf' means huggingface llama format, and 'awq' means + llama(hf) model quantized by lmdeploy/lite/quantization/awq.py. + the default value is None, which means the model_format will be + inferred based on model_name + tokenizer_path (str): The path of tokenizer model. + dst_path (str): The destination path that saves outputs. + tp (int): The number of GPUs used for tensor parallelism, which + should be 2^n. + quant_path (str): Path of the quantized model, which can be None. + group_size (int): A parameter used in AWQ to quantize fp16 weights + to 4 bits. + kwargs (dict): other params for convert + """ + from lmdeploy.turbomind.deploy.converter import main as convert + + convert(model_name, + model_path, + model_format=model_format, + tokenizer_path=tokenizer_path, + dst_path=dst_path, + tp=tp, + quant_path=quant_path, + group_size=group_size, + **kwargs) + + def list(self, engine: str = 'turbomind'): + """List supported model names. + + Examples 1: + lmdeploy list + + Examples 2: + lmdeploy list --engine pytorch + + Args: + engine (str): The backend for the model to run. Choice from + ['turbomind', 'pytorch']. + """ + assert engine in ['turbomind', 'pytorch'] + if engine == 'pytorch': + model_names = ['llama', 'llama2', 'internlm-7b'] + elif engine == 'turbomind': + from lmdeploy.model import MODELS + model_names = list(MODELS.module_dict.keys()) + model_names = [n for n in model_names if n.lower() not in ['base']] + model_names.sort() + print('Supported model names:') + print('\n'.join(model_names)) + + def check_env(self, dump_file: str = None): + """Check env information. + + Args: + dump_file (str): Output file to save env info. + """ + + import importlib + + import mmengine + from mmengine.utils import get_git_hash + from mmengine.utils.dl_utils import collect_env + + from lmdeploy.version import __version__ + + env_info = collect_env() + env_info['LMDeploy'] = __version__ + '+' + get_git_hash()[:7] + + # remove some unnecessary info + remove_reqs = ['MMEngine', 'OpenCV'] + for req in remove_reqs: + if req in env_info: + env_info.pop(req) + + # extra important dependencies + extra_reqs = ['transformers', 'gradio', 'fastapi', 'pydantic'] + + for req in extra_reqs: + try: + env_info[req] = importlib.import_module(req).__version__ + except Exception: + env_info[req] = 'Not Found' + + # print env info + for k, v in env_info.items(): + print(f'{k}: {v}') + + # dump to local file + if dump_file is not None: + work_dir, _ = os.path.split(dump_file) + if work_dir: + os.makedirs(work_dir, exist_ok=True) + mmengine.dump(env_info, dump_file) + + +def run(): + """The entry point of running LMDeploy CLI.""" + + cli = CLI() + cli.lite = SubCliLite() + cli.chat = SubCliChat() + cli.serve = SubCliServe() + + fire.Fire(cli, name='lmdeploy') diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py new file mode 100644 index 0000000000..4302765e28 --- /dev/null +++ b/lmdeploy/cli/lite.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. + + +class SubCliLite(object): + """CLI for compressing LLMs.""" + + def auto_awq(self, + model: str, + work_dir: str, + w_bits: int = 4, + w_sym: bool = False, + w_group_size: int = 128, + device: str = 'cuda'): + """Perform weight quantization using AWQ algorithm. + + Args: + model (str): The path of model in hf format. + work_dir (str): The working directory to save results. + w_bits (int): Bit number for weight quantization. + w_sym (bool): Whether to do symmetric quantization. + w_group_size (int): Group size for weight quantization statistics. + device (str): Device type of running. + """ + from lmdeploy.lite.apis.auto_awq import auto_awq + + auto_awq(model, + work_dir, + w_bits=w_bits, + w_sym=w_sym, + w_group_size=w_group_size, + device=device) + + def calibrate(self, + model: str, + calib_dataset: str = 'c4', + calib_samples: int = 128, + calib_seqlen: int = 2048, + work_dir: str = './work_dir', + device: str = 'cuda') -> None: + """Perform calibration on a given dataset. + + Args: + model (str): The model to be loaded. + calib_dataset (str, optional): The calibration dataset name. + Defaults to 'c4'. + calib_samples (int, optional): The number of samples for + calibration. Defaults to 128. + calib_seqlen (int, optional): The sequence length for calibration. + Defaults to 2048. + work_dir (str): The working directory for outputs. + Defaults to './work_dir'. + device (str, optional): The device to be used for calculation. + Defaults to 'cuda'. + """ + from lmdeploy.lite.apis.calibrate import calibrate + + calibrate(model, + calib_dataset=calib_dataset, + calib_samples=calib_samples, + calib_seqlen=calib_seqlen, + work_dir=work_dir, + device=device) + + def kv_qparams(self, + work_dir: str, + turbomind_dir: str, + kv_bits: int = 8, + kv_sym: bool = False, + num_tp: int = 1) -> None: + """Export key and value stats. + + Args: + work_dir (str): Directory path where the stats + are saved. + turbomind_dir (str): Directory path where to + save the results. + kv_bits (int, optional): Number of bits for quantization. + Defaults to 8. + kv_sym (bool, optional): Whether to use symmetric quantization. + Defaults to False. + num_tp (int, optional): Number of tensor parallelism. + Defaults to 1. + """ + from lmdeploy.lite.apis.kv_qparams import main as run_kv_qparams + + run_kv_qparams(work_dir, + turbomind_dir, + kv_bits=kv_bits, + kv_sym=kv_sym, + num_tp=num_tp) + + def get_small_sharded_hf(self, src_dir: str, dst_dir: str): + """Convert a hugging face model to the smallest sharded one. + + Args: + src_dir (str): The directory of the input HF model. + dst_dir (str): The directory to save new model. + """ + from lmdeploy.lite.apis.get_small_sharded_hf import main as run_sharded + run_sharded(src_dir, dst_dir) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py new file mode 100644 index 0000000000..b7680e8c95 --- /dev/null +++ b/lmdeploy/cli/serve.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + + +class SubCliServe(object): + """Serve LLMs and interact on terminal or web UI.""" + + def gradio(self, + model_path_or_server: str, + server_name: str = '0.0.0.0', + server_port: int = 6006, + batch_size: int = 32, + tp: int = 1, + **kwargs): + """Serve LLMs with web ui using gradio. + + Example 1: + lmdeploy serve gradio ./workspace + + Example 2: + lmdeploy serve gradio http://0.0.0.0:23333 + --server_name 0.0.0.0 + --server_port 6006 + + Example 3: + lmdeploy serve gradio ${triton_server_ip_addresss}:33337 + + Args: + model_path_or_server (str): the path of the deployed model or the + tritonserver URL or restful api URL. The former is for directly + running service with gradio. The latter is for running with + tritonserver by default. + server_name (str): the ip address of gradio server + server_port (int): the port of gradio server + batch_size (int): batch size for running Turbomind directly + tp (int): tensor parallel for Turbomind + kwargs (dict): extra params to init + """ + from lmdeploy.serve.gradio.app import run + run(model_path_or_server, + server_name=server_name, + server_port=server_port, + batch_size=batch_size, + tp=tp, + **kwargs) + + def api_server(self, + model_path: str, + server_name: str = '0.0.0.0', + server_port: int = 23333, + instance_num: int = 64, + tp: int = 1, + allow_origins: List[str] = ['*'], + allow_credentials: bool = True, + allow_methods: List[str] = ['*'], + allow_headers: List[str] = ['*'], + **kwargs): + """Serve LLMs with restful api using fastapi. + + Args: + model_path (str): the path of the deployed model + server_name (str): host ip for serving + server_port (int): server port + instance_num (int): number of instances of turbomind model + tp (int): tensor parallel + allow_origins (List[str]): a list of allowed origins for CORS + allow_credentials (bool): whether to allow credentials for CORS + allow_methods (List[str]): a list of allowed HTTP methods for CORS + allow_headers (List[str]): a list of allowed HTTP headers for CORS + kwargs (dict) extra params to init api server + """ + from lmdeploy.serve.openai.api_server import main as run_api_server + + run_api_server(model_path, + server_name=server_name, + server_port=server_port, + instance_num=instance_num, + tp=tp, + allow_origins=allow_origins, + allow_credentials=allow_credentials, + allow_methods=allow_methods, + allow_headers=allow_headers, + **kwargs) + + def api_client(self, restful_api_url: str, session_id: int = 0): + """Interact with restful api server in terminal. + + Args: + restful_api_url: The restful api URL. + session_id: The identical id of a session. + """ + from lmdeploy.serve.openai.api_client import main as run_api_client + run_api_client(restful_api_url, session_id=session_id) + + def triton_client(self, + tritonserver_addr: str, + session_id: int = 1, + cap: str = 'chat', + stream_output: bool = True, + **kwargs): + """Interact with Triton Server using gRPC protocol. + + Args: + tritonserver_addr (str): the address in format "ip:port" of + triton inference server + session_id (int): the identical id of a session + cap (str): the capability of a model. For example, codellama + has the ability among ['completion', 'infill', 'instruct', + 'python'] + stream_output (bool): indicator for streaming output or not + **kwargs (dict): other arguments for initializing model's + chat template + """ + + from lmdeploy.serve.client import main as run_triton_client + + run_triton_client( + tritonserver_addr, + session_id=session_id, + cap=cap, + stream_output=stream_output, + **kwargs, + ) diff --git a/lmdeploy/legacy/pytorch/chat.py b/lmdeploy/legacy/pytorch/chat.py index 706780092d..9a167cc52b 100644 --- a/lmdeploy/legacy/pytorch/chat.py +++ b/lmdeploy/legacy/pytorch/chat.py @@ -51,7 +51,6 @@ import logging from typing import Optional -import fire import torch from transformers import GenerationConfig, PreTrainedModel @@ -205,6 +204,8 @@ def main( def cli(): + import fire + fire.Fire(main) diff --git a/lmdeploy/legacy/pytorch/modules/linear.py b/lmdeploy/legacy/pytorch/modules/linear.py index bfde0d3d42..218a36407e 100644 --- a/lmdeploy/legacy/pytorch/modules/linear.py +++ b/lmdeploy/legacy/pytorch/modules/linear.py @@ -4,6 +4,11 @@ import torch from torch import nn +try: + import awq_inference_engine +except ModuleNotFoundError: + awq_inference_engine = None + class WeightOnlyQLinear(nn.Module): """This class implements weight only quantization linear. @@ -18,13 +23,15 @@ class WeightOnlyQLinear(nn.Module): bias (Tensor, optional): Defaults to None. """ - def __init__(self, - w_bit: int, - symmetry: bool, - group_size: int, - in_features: int, - out_features: int, - bias: Optional[torch.Tensor] = None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: Optional[torch.Tensor] = True, + w_bit: int = 4, + symmetry: bool = False, + group_size: int = 128, + ) -> None: super().__init__() if w_bit not in [2, 4, 8]: @@ -92,8 +99,8 @@ def from_linear(cls: Type['WeightOnlyQLinear'], out_features = linear.out_features bias = False if linear.bias is None else True - qlinear = cls(w_bit, symmetry, group_size, in_features, out_features, - bias) + qlinear = cls(in_features, out_features, bias, w_bit, symmetry, + group_size) qlinear.bias = linear.bias qparams = quantizer.calculate_qparams(linear.weight) @@ -124,3 +131,24 @@ def from_linear(cls: Type['WeightOnlyQLinear'], qlinear.to('cpu') return qlinear + + @torch.no_grad() + def forward(self, x): + if awq_inference_engine is None: + raise RuntimeError( + 'Run the following command to install ' + 'the kernel for 4bit inference\n\n' + 'git clone https://github.com/mit-han-lab/llm-awq.git\n' + 'cd awq/kernels\n' + 'python setup.py install\n') + out_shape = x.shape[:-1] + (self.out_features, ) + inputs = x.reshape(-1, x.shape[-1]) + + out = awq_inference_engine.gemm_forward_cuda(inputs.half(), + self.qweight, + self.scales.half(), + self.qzeros, + self.group_size) + out = out + self.bias if self.bias is not None else out + + return out.reshape(out_shape) diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py index 3517f51b85..4a4f8ea983 100644 --- a/lmdeploy/lite/apis/auto_awq.py +++ b/lmdeploy/lite/apis/auto_awq.py @@ -2,27 +2,28 @@ from pathlib import Path -import fire import torch -from accelerate import (infer_auto_device_map, init_empty_weights, - load_checkpoint_in_model) from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP, quant_weights, smooth_layers) -from lmdeploy.lite.utils import collect_target_modules +from lmdeploy.lite.utils import collect_target_modules, load_hf_from_pretrained + +# from lmdeploy.lite.utils.export_turbomind import export_turbomind_config LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', 'QWenLMHeadModel': 'QWenBlock', - 'BaiChuanForCausalLM': 'DecoderLayer', + 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B + 'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B 'LlamaForCausalLM': 'LlamaDecoderLayer', } NORM_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMRMSNorm', 'QWenLMHeadModel': 'RMSNorm', - 'BaiChuanForCausalLM': 'RMSNorm', + 'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B + 'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B 'LlamaForCausalLM': 'LlamaRMSNorm', } @@ -34,39 +35,25 @@ def auto_awq(model: str, w_group_size: int = 128, device: str = 'cuda'): + assert model != work_dir, '$WORK_DIR and $HF_MODEL should be different' + model_path = model # noqa + # Load tokenizer and configuration tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) - hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) - checkpoint = hf_config._name_or_path - with init_empty_weights(): - # Load model - model = AutoModelForCausalLM.from_pretrained(model, - torch_dtype=torch.float16, - trust_remote_code=True) - model.config.use_cache = False + model = load_hf_from_pretrained(model, + torch_dtype=torch.float16, + trust_remote_code=True) layer_type = LAYER_TYPE_MAP[type(model).__name__] fc2fcs = FC_FCS_MAP[layer_type] norm2fcs = NORM_FCS_MAP[layer_type] - decoder_layers = collect_target_modules(model, layer_type) - - # Infer device map - device_map = infer_auto_device_map(model, - no_split_module_classes=[layer_type]) - for name in device_map.keys(): - if name in decoder_layers or 'lm_head' in name: - device_map[name] = 'cpu' - else: - device_map[name] = 0 - load_checkpoint_in_model(model, checkpoint, device_map) - work_dir = Path(work_dir) - act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmean'] + act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmax'] layers = collect_target_modules(model, layer_type) fcs = {} for l_name, layer in layers.items(): @@ -76,10 +63,16 @@ def auto_awq(model: str, smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device) quant_weights(model, fcs, w_bits, w_sym, w_group_size, device) - model.save_pretrained(work_dir) + model.save_pretrained(work_dir, max_shard_size='2GB') tokenizer.save_pretrained(work_dir) + # export_turbomind_config(model_name, + # model_path, + # work_dir, + # group_size=w_group_size) + if __name__ == '__main__': + import fire fire.Fire(auto_awq) diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 38b6429a19..b164c638f8 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -1,30 +1,103 @@ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path +from typing import Union -import fire import torch -from accelerate import (infer_auto_device_map, init_empty_weights, - load_checkpoint_in_model) -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from torch import nn +from transformers import AutoTokenizer from lmdeploy.lite.quantization import CalibrationContext -from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders +from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders, + load_hf_from_pretrained) LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', 'QWenLMHeadModel': 'QWenBlock', - 'BaiChuanForCausalLM': 'DecoderLayer', + 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B + 'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B 'LlamaForCausalLM': 'LlamaDecoderLayer', } NORM_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMRMSNorm', 'QWenLMHeadModel': 'RMSNorm', - 'BaiChuanForCausalLM': 'RMSNorm', + 'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B + 'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B 'LlamaForCausalLM': 'LlamaRMSNorm', } +def _prepare_for_calibrate(model: nn.Module, + layer_type: Union[str, type], + head_name: str = 'lm_head', + device: str = 'cuda', + prefix: str = '') -> None: + """Prepare the model for calibration by moving specific modules to CPU. + + This function goes through each child of a given model and checks whether + it is an instance of a certain layer type or has the name equal to + `head_name`. + If yes, it moves the module to CPU, otherwise to the specified device + (default is CUDA). + + If the child contains the target layer type in its sub-modules, the + function performs the same operation recursively. + + Parameters + ---------- + model : nn.Module + The PyTorch model to prepare for calibration. + layer_type : Union[str, Type] + The type of the layer to be moved to CPU. Can be either a string of + class name or the class type itself. + head_name : str, optional + The name of the module to be moved to CPU. Default is 'lm_head'. + device : str, optional + The device to which modules not matching the `layer_type` or + `head_name` will be moved. Default is 'cuda'. + prefix : str, optional + The prefix used when printing the names of the moved modules. + Default is ''. + + Raises + ------ + TypeError + If `layer_type` is neither a string nor a type. + """ + + for name, child in model.named_children(): + + # Check if the child is an instance of the given layer type + if isinstance(layer_type, str): + is_layer = type(child).__name__ == layer_type + elif isinstance(layer_type, type): + is_layer = isinstance(child, layer_type) + else: + raise TypeError( + 'layer_type should be a string (class name) or a type') + + # Check if the child contains the target module type + contain_layer = len( + collect_target_modules(child, layer_type, [head_name]).keys()) > 0 + + # Check if the child matches the head name + is_head = name == head_name + + mod_name = f'{prefix}.{name}' if prefix else name + + # If the child is either an instance of the layer type or has the + # head name, move it to CPU, otherwise move it to the specified device + if is_layer or is_head: + child.to('cpu') + print(f'Move {mod_name} to CPU.') + elif contain_layer: + _prepare_for_calibrate(child, layer_type, head_name, device, + mod_name) + else: + child.to(device) + print(f'Move {mod_name} to GPU.') + + def calibrate(model: str, calib_dataset: str = 'c4', calib_samples: int = 128, @@ -35,7 +108,7 @@ def calibrate(model: str, given dataset. Args: - model (str): The model to be loaded. + model (str): The name or path of the model to be loaded. calib_dataset (str, optional): The calibration dataset name. Defaults to 'c4'. calib_samples (int, optional): The number of samples for calibration. @@ -55,30 +128,31 @@ def calibrate(model: str, tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) - hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) - checkpoint = hf_config._name_or_path - with init_empty_weights(): - # Load model - model = AutoModelForCausalLM.from_pretrained(model, - torch_dtype=torch.float16, - trust_remote_code=True) - model.config.use_cache = False + model = load_hf_from_pretrained(model, + torch_dtype=torch.float16, + trust_remote_code=True) + + model_type = type(model).__name__ + if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: + raise RuntimeError( + f'Currently, quantification and calibration of {model_type} are ' + f'not supported. The supported model types are ' + f"{', '.join(LAYER_TYPE_MAP.keys())}.") + + if model_type == 'QWenLMHeadModel': + try: + import flash_attn # noqa: F401 + except ImportError: + raise RuntimeError( + 'When using Qwen, you need to `pip install flash-attn` first, ' + 'otherwise calibration and quantification will not work ' + 'properly.') layer_type = LAYER_TYPE_MAP[type(model).__name__] norm_type = NORM_TYPE_MAP[type(model).__name__] - decoder_layers = collect_target_modules(model, layer_type) - - # Infer device map - device_map = infer_auto_device_map(model, - no_split_module_classes=[layer_type]) - for name in device_map.keys(): - if name in decoder_layers or 'lm_head' in name: - device_map[name] = 'cpu' - else: - device_map[name] = 0 - load_checkpoint_in_model(model, checkpoint, device_map) + _prepare_for_calibrate(model, layer_type, 'lm_head', device) print('Loading calibrate dataset ...') calib_loader, _ = get_calib_loaders(calib_dataset, @@ -107,4 +181,6 @@ def calibrate(model: str, if __name__ == '__main__': + import fire + fire.Fire(calibrate) diff --git a/lmdeploy/lite/apis/kv_qparams.py b/lmdeploy/lite/apis/kv_qparams.py index 7d43078daf..873bc5b047 100644 --- a/lmdeploy/lite/apis/kv_qparams.py +++ b/lmdeploy/lite/apis/kv_qparams.py @@ -1,17 +1,34 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from pathlib import Path from typing import Union -import fire import numpy as np import torch +def _export_weight(into: str, + kv_qparams: np.array, + out_path: str, + tm_params: dict = None): + """Save kv_qparams to disk or copy to tm_params.""" + if tm_params is None: + print(into) + kv_qparams.tofile(out_path) + else: + name = os.path.basename(out_path) + src = torch.from_numpy(kv_qparams) + for tm_tensor in tm_params[name]: + tm_tensor.copy_from(src) + tm_params.pop(name) + + def _export_sym(key_stats: dict, value_stats: dict, bits: int, out_dir: Union[str, Path], - tp: int = 1) -> None: + tp: int = 1, + tm_params: dict = None) -> None: """Export symmetric quantization parameters to specified directory.""" keys_absmax = key_stats['absmax'] values_absmax = value_stats['absmax'] @@ -32,15 +49,16 @@ def _export_sym(key_stats: dict, kv_qparams = np.array([k_s, v_s], dtype=np.float32) out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501 - kv_qparams.tofile(out_path) - print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}') + info = f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}' + _export_weight(info, kv_qparams, out_path, tm_params) def _export_asym(key_stats: dict, value_stats: dict, bits: int, out_dir: Union[str, Path], - tp: int = 1) -> None: + tp: int = 1, + tm_params: dict = None) -> None: """Export asymmetric quantization parameters to specified directory.""" keys_min = key_stats['min'] values_min = value_stats['min'] @@ -82,16 +100,17 @@ def _export_asym(key_stats: dict, kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp], dtype=np.float32) out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' - kv_qparams.tofile(out_path) - print(f'Layer {layer_idx} MP {i} qparam: ' - f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}') + info = f'Layer {layer_idx} MP {i} qparam: ' \ + f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}' + _export_weight(info, kv_qparams, out_path, tm_params) def main(work_dir: str, turbomind_dir: str, kv_bits: int = 8, kv_sym: bool = False, - num_tp: int = 1) -> None: + num_tp: int = 1, + tm_params: dict = None) -> None: """Main function to export key and value stats. Args: @@ -103,6 +122,7 @@ def main(work_dir: str, kv_sym (bool, optional): Whether to use symmetric quantizaiton. Defaults to False. num_tp (int, optional): Number of tensor parallelism. Defaults to 1. + tm_params (dict): turbomind model weights. """ work_dir = Path(work_dir) @@ -114,11 +134,13 @@ def main(work_dir: str, value_stats = torch.load(work_dir / 'value_stats.pth') if kv_sym: - _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp, tm_params) else: - _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp, + tm_params) if __name__ == '__main__': + import fire fire.Fire(main) diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py index 2816fc1d5f..91da1d4f0c 100644 --- a/lmdeploy/lite/quantization/awq.py +++ b/lmdeploy/lite/quantization/awq.py @@ -77,7 +77,7 @@ def smooth_ln_fcs(ln: torch.nn.Module, w_scales = get_weight_scale(concat_w, group_size) scales = (act_scales.pow(alpha) / - w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) + w_scales.pow(1 - alpha)).to(device).to(dtype) scales = scales / (scales.max() * scales.min()).sqrt() ln.weight.div_(scales) @@ -124,10 +124,10 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, w_scales = get_weight_scale(concat_w, group_size) scales = (act_scales.pow(alpha) / - w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) + w_scales.pow(1 - alpha)).to(device).to(dtype) scales = scales / (scales.max() * scales.min()).sqrt() - # (for qwen) pre_fc is packed QKV, only V needs to scale + # (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale if size_pre_fc > size_a and size_pre_fc % size_a == 0 \ and size_pre_fc // size_a == 3: diff --git a/lmdeploy/lite/quantization/weight/quantizer.py b/lmdeploy/lite/quantization/weight/quantizer.py index 56cfda8f01..1d01696eb9 100644 --- a/lmdeploy/lite/quantization/weight/quantizer.py +++ b/lmdeploy/lite/quantization/weight/quantizer.py @@ -8,7 +8,7 @@ cal_qparams_per_group_absmax, cal_qparams_per_group_minmax, cal_qparams_per_tensor_absmax, - cal_qparams_per_tensor_minmax) + cal_qparams_per_tensor_minmax, precise_round) from lmdeploy.lite.utils.global_avail import GlobalAvailMixin @@ -119,8 +119,10 @@ def quant(self, torch.Tensor: The fake quantized weight tensor. """ + float_w = weight.float() + if qparams is None: - qparams = self.calculate_qparams(weight) + qparams = self.calculate_qparams(float_w) scales = qparams.scales zero_points = qparams.zero_points @@ -133,17 +135,18 @@ def quant(self, # per group scales shape: [out_c, in_c//group_size, 1] if len(scales.shape) > 2: # scales shape: [out_c, in_c//group_size, 1] - weight = weight.reshape(out_c, scales.shape[1], -1) + float_w = float_w.reshape(out_c, scales.shape[1], -1) if zero_points is None: assert self.symmetry - real_qweight = (weight / scales).round() + real_qweight = (float_w / scales).round() fake_qweight = real_qweight * scales else: assert not self.symmetry - real_qweight = (weight / scales).round() + zero_points + real_qweight = precise_round( + (float_w - float_w.min(-1, keepdim=True)[0]) / scales) fake_qweight = (real_qweight - zero_points) * scales if len(scales.shape) > 2: @@ -153,4 +156,4 @@ def quant(self, if real: return real_qweight.to(torch.int32) else: - return fake_qweight + return fake_qweight.to(weight.dtype) diff --git a/lmdeploy/lite/utils/__init__.py b/lmdeploy/lite/utils/__init__.py index c2b56287bd..2d539e83ac 100644 --- a/lmdeploy/lite/utils/__init__.py +++ b/lmdeploy/lite/utils/__init__.py @@ -6,17 +6,18 @@ cal_qparams_per_group_absmax, cal_qparams_per_group_minmax, cal_qparams_per_tensor_absmax, - cal_qparams_per_tensor_minmax) + cal_qparams_per_tensor_minmax, precise_round) from .calib_dataloader import get_calib_loaders from .collect import (bimap_name_mod, collect_target_modules, collect_target_weights) from .global_avail import GlobalAvailMixin +from .load import load_hf_from_pretrained __all__ = [ 'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax', 'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax', 'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax', - 'QParams', 'get_calib_loaders', 'collect_target_modules', + 'QParams', 'get_calib_loaders', 'collect_target_modules', 'precise_round', 'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs', - 'bimap_name_mod', 'concat_decoder_layer_outputs' + 'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained' ] diff --git a/lmdeploy/lite/utils/cal_qparams.py b/lmdeploy/lite/utils/cal_qparams.py index a682704a55..569297cdb5 100644 --- a/lmdeploy/lite/utils/cal_qparams.py +++ b/lmdeploy/lite/utils/cal_qparams.py @@ -11,16 +11,22 @@ class QParams(NamedTuple): zero_points: Optional[torch.Tensor] +@torch.no_grad() +def precise_round(x): + return x.sign() * (x.abs() + 0.5).floor() + + @torch.no_grad() def cal_qparams_per_channel_absmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for each channel using absolute max value.""" + float_w = w.float() - absmax = w.abs().max(dim=-1, keepdim=True)[0] + absmax = float_w.abs().max(dim=-1, keepdim=True)[0] q_max = 2**(n_bits - 1) - 1 - scales = absmax.clamp(min=1e-5).div(q_max) + scales = absmax.div(q_max) if return_stats: return QParams(scales=scales, zero_points=None), absmax @@ -35,14 +41,16 @@ def cal_qparams_per_channel_minmax(w: torch.Tensor, """Calculate quantization parameters for each channel using min and max values.""" - w_min = w.min(dim=-1, keepdim=True)[0] - w_max = w.max(dim=-1, keepdim=True)[0] + float_w = w.float() + + w_min = float_w.min(dim=-1, keepdim=True)[0] + w_max = float_w.max(dim=-1, keepdim=True)[0] q_max = 2**n_bits - 1 scales = (w_max - w_min) - scales = scales.clamp_(min=1e-5).div_(q_max) + scales = scales.div_(q_max) - zero_points = (-w_min / scales).round() + zero_points = precise_round(-w_min / scales) if return_stats: return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) @@ -63,9 +71,12 @@ def cal_qparams_per_group_absmax(w: torch.Tensor, 'Input channels should be greater than or equal to group_size.' assert inc % group_size == 0, \ 'Input channels should be divisible by group_size.' - absmax = w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0] + + float_w = w.float() + absmax = float_w.abs().reshape(outc, -1, group_size).max(dim=-1, + keepdim=True)[0] q_max = 2**(n_bits - 1) - 1 - scales = absmax.clamp(min=1e-5).div(q_max) + scales = absmax.div(q_max) if return_stats: return QParams(scales=scales, zero_points=None), absmax else: @@ -85,14 +96,16 @@ def cal_qparams_per_group_minmax(w: torch.Tensor, 'Input channels should be greater than or equal to group_size.' assert inc % group_size == 0, \ 'Input channels should be divisible by group_size.' - w_group_wise = w.reshape(outc, -1, group_size) + + float_w = w.float() + w_group_wise = float_w.reshape(outc, -1, group_size) w_min = w_group_wise.min(dim=-1, keepdim=True)[0] w_max = w_group_wise.max(dim=-1, keepdim=True)[0] q_max = 2**n_bits - 1 scales = (w_max - w_min) - scales = scales.clamp_(min=1e-5).div_(q_max) - zero_points = (-w_min / scales).round() + scales = scales.div_(q_max) + zero_points = precise_round(-w_min / scales) if return_stats: return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) else: @@ -106,13 +119,15 @@ def cal_qparams_per_tensor_minmax(w: torch.Tensor, """Calculate quantization parameters for the entire tensor using min and max values.""" - w_min = w.min() - w_max = w.max() + float_w = w.float() + + w_min = float_w.min() + w_max = float_w.max() q_max = 2**n_bits - 1 scales = (w_max - w_min) scales = scales.clamp_(min=1e-5).div_(q_max) - zero_points = (-w_min / scales).round() + zero_points = precise_round(-w_min / scales) if return_stats: return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) else: @@ -125,9 +140,10 @@ def cal_qparams_per_tensor_absmax(w: torch.Tensor, return_stats: bool = False) -> QParams: """Calculate quantization parameters for the entire tensor using absolute max value.""" - absmax = w.abs().max() + float_w = w.float() + absmax = float_w.abs().max() q_max = 2**(n_bits - 1) - 1 - scales = absmax.clamp(min=1e-5).div(q_max) + scales = absmax.div(q_max) if return_stats: return QParams(scales=scales, zero_points=None), absmax diff --git a/lmdeploy/lite/utils/collect.py b/lmdeploy/lite/utils/collect.py index 8b2691a4a6..3b66ef6146 100644 --- a/lmdeploy/lite/utils/collect.py +++ b/lmdeploy/lite/utils/collect.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Tuple, Union -from mmengine.config.lazy import LazyAttr from torch import nn @@ -22,9 +21,6 @@ def collect_target_modules(model: nn.Module, A dictionary mapping from module names to module instances. """ - if isinstance(target, LazyAttr): - target = target.build() - if not isinstance(target, (type, str)): raise TypeError('Target must be a string (name of the module) ' 'or a type (class of the module)') diff --git a/lmdeploy/lite/utils/export_turbomind.py b/lmdeploy/lite/utils/export_turbomind.py new file mode 100644 index 0000000000..393a980041 --- /dev/null +++ b/lmdeploy/lite/utils/export_turbomind.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import shutil + +from huggingface_hub import snapshot_download + +from lmdeploy.turbomind.utils import get_hf_config_content + + +def export_turbomind_config(model_name: str, + model_path: str, + work_dir: str, + model_format: str = 'awq', + group_size: int = 128, + tp: int = 1): + """Export hf lmdeploy model and config.json.""" + import lmdeploy + from lmdeploy.model import MODELS + from lmdeploy.turbomind.deploy.converter import get_model_format + from lmdeploy.turbomind.deploy.source_model.base import INPUT_MODELS + from lmdeploy.turbomind.deploy.target_model.base import ( + OUTPUT_MODELS, TurbomindModelConfig) + + assert model_name in MODELS.module_dict.keys(), \ + f"'{model_name}' is not supported. " \ + f'The supported models are: {MODELS.module_dict.keys()}' + + if not os.path.exists(model_path): + model_path = snapshot_download(model_path, local_files_only=True) + + lmdeploy_dir = os.path.split(lmdeploy.__file__)[0] + hf_repo = os.path.join(lmdeploy_dir, 'turbomind', 'hf_repo') + files = os.listdir(hf_repo) + for file in files: + src = os.path.join(hf_repo, file) + dst = os.path.join(work_dir, file) + shutil.copy(src, dst) + + cfg = TurbomindModelConfig.from_dict({}, allow_none=True) + cfg.model_name = model_name + cfg.tensor_para_size = tp + cfg.rotary_embedding = cfg.size_per_head + cfg.group_size = group_size + cfg.weight_type = 'int4' + output_format = 'w4' + + inferred_model_format = get_model_format(model_name, model_format) + input_model = INPUT_MODELS.get(inferred_model_format)( + model_path=model_path, tokenizer_path=work_dir, ckpt_path=work_dir) + output_model = OUTPUT_MODELS.get(output_format)(input_model=input_model, + cfg=cfg, + to_file=False, + out_dir='') + + old_data = get_hf_config_content(model_path) + config = output_model.cfg.__dict__ + config_file = os.path.join(work_dir, 'config.json') + with open(config_file) as f: + data = json.load(f) + for k, v in old_data.items(): + if k in data: + data[f'__{k}'] = v + else: + data[k] = v + data['turbomind'] = config + from lmdeploy.version import __version__ + data['lmdeploy_version'] = __version__ + with open(config_file, 'w') as f: + f.write(json.dumps(data, indent=2) + '\n') diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py new file mode 100644 index 0000000000..55aefd2389 --- /dev/null +++ b/lmdeploy/lite/utils/load.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from accelerate import infer_auto_device_map, init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM + +from lmdeploy.lite.utils import collect_target_modules +from lmdeploy.pytorch.model import LoadWoInit + +LAYER_TYPE_MAP = { + 'InternLMForCausalLM': 'InternLMDecoderLayer', + 'QWenLMHeadModel': 'QWenBlock', + 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B + 'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B + 'LlamaForCausalLM': 'LlamaDecoderLayer', +} + + +def load_hf_from_pretrained(pretrained_model_name_or_path, **kwargs): + + kwargs.pop('config', None) + + hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, + torch_dtype=torch.float16, + trust_remote_code=True) + + # hard code for qwen, other configs do not have the `fp16` attribute. + hf_config.fp16 = True + + with init_empty_weights(): + # Load model + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, config=hf_config, **kwargs) + model.config.use_cache = False + layer_type = LAYER_TYPE_MAP[type(model).__name__] + decoder_layers = collect_target_modules(model, layer_type) + # Infer device map + device_map = infer_auto_device_map(model, + no_split_module_classes=[layer_type]) + for name in device_map.keys(): + if name in decoder_layers or 'lm_head' in name: + device_map[name] = 'cpu' + else: + device_map[name] = 0 + if 'device_map' in kwargs: + kwargs.pop('device_map') + with LoadWoInit(): + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + device_map=device_map, + config=hf_config, + **kwargs) + model.config.use_cache = False + + return model diff --git a/lmdeploy/model.py b/lmdeploy/model.py index b3fc86f999..8ccd5b7a3d 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -115,6 +115,7 @@ def update_input_ids(self, input_ids: List[int]): return input_ids +@MODELS.register_module(name='wizardlM') @MODELS.register_module(name='vicuna') class Vicuna(BaseModel): """Chat template of vicuna model.""" @@ -177,15 +178,16 @@ class InternLMChat7B(BaseModel): def __init__( self, - system='<|System|>', + system='<|System|>:', meta_instruction="""You are an AI assistant whose name is InternLM (书生·浦语). - InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. - InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文. """, # noqa: E501 - user='<|User|>', - eoh='', - eoa='', - assistant='<|Bot|>', + user='<|User|>:', + eoh='\n', + eoa='\n', + eosys='\n', + assistant='<|Bot|>:', stop_words=[''], **kwargs): super().__init__(**kwargs) @@ -194,6 +196,7 @@ def __init__( self.user = user self.eoh = eoh self.eoa = eoa + self.eosys = eosys self.assistant = assistant self.stop_words = stop_words @@ -211,12 +214,12 @@ def decorate_prompt(self, prompt, sequence_start=True): assert self.capability == 'chat', \ f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: - return f'{self.system}:{self.meta_instruction}\n' \ - f'{self.user}:{prompt}{self.eoh}\n' \ - f'{self.assistant}:' + return f'{self.system}{self.meta_instruction}{self.eosys}' \ + f'{self.user}{prompt}{self.eoh}' \ + f'{self.assistant}' else: - return f'\n{self.user}:{prompt}{self.eoh}\n' \ - f'{self.assistant}:' + return f'\n{self.user}{prompt}{self.eoh}' \ + f'{self.assistant}' def messages2prompt(self, messages, sequence_start=True): """Return the prompt that is concatenated with other elements in the @@ -227,17 +230,19 @@ def messages2prompt(self, messages, sequence_start=True): Returns: str: the concatenated prompt """ + if isinstance(messages, str): return self.get_prompt(messages, sequence_start) - system, users, assistants = self._translate_messages(messages) - system = self.meta_instruction if not system else system - ret = f'{self.system}:{system}\n' - for user, assistant in zip(users, assistants): - if assistant: - ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' \ - f'{assistant}{self.eoa}\n' - else: - ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' + eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys) + ret = '' + if self.meta_instruction: + ret += f'{self.system}:{self.meta_instruction}{self.eosys}' + + for message in messages: + role = message['role'] + content = message['content'] + ret += f'{eval(f"self.{role}")}{content}{eox_map[role]}' + ret += f'{self.assistant}' return ret @@ -368,7 +373,7 @@ def decorate_prompt(self, prompt, sequence_start=True): assert self.capability == 'chat', \ f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: - return f'{self.system}{self.meta_instruction}{self.eosys}' \ + return f'{self.system}{self.meta_instruction}{self.eosys}' \ f'{self.user}{prompt}{self.eoh}' \ f'{self.assistant}' else: @@ -386,15 +391,16 @@ def messages2prompt(self, messages, sequence_start=True): """ if isinstance(messages, str): return self.get_prompt(messages, sequence_start) - system, users, assistants = self._translate_messages(messages) - system = self.system if not system else system - ret = f'{system}{self.meta_instruction}{self.eosys}' - for user, assistant in zip(users, assistants): - if assistant: - ret += f'{self.user}{user}{self.eoh}{self.assistant}' \ - f'{assistant}{self.eoa}' - else: - ret += f'{self.user}{user}{self.eoh}{self.assistant}' + eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys) + ret = '' + if self.meta_instruction: + ret += f'{self.system}{self.meta_instruction}{self.eosys}' + + for message in messages: + role = message['role'] + content = message['content'] + ret += f'{eval(f"self.{role}")}{content}{eox_map[role]}' + ret += f'{self.assistant}' return ret @@ -436,7 +442,7 @@ def decorate_prompt(self, prompt, sequence_start=True): assert self.capability == 'chat', \ f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: - return f'{self.b_inst} ' \ + return f'{self.b_inst} ' \ f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \ f'{prompt} {self.e_inst} ' @@ -455,7 +461,7 @@ def messages2prompt(self, messages, sequence_start=True): return self.get_prompt(messages, sequence_start) system, users, assistants = self._translate_messages(messages) system = self.default_sys_prompt if not system else system - ret = f'{self.b_inst} {self.b_sys} {system} {self.e_sys}' + ret = f'{self.b_inst} {self.b_sys} {system} {self.e_sys}' for i, (user, assistant) in enumerate(zip(users, assistants)): if i != 0: ret += f'{self.b_inst} ' @@ -571,16 +577,16 @@ def _infill_prompt(self, prompt): prefix, suffix = prompt.split('') if self.suffix_first: # format as "
 {suf}  {pre}"
-            prompt = f'
 {suffix}  {prefix}'
+            prompt = f'
 {suffix}  {prefix}'
         else:
             # format as "
 {pre} {suf} "
-            prompt = f'
 {prefix} {suffix} '
+            prompt = f'
 {prefix} {suffix} '
         return prompt
 
     def _get_prompt(self, prompt, sequence_start):
         prompt = prompt.strip()
         if sequence_start:
-            return f'{self.b_inst} ' \
+            return f'{self.b_inst} ' \
                    f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \
                    f'{prompt} {self.e_inst}'
 
@@ -625,6 +631,214 @@ def update_input_ids(self, input_ids: List):
         return input_ids
 
 
+@MODELS.register_module(name='solar')
+class SOLAR(BaseModel):
+    """Chat template of SOLAR model.
+
+    `https://huggingface.co/upstage/SOLAR-0-70b-16bit`
+    """
+
+    def __init__(self,
+                 b_sys='### System:\n',
+                 e_sys='\n\n',
+                 user='### User:\n',
+                 eoh='\n\n',
+                 assistant='### Assistant:\n',
+                 eoa='\n\n',
+                 system='',
+                 session_len=2048,
+                 **kwargs):
+        super().__init__(**kwargs)
+        self.b_sys = b_sys
+        self.e_sys = e_sys
+        self.user = user
+        self.eoh = eoh
+        self.assistant = assistant
+        self.eoa = eoa
+        self.system = system
+        self.session_len = session_len
+
+    def decorate_prompt(self, prompt, sequence_start=True):
+        """Return the prompt that is concatenated with other elements in the
+        chat template.
+
+        Args:
+            prompt (str): user's input prompt
+            sequence_start (bool): indicator for the first round chat of a
+               session sequence
+        Returns:
+            str: the concatenated prompt
+        """
+        assert self.capability == 'chat', \
+            f'{type(self).__name__} has no capability of {self.capability}'
+        if sequence_start:
+            return f'{self.b_sys}{self.system}{self.e_sys}' \
+                   f'{self.user}{prompt}{self.eoh}{self.assistant}'
+
+        return f'{self.user}{prompt}{self.eoh}{self.assistant}'
+
+    def messages2prompt(self, messages, sequence_start=True):
+        """Return the prompt that is concatenated with other elements in the
+        chat template.
+
+        Args:
+            messages (str | List): user's input prompt
+        Returns:
+            str: the concatenated prompt
+        """
+        if isinstance(messages, str):
+            return self.get_prompt(messages, sequence_start)
+        system, users, assistants = self._translate_messages(messages)
+        system = self.system if not system else system
+        ret = f'{self.b_sys}{system}{self.e_sys}'
+        for i, (user, assistant) in enumerate(zip(users, assistants)):
+            ret += f'{self.user}{user}{self.eoh}{self.assistant}'
+            if assistant:
+                ret += f'{assistant}{self.eoa}'
+        return ret
+
+
+@MODELS.register_module(name='ultracm')
+@MODELS.register_module(name='ultralm')
+class UltraChat(BaseModel):
+    """Template of UltraCM and UltraLM models.
+
+    `https://huggingface.co/openbmb/UltraCM-13b`
+    `https://huggingface.co/openbmb/UltraLM-13b`
+    """
+
+    def __init__(
+            self,
+            system="""User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.""",  # noqa: E501
+            eos='',
+            user='User: ',
+            assistant='Assistant: ',
+            session_len=2048,
+            **kwargs):
+        super().__init__(**kwargs)
+        self.system = system
+        self.eos = eos
+        self.session_len = session_len
+        self.user = user
+        self.assistant = assistant
+
+    def decorate_prompt(self, prompt, sequence_start=True):
+        """Return the prompt that is concatenated with other elements in the
+        chat template.
+
+        Args:
+            prompt (str): the input prompt
+            sequence_start (bool): indicator for the first round chat of a
+               session sequence
+        Returns:
+            str: the concatenated prompt
+        """
+        assert self.capability == 'chat', \
+            f'{type(self).__name__} has no capability of {self.capability}'
+        if sequence_start:
+            return f'{self.system}\n{self.user}{prompt}{self.eos}' \
+                   f'\n{self.assistant}'
+
+        return f'\n{self.user}{prompt}{self.eos}' \
+               f'\n{self.assistant}'
+
+    def messages2prompt(self, messages, sequence_start=True):
+        """Return the prompt that is concatenated with other elements in the
+        chat template. Only evaluate the last instruction completion pair.
+
+        Args:
+            messages (str | List): user's input prompt
+        Returns:
+            str: the concatenated prompt
+        """
+        if isinstance(messages, str):
+            return self.get_prompt(messages, sequence_start)
+        system, users, assistants = self._translate_messages(messages)
+        system = self.system if not system else system
+        ret = f'{system}'
+        for user, assistant in zip(users, assistants):
+            if assistant:
+                ret += f'\n{self.user}{user}{self.eos}' \
+                       f'\n{self.assistant}{assistant}{self.eos}'
+            else:
+                ret += f'\n{self.user}{user}{self.eos}' \
+                       f'\n{self.assistant}'
+        return ret
+
+
+@MODELS.register_module(name='yi')
+class Yi(BaseModel):
+    """Chat template of Yi model."""
+
+    def __init__(self,
+                 system='<|im_start|>system\n',
+                 meta_instruction=None,
+                 user='<|im_start|>user\n',
+                 eoh='<|im_end|>\n',
+                 eoa='<|im_end|>\n',
+                 eosys='<|im_end|>\n',
+                 assistant='<|im_start|>assistant\n',
+                 stop_words=['<|im_end|>', '<|endoftext|>'],
+                 **kwargs):
+        super().__init__(**kwargs)
+        self.system = system
+        self.meta_instruction = meta_instruction
+        self.user = user
+        self.eoh = eoh
+        self.eoa = eoa
+        self.eosys = eosys
+        self.assistant = assistant
+        self.stop_words = stop_words
+
+    def decorate_prompt(self, prompt, sequence_start=True):
+        """Return the prompt that is concatenated with other elements in the
+        chat template.
+
+        Args:
+            prompt (str): user's input prompt
+            sequence_start (bool): indicator for the first round chat of a
+               session sequence
+        Returns:
+            str: the concatenated prompt
+        """
+        assert self.capability == 'chat', \
+            f'{type(self).__name__} has no capability of {self.capability}'
+        if sequence_start:
+            if self.meta_instruction is None:
+                return f'{self.user}{prompt}{self.eoh}' \
+                   f'{self.assistant}'
+            return f'{self.system}{self.meta_instruction}{self.eosys}' \
+                   f'{self.user}{prompt}{self.eoh}' \
+                   f'{self.assistant}'
+        else:
+            return f'{self.user}{prompt}{self.eoh}' \
+                   f'{self.assistant}'
+
+    def messages2prompt(self, messages, sequence_start=True):
+        """Return the prompt that is concatenated with other elements in the
+        chat template.
+
+        Args:
+            messages (str | List): user's input prompt
+        Returns:
+            str: the concatenated prompt
+        """
+
+        if isinstance(messages, str):
+            return self.get_prompt(messages, sequence_start)
+        eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
+        ret = ''
+        if self.meta_instruction:
+            ret += f'{self.system}:{self.meta_instruction}{self.eosys}'
+
+        for message in messages:
+            role = message['role']
+            content = message['content']
+            ret += f'{eval(f"self.{role}")}{content}{eox_map[role]}'
+        ret += f'{self.assistant}'
+        return ret
+
+
 def main(model_name: str = 'test'):
     assert model_name in MODELS.module_dict.keys(), \
         f"'{model_name}' is not supported. " \
@@ -637,4 +851,5 @@ def main(model_name: str = 'test'):
 
 if __name__ == '__main__':
     import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/pytorch/block.py b/lmdeploy/pytorch/block.py
index 744f5bddfa..8a136132ff 100644
--- a/lmdeploy/pytorch/block.py
+++ b/lmdeploy/pytorch/block.py
@@ -2,6 +2,8 @@
 # modify from: https://github.com/vllm-project/vllm
 from dataclasses import dataclass
 
+import numpy as np
+
 
 class LogicalTokenBlock:
     """Logical block used to count tokens per block."""
@@ -30,6 +32,110 @@ def append_tokens(self, num_tokens: int = 1):
         self.num_tokens += num_tokens
 
 
+def _div_up(x, n):
+    """perform div up."""
+    return (x + n - 1) // n
+
+
+def _round_up(x, n):
+    """perform round up."""
+    return _div_up(x, n) * n
+
+
+class LogicalTokenBlocks:
+    """Logical blocks."""
+    ALLOC_SIZE = 128
+
+    def __init__(self, block_size: int):
+        self._block_size = block_size
+        reserve_size = _round_up(block_size, self.ALLOC_SIZE)
+        self._blocks = np.zeros((reserve_size, ), dtype=np.int64)
+        self._last_block_size = 0
+        self._num_real = 0
+
+    def reserve(self, size: int):
+        """reserve cache size."""
+        num_blocks = self._blocks.size
+        if num_blocks >= size:
+            return
+        reserve_size = _round_up(size - num_blocks, self.ALLOC_SIZE)
+        self._blocks = np.pad(self._blocks, (0, reserve_size))
+
+    def __setitem__(self, *args, **kwargs):
+        """set values."""
+        return self.get_real_blocks().__setitem__(*args, **kwargs)
+
+    def __getitem__(self, *args, **kwargs):
+        """get values."""
+        return self.get_real_blocks().__getitem__(*args, **kwargs)
+
+    def get_real_blocks(self):
+        """get logical blocks."""
+        return self._blocks[:self._num_real]
+
+    def append(self, blocks: np.ndarray):
+        """append blocks."""
+        num_blocks = len(blocks)
+        self.reserve(num_blocks + self._num_real)
+        slice_start = self._num_real
+        slice_end = slice_start + num_blocks
+        self._num_real += num_blocks
+        self.__setitem__(slice(slice_start, slice_end), blocks)
+
+    def num_required_blocks(self, num_tokens: int):
+        """get num required blocks."""
+        if self._last_block_size == 0:
+            remain_tokens = num_tokens
+        else:
+            next_block_size = min(num_tokens,
+                                  self._block_size - self._last_block_size)
+            remain_tokens = num_tokens - next_block_size
+        return _div_up(remain_tokens, self._block_size)
+
+    def add_tokens(self, num_tokens: int):
+        """add tokens."""
+        total_tokens = self.num_tokens() + num_tokens
+        self._last_block_size = total_tokens % self._block_size
+        if self._last_block_size == 0:
+            self._last_block_size = self._block_size
+
+    def num_tokens(self):
+        """get num tokens."""
+        return max(
+            0, self._num_real - 1) * self._block_size + self._last_block_size
+
+    def __len__(self):
+        """get length."""
+        return self._num_real
+
+    def reshape_by_tokens(self, num_tokens: int):
+        """resize logical blocks by num tokens."""
+        assert num_tokens <= self.num_tokens()
+        self._num_real = _div_up(num_tokens, self._block_size)
+        self._last_block_size = num_tokens % self._block_size
+        if self._last_block_size == 0:
+            self._last_block_size = self._block_size
+
+    def reset(self):
+        """reset."""
+        self.reshape_by_tokens(0)
+
+    def get_block_size(self):
+        """get block size."""
+        return self._block_size
+
+    def last_block_size(self):
+        """get last block size."""
+        return self._last_block_size
+
+    def clone(self):
+        """clone logical blocks."""
+        ret = LogicalTokenBlocks(self.get_block_size())
+        ret.append(self[:])
+        ret.add_tokens(self.num_tokens())
+        return ret
+
+
 @dataclass
 class PhysicalTokenBlock:
     """Physical block used to schedule key value cache."""
diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py
index 087a92fcc0..1017fda64b 100644
--- a/lmdeploy/pytorch/chat.py
+++ b/lmdeploy/pytorch/chat.py
@@ -53,7 +53,7 @@ def main(
         stream_output (bool): indicator for streaming output or not
     """
     # tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
-    tokenizer = Tokenizer(model_path, trust_remote_code)
+    tokenizer = Tokenizer(model_path)
     tm_model = tm.Engine(model_path,
                          tp=tp,
                          trust_remote_code=trust_remote_code)
@@ -100,10 +100,10 @@ def main(
                     sampling_param=sampling_param):
                 status, res, tokens = outputs
                 # decode res
-                response = tokenizer.decode(res)[response_size:]
+                response = tokenizer.decode(res, offset=response_size)
                 response = valid_str(response)
                 print(f'{response}', end='', flush=True)
-                response_size += len(response)
+                response_size = tokens
 
             # update step
             step += len(input_ids) + tokens
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index ef31bf7af4..09c53462fb 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -10,6 +10,7 @@ class SchedulerConfig:
     max_session_len: int
     max_request_output_len: int
     eviction_type: str = 'copy'
+    prefill_interval: int = 16
 
 
 @dataclass
diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py
index 5a6dbd10d3..c9d00a1fe0 100644
--- a/lmdeploy/pytorch/engine/engine.py
+++ b/lmdeploy/pytorch/engine/engine.py
@@ -8,6 +8,7 @@
 import torch
 from transformers import AutoConfig
 
+from lmdeploy.tokenizer import Tokenizer
 from lmdeploy.utils import get_logger
 
 from ..config import CacheConfig, ModelConfig, SchedulerConfig
@@ -126,6 +127,14 @@ def _build_model_agent(model_path: str,
     return model_agent
 
 
+def _tensorlize_block_offsets(block_offsets):
+    """tensorlize block_offsets."""
+    from torch.nn.utils.rnn import pad_sequence
+    block_offsets = [torch.from_numpy(off) for off in block_offsets]
+    block_offsets = pad_sequence(block_offsets, batch_first=True)
+    return block_offsets
+
+
 class Engine:
     """The inference engine of lmdeploy pytorch.
 
@@ -141,14 +150,20 @@ def __init__(self,
                  scheduler_config: SchedulerConfig = None,
                  cache_config: CacheConfig = None,
                  tp: int = 1,
+                 model_name: str = None,
                  trust_remote_code=True) -> None:
 
         self.tp = tp
         self.gpu_count = tp
+        self.model_name = model_name
 
         scheduler_config = scheduler_config or SchedulerConfig(
-            max_batches=64, max_session_len=4096, max_request_output_len=512)
+            max_batches=128,
+            max_session_len=4096,
+            max_request_output_len=512,
+            eviction_type='recompute')
 
+        # block_size = 1 to enable unified paging
         cache_config = cache_config or CacheConfig(
             block_size=64, num_cpu_blocks=0, num_gpu_blocks=0)
 
@@ -182,6 +197,18 @@ def __init__(self,
         # create main thread
         self._start_loop()
 
+        self._create_buffers()
+        self.tokenizer = Tokenizer(model_path)
+
+    def _create_buffers(self):
+        scheduler_config = self.scheduler_config
+        max_batches = scheduler_config.max_batches
+
+        # buffers to create inputs
+        self._q_start_loc_buf = torch.arange(max_batches)
+        self._attention_mask_buf = torch.ones(max_batches, 1, dtype=torch.long)
+        self._seq_length_buf = torch.ones(max_batches, dtype=torch.long)
+
     def _bind_request_manager(self):
         """bind request manager."""
         req_manager = RequestManager()
@@ -218,12 +245,12 @@ def _on_stop_session(self, reqs: Request, **kwargs):
             resp_type = ResponseType.SESSION_NOT_EXIST
             if session_id in self.scheduler.sessions:
                 self.scheduler.stop_session(session_id)
-                self.scheduler.update()
                 resp_type = ResponseType.SUCCESS
             self.req_manager.response(
                 Response(type=resp_type,
                          sender_id=req.sender_id,
                          req_id=req.req_id))
+        self.scheduler.update()
 
     def _on_end_session(self, reqs: Request, **kwargs):
         """on end session callback."""
@@ -232,12 +259,12 @@ def _on_end_session(self, reqs: Request, **kwargs):
             resp_type = ResponseType.SESSION_NOT_EXIST
             if session_id in self.scheduler.sessions:
                 self.scheduler.end_session(session_id)
-                self.scheduler.update()
                 resp_type = ResponseType.SUCCESS
             self.req_manager.response(
                 Response(type=resp_type,
                          sender_id=req.sender_id,
                          req_id=req.req_id))
+        self.scheduler.update()
 
     def _on_add_message(self, reqs: Request, **kwargs):
         """on add message callback."""
@@ -265,10 +292,10 @@ def _on_add_message(self, reqs: Request, **kwargs):
                 msg.remain_output_len = req.data['max_request_output_len']
                 msg.sampling_param = req.data['sampling_param']
                 msg.status = MessageStatus.WAITING
-                self.scheduler.update()
 
             msg.sender_id = req.sender_id
             msg.req_id = req.req_id
+        self.scheduler.update()
 
     def create_instance(self, cuda_stream_id=0):
         """Create a turbomind instance.
@@ -310,14 +337,12 @@ def end_session(self, session_id: int):
                                       f'Error: {resp.type}.')):
             self.owned_sessions.remove(session_id)
 
-    def create_model_inputs(self,
-                            messages: List[SchedulerSequence],
-                            device: str = 'cuda'):
+    @torch.inference_mode()
+    def create_model_inputs(self, messages: List[SchedulerSequence]):
         """create model inputs from messages.
 
         Args:
             messages (List[SchedulerSequence]): The input messages.
-            device (str): Device name.
         """
         history_lengths = [msg.history_len for msg in messages]
 
@@ -329,39 +354,29 @@ def create_model_inputs(self,
             token_ids = [token_ids]
 
         batch_size = len(messages)
-        input_ids = token_ids
-        input_ids = torch.cat(input_ids).to(device)
+        input_ids = torch.cat(token_ids)
 
         is_decoding = input_ids.size(0) == batch_size
         if not is_decoding:
             seq_length = [tokens.size(0) for tokens in token_ids]
+            seq_length = torch.tensor(seq_length, dtype=torch.long)
             max_seq_len = max(seq_length)
-            q_start_loc = torch.tensor([0] +
-                                       seq_length).cumsum(0)[:-1].to(device)
-
-            attention_mask = torch.tensor([
-                seq_len * [1] + (max_seq_len - seq_len) * [0]
-                for seq_len in seq_length
-            ]).to(device)
+            q_start_loc = seq_length.cumsum(0) - seq_length
+            mask_range = torch.arange(max_seq_len)[None, :]
+            attention_mask = (mask_range < seq_length[:, None]).long()
             position_ids = attention_mask.long().cumsum(-1) - 1
             position_ids += position_ids.new_tensor(history_lengths).unsqueeze(
                 -1)
-            seq_length = torch.tensor(seq_length).to(device)
         else:
-            q_start_loc = torch.arange(batch_size, device=device)
-            attention_mask = torch.ones(batch_size,
-                                        1,
-                                        dtype=torch.long,
-                                        device=device)
+            q_start_loc = self._q_start_loc_buf[:batch_size]
+            attention_mask = self._attention_mask_buf[:batch_size]
+            seq_length = self._seq_length_buf[:batch_size]
             position_ids = q_start_loc.new_tensor(history_lengths).unsqueeze(
                 -1)
-            seq_length = torch.ones(batch_size,
-                                    dtype=torch.long,
-                                    device=device)
 
-        block_tables = self.scheduler.get_block_tables(messages)
-        block_offsets = [[block.block_id for block in block_table]
-                         for block_table in block_tables]
+        # TODO: get block offsets is slow when block_size = 1
+        block_offsets = self.scheduler.get_block_tables(messages)
+        block_offsets = _tensorlize_block_offsets(block_offsets)
 
         # add batch dim [bs=1, seq_len]
         if input_ids.ndim == 1:
@@ -401,7 +416,7 @@ def _check_request_len(msg):
             return msg.remain_output_len <= 0
 
         def _check_session_len(msg, max_session_len):
-            session_len = sum(block.num_tokens for block in msg.logical_blocks)
+            session_len = msg.logical_blocks.num_tokens()
             return session_len >= max_session_len
 
         sampling_param = msg.sampling_param
@@ -445,8 +460,7 @@ def _sampling(grouped_params, split_logits, inputs):
                 new_logits = split_logits[idx]
                 new_logits = logits_processor(input_ids, new_logits)
                 argmax_ids = new_logits.argmax(-1).cpu()
-                for i, next_ids in zip(idx, argmax_ids):
-                    next_token_ids[i] = next_ids
+                next_token_ids[idx] = argmax_ids
             return next_token_ids
 
         logits = logits.cuda()
@@ -476,7 +490,7 @@ def update_running(self, running: List[SchedulerSequence],
             if self._stoping_criteria(msg, token):
                 msg.status = MessageStatus.STOPPED
 
-    def step(self, return_logits=False):
+    def step(self, is_prefill: bool, return_logits: bool = False):
         """one step inference. Used to perform streaming chat.
 
         Args:
@@ -487,7 +501,7 @@ def step(self, return_logits=False):
         """
 
         # schedule
-        schedule_output = self.scheduler.schedule()
+        schedule_output = self.scheduler.schedule(is_prefill=is_prefill)
 
         running: List[SchedulerSequence] = schedule_output.running
         swap_in_map = schedule_output.swap_in_map
@@ -520,7 +534,7 @@ def step(self, return_logits=False):
                 sender_id=msg.sender_id,
                 req_id=msg.req_id,
                 finish=(msg.status == MessageStatus.STOPPED),
-                token_ids=[next_id],
+                token_ids=[next_id.item()],
             )
             outputs[session_id] = out
 
@@ -638,7 +652,7 @@ def decode(self, prompt_token_ids: List[List[int]]):
             msgs.append(msg)
             self.scheduler.add_sequence(msg)
 
-        outputs = self.step(True)
+        outputs = self.step(return_logits=True)
 
         logits = dict((k, out.logits) for k, out in outputs.items())
 
@@ -659,6 +673,7 @@ def loop(self):
         send_resp_que = Queue()
 
         def _send_resp():
+            """send response callback."""
             while True:
                 step_tokens = send_resp_que.get()
                 for _, out in step_tokens.items():
@@ -676,6 +691,8 @@ def _send_resp():
 
         send_thread = Thread(target=_send_resp, daemon=True)
         send_thread.start()
+        prefill_interval = self.scheduler_config.prefill_interval
+        prefill_counter = prefill_interval
 
         while True:
             if not self.req_manager.has_requests(
@@ -687,8 +704,14 @@ def _send_resp():
 
             # forward
             if self.scheduler.has_unfinished():
-                with torch.no_grad():
-                    step_tokens: Dict[int, InferOutput] = self.step()
+                has_running = self.scheduler.has_running()
+                is_prefill = not prefill_counter or not has_running
+                if is_prefill:
+                    prefill_counter = prefill_interval
+                with torch.inference_mode():
+                    step_tokens: Dict[int, InferOutput] = self.step(
+                        is_prefill=is_prefill)
+                prefill_counter -= 1
 
                 # send response
                 send_resp_que.put(step_tokens)
diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py
index 844a70b13d..cb4dda8c3d 100644
--- a/lmdeploy/pytorch/engine/model_agent.py
+++ b/lmdeploy/pytorch/engine/model_agent.py
@@ -36,7 +36,7 @@ def _update_cache_config(model_config: ModelConfig,
         gpu_id (int): The GPU id to use.
     """
     GPU_MEM_PERCENT = 0.7
-    SWAP_SPACE = 4 * (1 << 30)
+    SWAP_SPACE = 8 * (1 << 30)
     gpu_mem_physical_free, _ = get_gpu_memory(gpu_id)
     gpu_mem = gpu_mem_physical_free * GPU_MEM_PERCENT
     cpu_mem = SWAP_SPACE
@@ -74,6 +74,17 @@ class ModelInputs:
     is_decoding: bool
     meta: Any
 
+    def to_device(self, device: str):
+        """to device."""
+        input_dict = asdict(self)
+        out_dict = dict()
+        for k, v in input_dict.items():
+            if isinstance(v, torch.Tensor):
+                v = v.to(device)
+            out_dict[k] = v
+
+        return ModelInputs(**out_dict)
+
 
 class StepContext:
     """context of Model.
@@ -95,7 +106,7 @@ def __init__(
         json_config: dict = None,
     ):
         self.inputs = inputs
-        self.block_offsets_list = inputs.block_offsets
+        self.block_offsets = inputs.block_offsets
         self.position_ids = inputs.position_ids
         self.q_start_loc = inputs.q_start_loc
         self.history_lengths = inputs.history_lengths
@@ -107,8 +118,6 @@ def __init__(
         # seq_len + history_length
         self.kv_seq_length = self.position_ids[..., -1] + 1
 
-        self.block_offsets = self.tensorlize_block_offsets(
-            self.block_offsets_list, device)
         self.position_ids_1d = self.get_position_ids_1d(
             self.position_ids, self.seq_length, device)
 
@@ -117,15 +126,18 @@ def __init__(
     @classmethod
     def tensorlize_block_offsets(cls, block_offsets, device):
         """tensorlize block_offsets."""
-        # padding zero
-        # torch.nn.utils.rnn.pad_sequence is slower than manually concate
+        import numpy as np
         offset_len = [len(offset) for offset in block_offsets]
         max_offsets_len = max(offset_len)
-        pad_block_offsets = [
-            offset + [0] * (max_offsets_len - off_len)
-            for offset, off_len in zip(block_offsets, offset_len)
-        ]
-        block_offsets = torch.tensor(pad_block_offsets).to(device)
+        batch_size = len(offset_len)
+        pad_block_offsets = np.zeros((batch_size, max_offsets_len),
+                                     dtype=np.int64)
+
+        for pad_offset, offset, off_len in zip(pad_block_offsets,
+                                               block_offsets, offset_len):
+            pad_offset[:off_len] = offset
+
+        block_offsets = torch.from_numpy(pad_block_offsets).to(device)
         return block_offsets
 
     @classmethod
@@ -185,8 +197,9 @@ def model_forward(
 ):
     """perform model forward."""
     stream = stream or torch.cuda.current_stream()
-    with torch.no_grad(), torch.cuda.stream(stream):
+    with torch.inference_mode(), torch.cuda.stream(stream):
         # forward
+        inputs = inputs.to_device('cuda')
         context = StepContext(
             inputs=inputs,
             world_size=world_size,
@@ -319,6 +332,29 @@ def _tp_build_model(
     patched_model = None
     cache_engine = None
 
+    def _broadcast_config(cache_config):
+        """broadcast cache config, use minimum cache."""
+        if rank == 0:
+            gathered_configs = [None] * world_size
+            dist.gather_object(cache_config, gathered_configs)
+            num_gpu_blocks_list = [
+                config.num_gpu_blocks for config in gathered_configs
+            ]
+            num_cpu_blocks_list = [
+                config.num_cpu_blocks for config in gathered_configs
+            ]
+            min_num_gpu_blocks = min(num_gpu_blocks_list)
+            min_num_cpu_blocks = min(num_cpu_blocks_list)
+            cache_config.num_cpu_blocks = min_num_cpu_blocks
+            cache_config.num_gpu_blocks = min_num_gpu_blocks
+            config_list = [cache_config]
+        else:
+            gathered_configs = None
+            dist.gather_object(cache_config, gathered_configs)
+            config_list = [None]
+        dist.broadcast_object_list(config_list)
+        return config_list[0]
+
     try:
         config = AutoConfig.from_pretrained(
             model_path, trust_remote_code=trust_remote_code)
@@ -341,6 +377,7 @@ def _tp_build_model(
         )
 
         _update_cache_config(model_config, cache_config)
+        cache_config = _broadcast_config(cache_config)
         cache_engine = CacheEngine(cache_config,
                                    model_config,
                                    rank=rank,
diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py
index eeb6875ed7..ad1d154fa4 100644
--- a/lmdeploy/pytorch/engine/request.py
+++ b/lmdeploy/pytorch/engine/request.py
@@ -67,6 +67,29 @@ def __init__(self, sender_id: int, req_que: Queue):
         self.resp_que = Queue()
         self.resp_dict = dict()
 
+    def _push_resp(self, req_id: int, resp: Response):
+        """push response."""
+        self.resp_dict.setdefault(req_id, [])
+        self.resp_dict[req_id].append(resp)
+
+    def _pop_resp(self, req_id: int, default: Any = None):
+        """pop response."""
+        if req_id not in self.resp_dict:
+            return default
+        resps = self.resp_dict[req_id]
+        ret = resps.pop(0)
+        if len(resps) == 0:
+            self.resp_dict.pop(req_id)
+        return ret
+
+    def _prefetch_resps(self):
+        """prefetch from resp que."""
+        num_resps = self.resp_que.qsize()
+        for _ in range(num_resps):
+            resp: Response = self.resp_que.get()
+            req_id = resp.req_id
+            self._push_resp(req_id, resp)
+
     def batched_send_async(self, req_types: List[RequestType],
                            data: List[Any]) -> List[int]:
         """Batched send request asynchronize."""
@@ -95,31 +118,33 @@ def send_async(self, req_type: RequestType, data: Any) -> int:
     def recv_any(self, que_timeout: float = None) -> Response:
         """receive any response."""
         # check resp dict
-        for req_id, resps in self.resp_dict.items():
-            ret = resps.pop(0)
-            if len(resps) == 0:
-                self.resp_dict.pop(req_id)
-            return ret
+        self._prefetch_resps()
+        for req_id in self.resp_dict:
+            ret = self._pop_resp(req_id, default=None)
+            if ret is not None:
+                return ret
 
         # check resp que
         return self.resp_que.get(timeout=que_timeout)
 
+    def recv_all(self, req_id: int):
+        """revceive all response with req_id."""
+        self._prefetch_resps()
+        resps = self.resp_dict.pop(req_id, [])
+        return resps
+
     def recv(self, req_id: int, que_timeout: float = None) -> Response:
         """receive response of given request id."""
         # check resp dict
-        if req_id in self.resp_dict:
-            resps = self.resp_dict[req_id]
-            ret = resps.pop(0)
-            if len(resps) == 0:
-                self.resp_dict.pop(req_id)
+        ret = self._pop_resp(req_id, default=None)
+        if ret is not None:
             return ret
 
         # check resp que
         while True:
             resp: Response = self.resp_que.get(timeout=que_timeout)
             if resp.req_id != req_id:
-                self.resp_dict.setdefault(req_id, [])
-                self.resp_dict[req_id].append(resp)
+                self._push_resp(req_id, resp)
             else:
                 return resp
 
diff --git a/lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py
index d59022f649..f2dd3d92bc 100644
--- a/lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py
+++ b/lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py
@@ -107,6 +107,7 @@ def apply_rotary_pos_emb_qk_kernel(
     tl.store(K_EMB + k_offset_h, k_emb_h, mask=pos_mask)
 
 
+@torch.inference_mode()
 def apply_rotary_pos_emb(q: Tensor,
                          k: Tensor,
                          cos: Tensor,
diff --git a/lmdeploy/pytorch/kernels/fill_kv_cache.py b/lmdeploy/pytorch/kernels/fill_kv_cache.py
index 69d408d846..d0f72558af 100644
--- a/lmdeploy/pytorch/kernels/fill_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/fill_kv_cache.py
@@ -59,29 +59,45 @@ def _fill_kv_cache_kernel(
         tl.store(vc_ptrs + idx, vs, mask=mask)
 
 
-def _create_fill_cache_info(is_decoding: bool, block_size: int,
-                            start_loc: Tensor, seq_length: Tensor,
-                            block_offsets: Tensor, history_lengths: Sequence,
-                            device: torch.device):
-    """create information for cache filling.
-
-    There are 4 tensor that we need to generate. Each one has shape (N) where
-    N is the number of blocks that need to fill data.
-    1. state_start: the token offset where we copy data from.
-    2. cache_start: the token offset in block that we want to copy to.
-    3. state_len: how many data (tokens) we want to copy
-    4. block_offset1d: which block we want to perform the filling.
-    """
-    batch_size = block_offsets.size(0)
-
+def _prefilling_cache_info(block_size: int, seq_length: Tensor,
+                           block_offsets: Tensor, history_lengths: Sequence,
+                           device: torch.device):
     # make sure history lengths is a tensor
     if not isinstance(history_lengths, Tensor):
         history_lengths = torch.tensor(history_lengths, device=device)
 
-    first_block_ids = history_lengths // block_size
-    token_ids_start = history_lengths % block_size
-    if not is_decoding:
-        # prefilling
+    def _unified_cache_info():
+        total_blocks = seq_length.sum().item()
+
+        # cache start
+        cache_start = torch.zeros((total_blocks, ),
+                                  dtype=torch.long,
+                                  device=device)
+
+        state_len = torch.ones((total_blocks, ),
+                               dtype=torch.long,
+                               device=device)
+
+        # state_start (0~cumed state len)
+        cum_state_len = state_len.cumsum(0)
+        state_start = torch.cat(
+            [state_len.new_zeros((1, )), cum_state_len[:-1]])
+
+        # block offsets
+        block_offsets1d = [
+            offs[first:last]
+            for offs, first, last in zip(block_offsets, history_lengths.cpu(),
+                                         (seq_length + history_lengths).cpu())
+        ]
+        block_offsets1d = torch.cat(block_offsets1d)
+        return dict(state_start=state_start,
+                    state_len=state_len,
+                    cache_start=cache_start,
+                    block_offsets1d=block_offsets1d)
+
+    def _default_cache_info():
+        first_block_ids = history_lengths // block_size
+        token_ids_start = history_lengths % block_size
 
         # initialize
         kv_seq_length = history_lengths + seq_length
@@ -119,23 +135,58 @@ def _create_fill_cache_info(is_decoding: bool, block_size: int,
                 block_offsets, first_block_ids.cpu(), last_block_ids.cpu())
         ]
         block_offsets1d = torch.cat(block_offsets1d)
+        return dict(state_start=state_start,
+                    state_len=state_len,
+                    cache_start=cache_start,
+                    block_offsets1d=block_offsets1d)
+
+    if block_size == 1:
+        return _unified_cache_info()
     else:
-        # decoding
-        state_start = start_loc
-        state_len = seq_length
-        cache_start = token_ids_start
-        batch_ids = torch.arange(batch_size, device=device)
-        block_offsets1d = block_offsets[batch_ids, first_block_ids]
+        return _default_cache_info()
+
+
+def _decoding_cache_info(block_size: int, start_loc: Tensor,
+                         seq_length: Tensor, block_offsets: Tensor,
+                         history_lengths: Sequence, device: torch.device):
+    batch_size = block_offsets.size(0)
+
+    # make sure history lengths is a tensor
+    if not isinstance(history_lengths, Tensor):
+        history_lengths = torch.tensor(history_lengths, device=device)
+
+    first_block_ids = history_lengths // block_size
+    token_ids_start = history_lengths % block_size
+
+    batch_ids = torch.arange(batch_size, device=device)
+    return dict(state_start=start_loc,
+                state_len=seq_length,
+                cache_start=token_ids_start,
+                block_offsets1d=block_offsets[batch_ids, first_block_ids])
+
 
-    fill_cache_info = dict()
-    fill_cache_info['state_start'] = state_start
-    fill_cache_info['state_len'] = state_len
-    fill_cache_info['cache_start'] = cache_start
-    fill_cache_info['block_offsets1d'] = block_offsets1d
+def _create_fill_cache_info(is_decoding: bool, block_size: int,
+                            start_loc: Tensor, seq_length: Tensor,
+                            block_offsets: Tensor, history_lengths: Sequence,
+                            device: torch.device):
+    """create information for cache filling.
 
-    return fill_cache_info
+    There are 4 tensor that we need to generate. Each one has shape (N) where
+    N is the number of blocks that need to fill data.
+    1. state_start: the token offset where we copy data from.
+    2. cache_start: the token offset in block that we want to copy to.
+    3. state_len: how many data (tokens) we want to copy
+    4. block_offset1d: which block we want to perform the filling.
+    """
+    if not is_decoding:
+        return _prefilling_cache_info(block_size, seq_length, block_offsets,
+                                      history_lengths, device)
+    else:
+        return _decoding_cache_info(block_size, start_loc, seq_length,
+                                    block_offsets, history_lengths, device)
 
 
+@torch.inference_mode()
 def fill_kv_cache(k_states: Tensor,
                   v_states: Tensor,
                   k_caches: Tensor,
@@ -162,6 +213,14 @@ def fill_kv_cache(k_states: Tensor,
         history_lengths (Sequence): The history lengths of each data in batch.
         context (Any): Context object of current step.
     """
+
+    def _kernel_meta():
+        device = k_states.device
+        device_idx = device.index
+        device_type = device.type
+        stream = get_cuda_stream(device_idx)
+        return dict(device=device, device_type=device_type, stream=stream)
+
     fill_cache_info = getattr(context, 'fill_cache_info', None)
 
     if fill_cache_info is None:
@@ -187,28 +246,21 @@ def fill_kv_cache(k_states: Tensor,
     BLOCK_M = k_caches.size(-3)
     BLOCK_N = min(128, k_caches.stride(-3), v_caches.stride(-3))
 
-    device = k_states.device
-    device_idx = device.index
-    device_type = device.type
-    stream = get_cuda_stream(device_idx)
-    _fill_kv_cache_kernel[grid](
-        k_states,
-        v_states,
-        k_caches,
-        v_caches,
-        state_start=state_start,
-        state_len=state_len,
-        cache_start=cache_start,
-        block_offsets1d=block_offsets1d,
-        stride_kss=k_states.stride(-3),
-        stride_vss=v_states.stride(-3),
-        stride_kcs=k_caches.stride(-3),
-        stride_vcs=v_caches.stride(-3),
-        BLOCK_M=BLOCK_M,
-        BLOCK_N=BLOCK_N,
-        num_warps=4,
-        num_stages=1,
-        stream=stream,
-        device=device_idx,
-        device_type=device_type,
-    )
+    kernel_meta = _kernel_meta()
+    _fill_kv_cache_kernel[grid](k_states,
+                                v_states,
+                                k_caches,
+                                v_caches,
+                                state_start=state_start,
+                                state_len=state_len,
+                                cache_start=cache_start,
+                                block_offsets1d=block_offsets1d,
+                                stride_kss=k_states.stride(-3),
+                                stride_vss=v_states.stride(-3),
+                                stride_kcs=k_caches.stride(-3),
+                                stride_vcs=v_caches.stride(-3),
+                                BLOCK_M=BLOCK_M,
+                                BLOCK_N=BLOCK_N,
+                                num_warps=4,
+                                num_stages=1,
+                                **kernel_meta)
diff --git a/lmdeploy/pytorch/kernels/mbgmm.py b/lmdeploy/pytorch/kernels/mbgmm.py
new file mode 100644
index 0000000000..206deeb50d
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/mbgmm.py
@@ -0,0 +1,267 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import triton
+import triton.language as tl
+from torch import Tensor
+
+
+def _next_pow_of_2(x):
+    """get next power of 2."""
+    return 1 << (x - 1).bit_length()
+
+
+@triton.jit
+def _x_a_mm_kernel(
+    X,
+    LoRA_A,
+    XA,
+    B_start_loc,
+    B_seq_lens,
+    B_rank_id,
+    Rank_page_table,
+    Ranks,
+    stride_xs,
+    stride_xh,
+    stride_las,
+    stride_lah,
+    stride_xas,
+    stride_xar,
+    stride_ptb,
+    BLOCK_M: tl.constexpr,
+    BLOCK_R: tl.constexpr,
+    BLOCK_H: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+):
+    """xa mm kernel."""
+    cur_batch = tl.program_id(0)
+    start_m = tl.program_id(1)
+
+    r_off = tl.arange(0, BLOCK_R)
+
+    seq_len = tl.load(B_seq_lens + cur_batch)
+    if start_m * BLOCK_M >= seq_len:
+        return
+
+    start_loc = tl.load(B_start_loc + cur_batch)
+    rank_id = tl.load(B_rank_id + cur_batch)
+    rank = tl.load(Ranks + rank_id)
+
+    page_table_off = rank_id * stride_ptb + r_off
+    rank_mask = r_off < rank
+    page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
+
+    m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    r_off = tl.arange(0, BLOCK_R)
+    dm_off = tl.arange(0, BLOCK_DMODEL)
+    rank_mask = r_off < rank
+
+    x_off = (start_loc + m_off) * stride_xs
+    xs_mask = m_off < seq_len
+    la_page_off = page_table * stride_las
+    acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32)
+
+    # compute acc
+    for start_h in range(0, BLOCK_H, BLOCK_DMODEL):
+        cur_dm_off = start_h + dm_off
+        h_mask = cur_dm_off < BLOCK_H
+
+        # load x
+        xh_off = cur_dm_off * stride_xh
+        x_mask = xs_mask[:, None] and h_mask[None, :]
+        x = tl.load(X + x_off[:, None] + xh_off[None, :],
+                    mask=x_mask,
+                    other=0.0)
+
+        # load lora a
+        lah_off = cur_dm_off * stride_lah
+        la_mask = rank_mask[None, :] and h_mask[:, None]
+        la = tl.load(LoRA_A + la_page_off[None, :] + lah_off[:, None],
+                     mask=la_mask,
+                     other=0.0)
+
+        # compute
+        acc += tl.dot(x, la)
+
+    acc = acc.to(X.dtype.element_ty)
+    xa_off = (start_loc + m_off) * stride_xas
+    xas_mask = xs_mask
+    xa_mask = xas_mask[:, None] and rank_mask[None, :]
+    tl.store(XA + xa_off[:, None] + r_off[None, :] * stride_xar,
+             acc,
+             mask=xa_mask)
+
+
+@triton.jit
+def _acc_b_mm_kernel(
+    XA,
+    LoRA_B,
+    Out,
+    B_start_loc,
+    B_seq_lens,
+    B_rank_id,
+    Rank_page_table,
+    Ranks,
+    stride_xas,
+    stride_xar,
+    stride_os,
+    stride_oh,
+    stride_lbs,
+    stride_lbh,
+    stride_ptb,
+    BLOCK_M: tl.constexpr,
+    BLOCK_R: tl.constexpr,
+    BLOCK_HO: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+):
+    cur_batch = tl.program_id(0)
+    start_m = tl.program_id(1)
+
+    r_off = tl.arange(0, BLOCK_R)
+
+    seq_len = tl.load(B_seq_lens + cur_batch)
+    if start_m * BLOCK_M >= seq_len:
+        return
+
+    start_loc = tl.load(B_start_loc + cur_batch)
+    rank_id = tl.load(B_rank_id + cur_batch)
+    rank = tl.load(Ranks + rank_id)
+
+    page_table_off = rank_id * stride_ptb + r_off
+    rank_mask = r_off < rank
+    page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
+
+    m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    dm_off = tl.arange(0, BLOCK_DMODEL)
+    rank_mask = r_off < rank
+    lb_page_off = page_table * stride_lbs
+
+    xs_mask = m_off < seq_len
+    o_off = (start_loc + m_off) * stride_os
+    os_mask = xs_mask
+
+    xa_off = (start_loc + m_off) * stride_xas
+    xa_mask = xs_mask[:, None] and rank_mask[None, :]
+    acc = tl.load(XA + xa_off[:, None] + r_off[None, :] * stride_xar,
+                  mask=xa_mask,
+                  other=0.0)
+    acc = acc.to(LoRA_B.dtype.element_ty)
+
+    # compute output
+    for start_h in range(0, BLOCK_HO, BLOCK_DMODEL):
+        cur_dm_off = start_h + dm_off
+        h_mask = cur_dm_off < BLOCK_HO
+
+        # load lora b
+        lbh_off = cur_dm_off * stride_lbh
+        lb_mask = rank_mask[:, None] and h_mask[None, :]
+        lb = tl.load(LoRA_B + lb_page_off[:, None] + lbh_off[None, :],
+                     mask=lb_mask,
+                     other=0)
+
+        # compute
+        out = tl.dot(acc, lb)
+        out = out.to(lb.dtype)
+
+        # store o
+        oh_off = cur_dm_off * stride_oh
+        o_mask = os_mask[:, None] and h_mask[None, :]
+        tl.store(Out + o_off[:, None] + oh_off[None, :], out, mask=o_mask)
+
+
+@torch.inference_mode()
+def mbgmm_a(x: Tensor, lora_a: Tensor, b_start_loc: Tensor, b_seq_lens: Tensor,
+            b_rank_ids: Tensor, rank_page_table: Tensor, ranks: Tensor,
+            max_seq_len: int, max_rank: int):
+    """mbgmm_a."""
+    assert x.dim() == 2
+    assert lora_a.dim() == 2
+    assert rank_page_table.dim() == 2
+
+    head_size = x.size(-1)
+    batch_size = len(b_seq_lens)
+
+    BLOCK_M = 32
+    BLOCK_R = _next_pow_of_2(max_rank)
+    if BLOCK_R < 16:
+        BLOCK_R = 16
+    BLOCK_H = head_size
+    BLOCK_DMODEL = 64
+
+    num_warps = 4
+    grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)]
+    xa = x.new_empty((x.size(0), BLOCK_R))
+    _x_a_mm_kernel[grid](
+        x,
+        lora_a,
+        xa,
+        b_start_loc,
+        b_seq_lens,
+        b_rank_ids,
+        Rank_page_table=rank_page_table,
+        Ranks=ranks,
+        stride_xs=x.stride(0),
+        stride_xh=x.stride(1),
+        stride_las=lora_a.stride(0),
+        stride_lah=lora_a.stride(1),
+        stride_xas=xa.stride(0),
+        stride_xar=xa.stride(1),
+        stride_ptb=rank_page_table.stride(0),
+        BLOCK_M=BLOCK_M,
+        BLOCK_R=BLOCK_R,
+        BLOCK_H=BLOCK_H,
+        BLOCK_DMODEL=BLOCK_DMODEL,
+        num_warps=num_warps,
+        num_stages=1,
+    )
+    return xa
+
+
+@torch.inference_mode()
+def mbgmm_b(xa: Tensor, lora_b: Tensor, b_start_loc: Tensor,
+            b_seq_lens: Tensor, b_rank_ids: Tensor, rank_page_table: Tensor,
+            ranks: Tensor, max_seq_len: int, max_rank: int):
+    """mbgmm_b."""
+
+    assert xa.dim() == 2
+    assert lora_b.dim() == 2
+    assert rank_page_table.dim() == 2
+
+    head_o_size = lora_b.size(-1)
+    batch_size = len(b_seq_lens)
+
+    BLOCK_M = 32
+    BLOCK_R = _next_pow_of_2(max_rank)
+    if BLOCK_R < 16:
+        BLOCK_R = 16
+    BLOCK_HO = head_o_size
+    BLOCK_DMODEL = 64
+
+    num_warps = 4
+    grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)]
+    output = xa.new_empty((xa.size(0), BLOCK_HO))
+
+    _acc_b_mm_kernel[grid](
+        xa,
+        lora_b,
+        output,
+        b_start_loc,
+        b_seq_lens,
+        b_rank_ids,
+        Rank_page_table=rank_page_table,
+        Ranks=ranks,
+        stride_xas=xa.stride(0),
+        stride_xar=xa.stride(1),
+        stride_os=output.stride(0),
+        stride_oh=output.stride(1),
+        stride_lbs=lora_b.stride(0),
+        stride_lbh=lora_b.stride(1),
+        stride_ptb=rank_page_table.stride(0),
+        BLOCK_M=BLOCK_M,
+        BLOCK_R=BLOCK_R,
+        BLOCK_HO=BLOCK_HO,
+        BLOCK_DMODEL=BLOCK_DMODEL,
+        num_warps=num_warps,
+        num_stages=1,
+    )
+
+    return output
diff --git a/lmdeploy/pytorch/kernels/mbgmv.py b/lmdeploy/pytorch/kernels/mbgmv.py
new file mode 100644
index 0000000000..b45d2290ba
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/mbgmv.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import triton
+import triton.language as tl
+from torch import Tensor
+
+
+def _next_pow_of_2(x):
+    """get next power of 2."""
+    return 1 << (x - 1).bit_length()
+
+
+@triton.jit
+def _x_a_mv_kernel(
+    X,
+    LoRA_A,
+    XA,
+    B_rank_id,
+    Rank_page_table,
+    Ranks,
+    stride_xs,
+    stride_xh,
+    stride_las,
+    stride_lah,
+    stride_xas,
+    stride_xar,
+    stride_ptb,
+    BLOCK_R: tl.constexpr,
+    BLOCK_H: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+):
+    """xa mv kernel."""
+    cur_batch = tl.program_id(0)
+
+    r_off = tl.arange(0, BLOCK_R)
+    rank_id = tl.load(B_rank_id + cur_batch)
+    rank = tl.load(Ranks + rank_id)
+
+    page_table_off = rank_id * stride_ptb + r_off
+    rank_mask = r_off < rank
+    page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
+
+    dm_off = tl.arange(0, BLOCK_DMODEL)
+    rank_mask = r_off < rank
+
+    x_off = cur_batch * stride_xs
+    la_page_off = page_table * stride_las
+    acc = tl.zeros((BLOCK_R, ), dtype=tl.float32)
+
+    # compute acc
+    for start_h in range(0, BLOCK_H, BLOCK_DMODEL):
+        cur_dm_off = start_h + dm_off
+        h_mask = cur_dm_off < BLOCK_H
+
+        # load x
+        xh_off = cur_dm_off * stride_xh
+        x_mask = h_mask
+        x = tl.load(X + x_off + xh_off, mask=x_mask, other=0.0).to(tl.float32)
+
+        # load lora a
+        lah_off = cur_dm_off * stride_lah
+        la_mask = rank_mask[:, None] and h_mask[None, :]
+        la = tl.load(LoRA_A + la_page_off[:, None] + lah_off[None, :],
+                     mask=la_mask,
+                     other=0.0)
+
+        # compute
+        acc += tl.sum(x[None, :] * la, 1)
+
+    acc = acc.to(X.dtype.element_ty)
+    xa_off = cur_batch * stride_xas
+    tl.store(XA + xa_off + r_off * stride_xar, acc, mask=rank_mask)
+
+
+@triton.jit
+def _acc_b_mv_kernel(
+    XA,
+    LoRA_B,
+    Out,
+    B_rank_id,
+    Rank_page_table,
+    Ranks,
+    stride_xas,
+    stride_xar,
+    stride_os,
+    stride_oh,
+    stride_lbs,
+    stride_lbh,
+    stride_ptb,
+    BLOCK_R: tl.constexpr,
+    BLOCK_HO: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+):
+    """acc b mv kernel."""
+    cur_batch = tl.program_id(0)
+
+    r_off = tl.arange(0, BLOCK_R)
+    rank_id = tl.load(B_rank_id + cur_batch)
+    rank = tl.load(Ranks + rank_id)
+
+    page_table_off = rank_id * stride_ptb + r_off
+    rank_mask = r_off < rank
+    page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
+
+    dm_off = tl.arange(0, BLOCK_DMODEL)
+    rank_mask = r_off < rank
+    lb_page_off = page_table * stride_lbs
+
+    o_off = cur_batch * stride_os
+
+    xa_off = cur_batch * stride_xas
+    acc = tl.load(XA + xa_off + r_off * stride_xar, mask=rank_mask, other=0.0)
+
+    # compute output
+    for start_h in range(0, BLOCK_HO, BLOCK_DMODEL):
+        cur_dm_off = start_h + dm_off
+        h_mask = cur_dm_off < BLOCK_HO
+
+        # load lora b
+        lbh_off = cur_dm_off * stride_lbh
+        lb_mask = rank_mask[:, None] and h_mask[None, :]
+        lb = tl.load(LoRA_B + lb_page_off[:, None] + lbh_off[None, :],
+                     mask=lb_mask,
+                     other=0)
+
+        # compute
+        out = tl.sum(acc[:, None] * lb, 0)
+        out = out.to(lb.dtype)
+
+        # store o
+        oh_off = cur_dm_off * stride_oh
+        tl.store(Out + o_off + oh_off, out, mask=h_mask)
+
+
+@torch.inference_mode()
+def mbgmv_a(x: Tensor, lora_a: Tensor, b_rank_ids: Tensor,
+            rank_page_table: Tensor, ranks: Tensor, max_rank: int):
+    """mbgmv_a."""
+
+    assert x.dim() == 2
+    assert lora_a.dim() == 2
+    assert rank_page_table.dim() == 2
+
+    head_size = x.size(-1)
+    batch_size = x.size(0)
+
+    BLOCK_R = _next_pow_of_2(max_rank)
+    if BLOCK_R < 16:
+        BLOCK_R = 16
+    BLOCK_H = head_size
+    BLOCK_DMODEL = 64
+
+    num_warps = 4
+    grid = [batch_size]
+    xa = x.new_empty((x.size(0), BLOCK_R))
+
+    _x_a_mv_kernel[grid](
+        x,
+        lora_a,
+        xa,
+        b_rank_ids,
+        Rank_page_table=rank_page_table,
+        Ranks=ranks,
+        stride_xs=x.stride(0),
+        stride_xh=x.stride(1),
+        stride_las=lora_a.stride(0),
+        stride_lah=lora_a.stride(1),
+        stride_xas=xa.stride(0),
+        stride_xar=xa.stride(1),
+        stride_ptb=rank_page_table.stride(0),
+        BLOCK_R=BLOCK_R,
+        BLOCK_H=BLOCK_H,
+        BLOCK_DMODEL=BLOCK_DMODEL,
+        num_warps=num_warps,
+        num_stages=1,
+    )
+    return xa
+
+
+@torch.inference_mode()
+def mbgmv_b(xa: Tensor, lora_b: Tensor, b_rank_ids: Tensor,
+            rank_page_table: Tensor, ranks: Tensor, max_rank: int):
+    """mbgmv_b."""
+
+    assert xa.dim() == 2
+    assert lora_b.dim() == 2
+    assert rank_page_table.dim() == 2
+
+    head_o_size = lora_b.size(-1)
+    batch_size = xa.size(0)
+
+    BLOCK_R = _next_pow_of_2(max_rank)
+    if BLOCK_R < 16:
+        BLOCK_R = 16
+    BLOCK_HO = head_o_size
+    BLOCK_DMODEL = 64
+
+    num_warps = 4
+    grid = [batch_size]
+    output = xa.new_empty((xa.size(0), BLOCK_HO))
+
+    _acc_b_mv_kernel[grid](
+        xa,
+        lora_b,
+        output,
+        b_rank_ids,
+        Rank_page_table=rank_page_table,
+        Ranks=ranks,
+        stride_xas=xa.stride(0),
+        stride_xar=xa.stride(1),
+        stride_lbs=lora_b.stride(0),
+        stride_lbh=lora_b.stride(1),
+        stride_os=output.stride(0),
+        stride_oh=output.stride(1),
+        stride_ptb=rank_page_table.stride(0),
+        BLOCK_R=BLOCK_R,
+        BLOCK_HO=BLOCK_HO,
+        BLOCK_DMODEL=BLOCK_DMODEL,
+        num_warps=num_warps,
+        num_stages=1,
+    )
+
+    return output
diff --git a/lmdeploy/pytorch/kernels/pagedattention.py b/lmdeploy/pytorch/kernels/pagedattention.py
index 3d8fbb0237..9212eb9568 100644
--- a/lmdeploy/pytorch/kernels/pagedattention.py
+++ b/lmdeploy/pytorch/kernels/pagedattention.py
@@ -9,6 +9,16 @@
 assert triton.__version__ >= '2.1.0'
 
 
+@triton.jit
+def _load_block_offsets(offset_ptr, block_id, is_unified_paging: tl.constexpr,
+                        BLOCK: tl.constexpr):
+    offs_n = tl.arange(0, BLOCK)
+    if is_unified_paging:
+        return tl.load(offset_ptr + block_id * BLOCK + offs_n)
+    else:
+        return tl.load(offset_ptr + block_id) * BLOCK + offs_n
+
+
 @triton.jit
 def _fwd_split_kernel(
     Q,
@@ -34,6 +44,7 @@ def _fwd_split_kernel(
     stride_boffb,
     kv_group_num,
     block_per_cta,
+    is_unified_paging: tl.constexpr,
     BLOCK_DMODEL: tl.constexpr,
     BLOCK_N: tl.constexpr,
 ):
@@ -53,10 +64,8 @@ def _fwd_split_kernel(
     offs_d = tl.arange(0, BLOCK_DMODEL)
     off_q = (cur_batch * stride_qbs + cur_head * stride_qh +
              offs_d * stride_qd)
-    off_k = (offs_n[:, None] * stride_kbs + cur_kv_head * stride_kh +
-             offs_d[None, :] * stride_kd)
-    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
-             offs_d[None, :] * stride_vd)
+    off_k = (cur_kv_head * stride_kh + offs_d[None, :] * stride_kd)
+    off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)
 
     q = tl.load(Q + off_q).to(tl.float32)
 
@@ -76,7 +85,8 @@ def _fwd_split_kernel(
 
     # load block offset
     start_block_id = loop_start // BLOCK_N
-    b_offset = tl.load(block_offset_ptrs + start_block_id)
+    b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
+                                   is_unified_paging, BLOCK_N)
 
     for start_n in range(loop_start, loop_end, BLOCK_N):
         start_n = tl.multiple_of(start_n, BLOCK_N)
@@ -85,13 +95,13 @@ def _fwd_split_kernel(
 
         # -- compute qk ----
         k = tl.load(
-            k_ptrs + b_offset * BLOCK_N * stride_kbs,
+            k_ptrs + b_offset[:, None] * stride_kbs,
             mask=mask,
             other=0.0,
         )
 
         v = tl.load(
-            v_ptrs + b_offset * BLOCK_N * stride_vbs,
+            v_ptrs + b_offset[:, None] * stride_vbs,
             mask=mask,
             other=0.0,
         )
@@ -99,7 +109,8 @@ def _fwd_split_kernel(
         # prefetch b_offset
         if start_n + BLOCK_N < loop_end:
             start_block_id += 1
-            b_offset = tl.load(block_offset_ptrs + start_block_id)
+            b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
+                                           is_unified_paging, BLOCK_N)
 
         qk = tl.sum(q[None, :] * k, 1)
         qk *= sm_scale
@@ -208,6 +219,7 @@ def _fwd_kernel(
     stride_od,
     stride_boffb,
     kv_group_num,
+    is_unified_paging: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_DMODEL: tl.constexpr,
     BLOCK_N: tl.constexpr,
@@ -232,10 +244,8 @@ def _fwd_kernel(
     offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
     off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
              cur_head * stride_qh + offs_d[None, :] * stride_qd)
-    off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
-             offs_d[:, None] * stride_kd)
-    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
-             offs_d[None, :] * stride_vd)
+    off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd)
+    off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)
 
     q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
 
@@ -251,25 +261,27 @@ def _fwd_kernel(
 
     block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
 
-    b_offset = tl.load(block_offset_ptrs)
+    b_offset = _load_block_offsets(block_offset_ptrs, 0, is_unified_paging,
+                                   BLOCK_N)
     for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N):
         start_n = tl.multiple_of(start_n, BLOCK_N)
 
         # -- compute qk ----
         k = tl.load(
-            k_ptrs + b_offset * BLOCK_N * stride_kbs,
+            k_ptrs + b_offset[None, :] * stride_kbs,
             mask=(start_n + offs_n[None, :]) < cur_batch_kv_len,
             other=0.0,
         )
 
         v = tl.load(
-            v_ptrs + b_offset * BLOCK_N * stride_vbs,
+            v_ptrs + b_offset[:, None] * stride_vbs,
             mask=(start_n + offs_n[:, None]) < cur_batch_kv_len,
             other=0.0,
         )
         if start_n + BLOCK_N < cur_batch_kv_len:
             start_block_id = start_n // BLOCK_N + 1
-            b_offset = tl.load(block_offset_ptrs + start_block_id)
+            b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
+                                           is_unified_paging, BLOCK_N)
 
         qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
         qk += tl.dot(q, k)
@@ -305,7 +317,7 @@ def _fwd_kernel(
     tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
 
 
-@torch.no_grad()
+@torch.inference_mode()
 def paged_attention_fwd(
     q: Tensor,
     k: Tensor,
@@ -316,7 +328,6 @@ def paged_attention_fwd(
     b_seq_len: Tensor,
     b_kv_seq_len: Tensor,
     max_input_len: int,
-    BLOCK: int = 64,
 ):
     """Paged Attention forward.
 
@@ -333,6 +344,13 @@ def paged_attention_fwd(
         BLOCK (int): The kernel block size.
     """
 
+    def _kernel_meta():
+        device = q.device
+        device_idx = device.index
+        device_type = device.type
+        stream = get_cuda_stream(device_idx)
+        return dict(device=device, device_type=device_type, stream=stream)
+
     # shape constraints
     Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
     assert Lq == Lk and Lk == Lv
@@ -344,10 +362,10 @@ def paged_attention_fwd(
 
     num_warps = 4 if Lk <= 64 else 8
 
-    device = q.device
-    device_idx = device.index
-    device_type = device.type
-    stream = get_cuda_stream(device_idx)
+    is_unified_paging = k.size(1) == 1
+    BLOCK = 64 if is_unified_paging else k.size(1)
+
+    kernel_meta = _kernel_meta()
     is_decoding = q.shape[-3] == b_seq_len.size(0)
     if not is_decoding:
         grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
@@ -374,14 +392,13 @@ def paged_attention_fwd(
                           o.stride(-1),
                           block_offsets.stride(0),
                           kv_group_num=kv_group_num,
+                          is_unified_paging=is_unified_paging,
                           BLOCK_M=BLOCK,
                           BLOCK_DMODEL=Lk,
                           BLOCK_N=BLOCK,
                           num_warps=num_warps,
                           num_stages=1,
-                          stream=stream,
-                          device=device_idx,
-                          device_type=device_type)
+                          **kernel_meta)
     else:
         SPLIT_K = 4
         grid = (batch, head, SPLIT_K)
@@ -410,13 +427,12 @@ def paged_attention_fwd(
                                 stride_boffb=block_offsets.stride(0),
                                 kv_group_num=kv_group_num,
                                 block_per_cta=block_per_cta,
+                                is_unified_paging=is_unified_paging,
                                 BLOCK_DMODEL=Lk,
                                 BLOCK_N=BLOCK,
                                 num_warps=4,
                                 num_stages=1,
-                                stream=stream,
-                                device=device_idx,
-                                device_type=device_type)
+                                **kernel_meta)
 
         grid = (batch, head)
         _reduce_split_kernel[grid](acc,
@@ -432,6 +448,4 @@ def paged_attention_fwd(
                                    BLOCK_DMODEL=Lk,
                                    num_warps=num_warps,
                                    num_stages=1,
-                                   stream=stream,
-                                   device=device_idx,
-                                   device_type=device_type)
+                                   **kernel_meta)
diff --git a/lmdeploy/pytorch/kernels/rerope_attention.py b/lmdeploy/pytorch/kernels/rerope_attention.py
index 2d600d5bb2..ad3633b189 100644
--- a/lmdeploy/pytorch/kernels/rerope_attention.py
+++ b/lmdeploy/pytorch/kernels/rerope_attention.py
@@ -1,6 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import torch
-import torch.utils.benchmark as benchmark
 import triton
 import triton.language as tl
 
@@ -219,6 +218,7 @@ def rerope_attention_fwd(q1,
 if __name__ == '__main__':
 
     def test_rerope():
+        import torch.utils.benchmark as benchmark
         Z = 1
         H = 40
         N_CTX = 2176
diff --git a/lmdeploy/pytorch/kernels/rms_norm.py b/lmdeploy/pytorch/kernels/rms_norm.py
index d247c4424c..1a3a009061 100644
--- a/lmdeploy/pytorch/kernels/rms_norm.py
+++ b/lmdeploy/pytorch/kernels/rms_norm.py
@@ -27,9 +27,18 @@ def rms_norm_kernel(input, weight, output, input_row_stride, n_cols, eps,
     tl.store(out_ptr + offsets, out, mask=offsets < n_cols)
 
 
+@torch.inference_mode()
 def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6):
     """rms norm."""
-    feat_size = weight.size(-1)
+
+    def _kernel_meta():
+        device = hidden_states.device
+        device_idx = device.index
+        device_type = device.type
+        stream = get_cuda_stream(device_idx)
+        return dict(device=device, device_type=device_type, stream=stream)
+
+    feat_size = weight.shape[0]
     seq_len = hidden_states.numel() // hidden_states.size(-1)
     input_stride = hidden_states.stride(-2)
 
@@ -37,13 +46,8 @@ def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6):
 
     out = torch.empty_like(hidden_states)
 
-    device = hidden_states.device
-    device_idx = device.index
-    device_type = device.type
-    stream = get_cuda_stream(device_idx)
-    grid = [
-        seq_len,
-    ]
+    kernel_meta = _kernel_meta()
+    grid = (seq_len, )
     rms_norm_kernel[grid](hidden_states,
                           weight,
                           out,
@@ -54,9 +58,7 @@ def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6):
                           BLOCK_N,
                           num_warps=4,
                           num_stages=2,
-                          stream=stream,
-                          device=device_idx,
-                          device_type=device_type)
+                          **kernel_meta)
 
     return out
 
diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py
index 2cc0cd7964..0290ee0407 100644
--- a/lmdeploy/pytorch/messages.py
+++ b/lmdeploy/pytorch/messages.py
@@ -3,12 +3,12 @@
 import time
 from copy import deepcopy
 from dataclasses import dataclass, field
-from typing import Any, Dict, List, Sequence
+from typing import Any, Dict, List
 
 import torch
 from torch import Tensor
 
-from .block import LogicalTokenBlock
+from .block import LogicalTokenBlocks
 
 
 class SamplingParam:
@@ -16,8 +16,8 @@ class SamplingParam:
 
     def __init__(
         self,
-        top_p: float = 0.8,
-        top_k: int = None,
+        top_p: float = 1.0,
+        top_k: int = 1,
         temperature: float = 0.8,
         repetition_penalty: float = 1.0,
         ignore_eos: bool = False,
@@ -59,16 +59,17 @@ def _new_msg_id():
 class SchedulerSession:
     """Scheduler session."""
 
-    def __init__(self, session_id: int) -> None:
+    def __init__(self, session_id: int, block_size: int) -> None:
         self.session_id = session_id
+        self.block_size = block_size
         self.status: MessageStatus = MessageStatus.RUNNING
         self.sequences: Dict[int, SchedulerSequence] = dict()
 
-    def add_sequence(
-            self,
-            token_ids: Tensor,
-            max_output_len: int = 512,
-            sampling_param: SamplingParam = None) -> 'SchedulerSequence':
+    def add_sequence(self,
+                     token_ids: Tensor,
+                     max_output_len: int = 512,
+                     sampling_param: SamplingParam = None,
+                     adapter_id: int = -1) -> 'SchedulerSequence':
         """Add a new message."""
         if not isinstance(token_ids, Tensor):
             token_ids = torch.tensor(token_ids)
@@ -80,9 +81,11 @@ def add_sequence(
         seq = SchedulerSequence(seq_id=_new_msg_id(),
                                 token_ids=token_ids,
                                 session=self,
+                                block_size=self.block_size,
                                 status=MessageStatus.WAITING,
                                 remain_output_len=max_output_len,
                                 sampling_param=sampling_param,
+                                adapter_id=adapter_id,
                                 arrive_time=time.time())
         self.sequences[seq.seq_id] = seq
         return seq
@@ -106,11 +109,13 @@ def fork_sequence(
             seq_id=_new_msg_id(),
             token_ids=token_ids,
             session=self,
+            block_size=self.block_size,
             history_token_ids=seq.history_token_ids.copy(),
-            status=seq.status,
             remain_output_len=max_output_len,
-            logical_blocks=deepcopy(seq.logical_blocks),
             sampling_param=sampling_param,
+            status=seq.status,
+            logical_blocks=seq.logical_blocks.clone(),
+            adapter_id=seq.adapter_id,
             arrive_time=time.time(),
             meta=deepcopy(seq.meta))
 
@@ -124,16 +129,22 @@ class SchedulerSequence:
     seq_id: int
     token_ids: Tensor
     session: SchedulerSession
+    block_size: int
     history_token_ids: list = field(default_factory=list)
     remain_output_len: int = 0
     sampling_param: SamplingParam = field(default_factory=SamplingParam)
     status: MessageStatus = MessageStatus.WAITING
-    logical_blocks: Sequence[LogicalTokenBlock] = field(default_factory=list)
+    logical_blocks: LogicalTokenBlocks = None
     sender_id: int = -1
     req_id: int = -1
+    adapter_id: int = -1
     arrive_time: float = 0.0
     meta: Any = None
 
+    def __post_init__(self):
+        self.logical_blocks = self.logical_blocks or LogicalTokenBlocks(
+            self.block_size)
+
     @property
     def history_len(self) -> int:
         """get history length."""
@@ -146,46 +157,27 @@ def session_id(self) -> int:
 
     def num_logical_tokens(self) -> int:
         """num logitcal tokens."""
-        if len(self.logical_blocks) == 0:
-            return 0
-        else:
-            return sum(block.num_tokens for block in self.logical_blocks)
+        return self.logical_blocks.num_tokens()
+
+    def num_all_tokens(self) -> int:
+        """num all tokens."""
+        return len(self.token_ids) + self.history_len
 
     def num_required_tokens(self) -> int:
         """num required tokens."""
-        num_all_tokens = len(self.token_ids) + self.history_len
+        num_all_tokens = self.num_all_tokens()
         num_logical_tokens = self.num_logical_tokens()
         return num_all_tokens - num_logical_tokens
 
-    def append_tokens(self, num_tokens: int, block_size: int):
-        """Append new tokens, update logical blocks.
-
-        Args:
-            num_tokens (int): Number of tokens.
-            block_size (int): Size of block.
-        """
-        if len(self.logical_blocks) == 0:
-            remain_num_tokens = num_tokens
-            next_block_id = 0
-        else:
-            last_block = self.logical_blocks[-1]
-            num_empty_slots = last_block.get_num_empty_slots()
-            num_append_slots = min(num_tokens, num_empty_slots)
-            last_block.append_tokens(num_append_slots)
-            remain_num_tokens = num_tokens - num_append_slots
-            next_block_id = last_block.block_id + 1
-
-        for block_id_offset, msg_offset in enumerate(
-                range(0, remain_num_tokens, block_size)):
-            num_tokens = min(remain_num_tokens - msg_offset, block_size)
-            logical_block = LogicalTokenBlock(next_block_id + block_id_offset,
-                                              block_size)
-            logical_block.append_tokens(num_tokens=num_tokens)
-            self.logical_blocks.append(logical_block)
-
-    def update_token_ids(self, token_ids: Tensor):
+    def num_required_blocks(self) -> int:
+        """num required blocks."""
+        return self.logical_blocks.num_required_blocks(
+            self.num_required_tokens())
+
+    def update_token_ids(self, token_ids: Tensor, update_history: bool = True):
         """Update token ids, old token ids will be added to history."""
-        self.history_token_ids += self.token_ids.tolist()
+        if update_history:
+            self.history_token_ids += self.token_ids.tolist()
         if not isinstance(token_ids, Tensor):
             token_ids = self.token_ids.new_tensor(token_ids)
         if token_ids.dim() == 0:
@@ -195,9 +187,10 @@ def update_token_ids(self, token_ids: Tensor):
 
     def set_step(self, step: int):
         """set step."""
-        assert step < self.history_len
-        history_token_ids = torch.cat(self.history_token_ids)
-        new_history_ids = [history_token_ids[:step]]
+        assert step <= self.history_len
+        history_token_ids = torch.tensor(self.history_token_ids,
+                                         dtype=torch.long)
+        new_history_ids = self.history_token_ids[:step]
         new_token_ids = torch.cat([history_token_ids[step:], self.token_ids])
         self.history_token_ids = new_history_ids
         self.token_ids = new_token_ids
diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py
index 7c0ab727f4..98b18fe4ff 100644
--- a/lmdeploy/pytorch/models/chatglm2.py
+++ b/lmdeploy/pytorch/models/chatglm2.py
@@ -211,7 +211,6 @@ def _contiguous_batching_forward(
         context_layer = torch.empty_like(query_layer)
 
         block_offsets = context.block_offsets
-        block_size = cache_k.size(1)
 
         paged_attention_fwd(query_layer,
                             cache_k,
@@ -221,8 +220,7 @@ def _contiguous_batching_forward(
                             b_start_loc=q_start_loc,
                             b_seq_len=q_seq_length,
                             b_kv_seq_len=kv_seq_length,
-                            max_input_len=max_seq_len,
-                            BLOCK=block_size)
+                            max_input_len=max_seq_len)
 
         context_layer = context_layer.transpose(1, 0).flatten(-2)
 
diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py
index 756df31164..a76a653fc1 100644
--- a/lmdeploy/pytorch/models/falcon.py
+++ b/lmdeploy/pytorch/models/falcon.py
@@ -287,8 +287,7 @@ def _contiguous_batching_forward(
                                 b_start_loc=q_start_loc,
                                 b_seq_len=q_seq_length,
                                 b_kv_seq_len=kv_seq_length,
-                                max_input_len=max_seq_len,
-                                BLOCK=block_size)
+                                max_input_len=max_seq_len)
 
         else:
             alibi_paged_attention_fwd(q=query_layer,
diff --git a/lmdeploy/pytorch/models/functional.py b/lmdeploy/pytorch/models/functional.py
index 215f163251..2ca7589fe4 100644
--- a/lmdeploy/pytorch/models/functional.py
+++ b/lmdeploy/pytorch/models/functional.py
@@ -88,7 +88,7 @@ def get_alibi_biases(n_heads: int, mask: torch.Tensor):
     return distance * m[None, :, None, None]
 
 
-@torch.no_grad()
+@torch.inference_mode()
 def attention_forward_with_paged_attention(
     hidden_states: Tensor,
     history_lengths: Sequence,
@@ -198,7 +198,6 @@ def attention_forward_with_paged_attention(
             b_seq_len=q_seq_length,
             b_kv_seq_len=kv_seq_length,
             max_input_len=max_seq_len,
-            BLOCK=block_size,
         )
     else:
         if bias_type == 'alibi':
diff --git a/lmdeploy/pytorch/paging/__init__.py b/lmdeploy/pytorch/paging/__init__.py
index 43d28974db..bcd27eee7c 100644
--- a/lmdeploy/pytorch/paging/__init__.py
+++ b/lmdeploy/pytorch/paging/__init__.py
@@ -1,4 +1,4 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from .scheduler import BlockTable, Scheduler
+from .scheduler import Scheduler
 
-__all__ = ['Scheduler', 'BlockTable']
+__all__ = ['Scheduler']
diff --git a/lmdeploy/pytorch/paging/block_manager.py b/lmdeploy/pytorch/paging/block_manager.py
index 6debdd792e..15fb39d643 100644
--- a/lmdeploy/pytorch/paging/block_manager.py
+++ b/lmdeploy/pytorch/paging/block_manager.py
@@ -1,75 +1,235 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 # modify from: https://github.com/vllm-project/vllm
-from typing import Dict, List
+from typing import Dict
+
+import numpy as np
 
-from ..block import PhysicalTokenBlock
 from ..messages import SchedulerSequence
 
 
-class BlockAllocator:
-    """The block allocator.
+class LogicalMemory:
+    """Logical memory blocks."""
 
-    The allocator won't allocate real memory. It is used to support
-    block manager.
+    def __init__(self, num_blocks: int) -> None:
+        self._num_blocks = num_blocks
 
-    Args:
-        block_size (int): The num tokens of each block.
-        block_num (int): Total blocks.
-        device (str): The device name.
+        self.phy_map: np.ndarray = np.zeros(self._num_blocks, dtype=np.int64)
+
+    def get_physical_blocks(self, logical_address: np.ndarray):
+        """get physical address."""
+        return self.phy_map[logical_address]
+
+    def num_blocks(self):
+        """get num blocks."""
+        return self._num_blocks
+
+
+class PhysicalMemory:
+    """physical memory blocks."""
+
+    def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int) -> None:
+        self._num_cpu_blocks = num_cpu_blocks
+        self._num_gpu_blocks = num_gpu_blocks
+        self._num_blocks = num_cpu_blocks + num_gpu_blocks
+
+        self.ref_count: np.ndarray = np.zeros((self._num_blocks, ),
+                                              dtype=np.int64)
+        self.swap_count: np.ndarray = np.zeros((self._num_blocks, ),
+                                               dtype=np.int64)
+
+    def num_cpu_blocks(self):
+        """get num cpu blocks."""
+        return self._num_cpu_blocks
+
+    def num_gpu_blocks(self):
+        """get num gpu blocks."""
+        return self._num_gpu_blocks
+
+
+class PhysicalAllocator:
+    """The physical block allocator.
+
+    The allocator won't allocate real memory. It is used to support block
+    manager.
     """
 
-    def __init__(self, block_size: int, block_num: int, device: str):
-        self.block_size = block_size
-        self.block_num = block_num
-        self.device = device
+    def __init__(self,
+                 memory: PhysicalMemory,
+                 num_blocks: int,
+                 offset: int = 0):
+        self._mem = memory
+        self._num_blocks = num_blocks
+        self._offset = offset
 
-        free_blocks: List[PhysicalTokenBlock] = [
-            PhysicalTokenBlock(device, i, block_size) for i in range(block_num)
-        ]
-        self.free_blocks = free_blocks
+        self._free_blocks = np.arange(num_blocks, dtype=np.int64) + offset
+        self._free_count = num_blocks
 
-    def allocate(self):
+    def allocate(self, num_blocks: int):
         """Allocate block from block pool."""
-        if len(self.free_blocks) > 0:
-            block = self.free_blocks.pop(0)
-            block.ref_count += 1
-            return block
+        if self.get_num_free_blocks() >= num_blocks:
+            num_used = self._num_blocks - self._free_count
+            blocks = self._free_blocks[num_used:num_used + num_blocks]
+            self._mem.ref_count.put(blocks, 1)
+            self._free_count -= num_blocks
+            return blocks
         else:
-            raise MemoryError(f'No free {self.device} memory blocks.')
+            raise MemoryError('No enough free memory blocks.')
 
-    def free(self, block: PhysicalTokenBlock):
+    def free(self, blocks: np.ndarray):
         """Free block to block pool."""
-        if block.ref_count == 0:
-            raise ValueError(f'Double free {block}.')
-        block.ref_count -= 1
-        if block.ref_count == 0:
-            self.free_blocks.append(block)
+        np.add.at(self._mem.ref_count, blocks, -1)
+        ref_count = self.get_ref_count(blocks)
+        freed_blocks = blocks[ref_count == 0]
+        num_freed_blocks = len(freed_blocks)
+        if num_freed_blocks > 0:
+            num_used = self._num_blocks - self._free_count
+            self._free_blocks[num_used -
+                              num_freed_blocks:num_used] = freed_blocks
+            self._free_count += num_freed_blocks
+        return freed_blocks
 
     def get_num_free_blocks(self):
         """Get numbers of free blocks."""
-        return len(self.free_blocks)
+        return self._free_count
+
+    def get_ref_count(self, blocks: np.ndarray):
+        """get ref count."""
+        return self._mem.ref_count[blocks]
+
 
+class LogicalAllocator:
+    """The logical block allocator."""
 
-BlockTable = List[PhysicalTokenBlock]
+    def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int) -> None:
+        self._log_mem = LogicalMemory(num_cpu_blocks + num_gpu_blocks)
+        self._phy_mem = PhysicalMemory(num_cpu_blocks, num_gpu_blocks)
+
+        self._cpu_mem_offset = num_gpu_blocks
+        self._gpu_allocator = PhysicalAllocator(self._phy_mem, num_gpu_blocks,
+                                                0)
+        self._cpu_allocator = PhysicalAllocator(self._phy_mem, num_cpu_blocks,
+                                                self._cpu_mem_offset)
+
+        num_blocks = self._log_mem.num_blocks()
+        self._num_blocks = num_blocks
+        self._free_blocks = np.arange(num_blocks)
+        self._free_count = num_blocks
+
+    def get_phy_allocator(self, device: str):
+        """get allocator."""
+        if device == 'gpu':
+            return self._gpu_allocator
+        elif device == 'cpu':
+            return self._cpu_allocator
+        else:
+            raise ValueError(f'Unsupported device: {device}')
+
+    def allocate(self, num_blocks: int, device: str = 'gpu'):
+        """allocate logical blocks."""
+        if num_blocks == 0:
+            return np.empty((0, ), dtype=np.int64)
+        phy_allocator = self.get_phy_allocator(device)
+        logical_enable = self.get_num_free_blocks() >= num_blocks
+        physical_enable = phy_allocator.get_num_free_blocks() >= num_blocks
+        if logical_enable and physical_enable:
+            num_used = self._num_blocks - self._free_count
+            blocks = self._free_blocks[num_used:num_used + num_blocks]
+            phy_blocks = phy_allocator.allocate(num_blocks)
+            self._log_mem.phy_map.put(blocks, phy_blocks)
+            self._free_count -= num_blocks
+            return blocks.copy()
+        else:
+            raise MemoryError('No enough free memory blocks.')
+
+    def free(self, blocks: np.ndarray):
+        """Free logical block."""
+        phy_blocks = self.get_physical_blocks(blocks)
+
+        cpu_blocks = phy_blocks[phy_blocks >= self._cpu_mem_offset]
+        gpu_blocks = phy_blocks[phy_blocks < self._cpu_mem_offset]
+        if len(cpu_blocks) > 0:
+            self._cpu_allocator.free(cpu_blocks)
+        if len(gpu_blocks) > 0:
+            self._gpu_allocator.free(gpu_blocks)
+
+        ref_count = self._phy_mem.ref_count[phy_blocks]
+        freed_blocks = blocks[ref_count == 0]
+        num_freed_blocks = len(freed_blocks)
+        if num_freed_blocks > 0:
+            num_used = self._num_blocks - self._free_count
+            self._free_blocks[num_used -
+                              num_freed_blocks:num_used] = freed_blocks
+            self._free_count += num_freed_blocks
+
+    def get_num_free_blocks(self):
+        """Get numbers of free blocks."""
+        return self._free_count
+
+    def get_physical_blocks(self, blocks: np.ndarray):
+        """get physical address."""
+        return self._log_mem.get_physical_blocks(blocks)
+
+    def get_ref_count(self, blocks: np.ndarray):
+        """get ref count."""
+        phy_blocks = self.get_physical_blocks(blocks)
+        return self._phy_mem.ref_count[phy_blocks]
+
+    def add_ref_count(self, blocks: np.ndarray, value: np.ndarray):
+        """update ref count."""
+        phy_blocks = self.get_physical_blocks(blocks)
+        np.add.at(self._phy_mem.ref_count, phy_blocks, value)
+
+    def cpu_mem_offset(self):
+        """get cpu mem offset in unified physical memory."""
+        return self._cpu_mem_offset
+
+    def count_cpu_blocks(self, blocks: np.ndarray):
+        """count cpu blocks."""
+        phy_blocks = self.get_physical_blocks(blocks)
+        return np.count_nonzero(phy_blocks >= self.cpu_mem_offset())
+
+    def count_gpu_blocks(self, blocks: np.ndarray):
+        """count gpu blocks."""
+        phy_blocks = self.get_physical_blocks(blocks)
+        return np.count_nonzero(phy_blocks < self.cpu_mem_offset())
+
+    def update_phy_map(self, log_blocks: np.ndarray, phy_blocks: np.ndarray):
+        """update physical map."""
+        assert len(phy_blocks) == len(log_blocks)
+        self._log_mem.phy_map.put(log_blocks, phy_blocks)
+
+    def on_device(self, blocks: np.ndarray, device: str):
+        """blocks on given device."""
+        if len(blocks) == 0:
+            return False
+
+        # TODO: check all blocks
+        cpu_mem_offset = self.cpu_mem_offset()
+
+        phy_blocks = self.get_physical_blocks(blocks[:1])
+        if phy_blocks[0] < cpu_mem_offset:
+            phy_device = 'gpu'
+        else:
+            phy_device = 'cpu'
+        return device == phy_device
+
+
+BlockTable = np.ndarray
 
 
 class BlockManager:
     """Manage the usage of blocks, generate block tables.
 
     Args:
-        block_size (int): The num tokens of each block.
         num_gpu_blocks (int): number of gpu blocks.
         num_cpu_blocks (int): number of cpu blocks.
     """
 
-    def __init__(self, block_size: int, num_gpu_blocks: int,
-                 num_cpu_blocks: int) -> None:
-        self.block_size = block_size
+    def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
         self.num_gpu_blocks = num_gpu_blocks
         self.num_cpu_blocks = num_cpu_blocks
 
-        self.gpu_allocator = BlockAllocator(block_size, num_gpu_blocks, 'gpu')
-        self.cpu_allocator = BlockAllocator(block_size, num_cpu_blocks, 'cpu')
+        self.allocator = LogicalAllocator(num_cpu_blocks, num_gpu_blocks)
 
         self.block_tables: Dict[int, BlockTable] = {}
 
@@ -79,181 +239,187 @@ def get_block_table(self, msg: SchedulerSequence):
         Args:
             msg (SchedulerSequence): The msg to get block table.
         """
-        seq_id = msg.seq_id
-        if seq_id in self.block_tables:
-            return self.block_tables[seq_id]
-        else:
-            return None
+        logical_blocks = msg.logical_blocks
+        return self.allocator.get_physical_blocks(
+            logical_blocks.get_real_blocks())
 
     def can_allocate(self, msg: SchedulerSequence):
         """Return if physical block can be allocated for given message."""
-        required_blocks = len(msg.logical_blocks)
-        return required_blocks <= self.gpu_allocator.get_num_free_blocks()
+        num_required_blocks = msg.num_required_blocks()
+        num_free_phy = self.get_num_free_gpu_blocks()
+        return num_required_blocks <= num_free_phy
 
     def allocate(self, msg: SchedulerSequence):
         """Allocate physical blocks for given message according to logical
         blocks."""
-        assert msg.seq_id not in self.block_tables
-        block_table: BlockTable = []
         logical_blocks = msg.logical_blocks
-
-        for _ in logical_blocks:
-            phy_block = self.gpu_allocator.allocate()
-            block_table.append(phy_block)
-
-        self.block_tables[msg.seq_id] = block_table
-
-    def _free_block_table(self, block_table: BlockTable):
-        """Free physical blocks of given block table."""
-        for block in block_table:
-            if block.device == 'cpu':
-                self.cpu_allocator.free(block)
-            elif block.device == 'gpu':
-                self.gpu_allocator.free(block)
-            else:
-                raise ValueError(f'Can not free block {block}.')
+        num_required_tokens = msg.num_required_tokens()
+        num_required_blocks = logical_blocks.num_required_blocks(
+            num_required_tokens)
+        if num_required_blocks > 0:
+            blocks = self.allocator.allocate(num_required_blocks, 'gpu')
+            logical_blocks.append(blocks)
+            logical_blocks.add_tokens(num_required_tokens)
 
     def free(self, msg: SchedulerSequence):
         """Free all physical blocks allocated for the session."""
-        seq_id = msg.seq_id
-        if seq_id not in self.block_tables:
-            return
-
-        block_table = self.block_tables[seq_id]
-        self._free_block_table(block_table)
-        self.block_tables.pop(seq_id)
+        self.allocator.free(msg.logical_blocks.get_real_blocks())
+        msg.logical_blocks.reset()
 
     def can_append_slot(self, msg: SchedulerSequence):
         """Return true if the message can append new slot."""
-        seq_id = msg.seq_id
-        num_blocks = len(msg.logical_blocks)
-        assert seq_id in self.block_tables
-        block_table = self.block_tables[seq_id]
-        gpu_block_table = [
-            block for block in block_table if block.device == 'gpu'
-        ]
-        return num_blocks - len(
-            gpu_block_table) <= self.gpu_allocator.get_num_free_blocks()
+        return self.can_allocate(msg)
 
     def append_slot(self, msg: SchedulerSequence):
         """Append new slot to message."""
-        seq_id = msg.seq_id
-        logical_blocks = msg.logical_blocks
-
-        assert seq_id in self.block_tables
-        block_table = self.block_tables[seq_id]
-
-        while len(logical_blocks) > len(block_table):
-            block = self.gpu_allocator.allocate()
-            block_table.append(block)
+        return self.allocate(msg)
 
     def can_fork(self, from_msg: SchedulerSequence):
         """Return true if blocks can be folked."""
-        seq_id = from_msg.seq_id
-        assert seq_id in self.block_tables
         logical_blocks = from_msg.logical_blocks
-        if logical_blocks[-1].is_full():
-            # every block can be shared
+        if logical_blocks.last_block_size() == logical_blocks.get_block_size():
             return True
 
-        block_table = self.block_tables[seq_id]
-        device = block_table[-1].device
-        if device == 'cpu':
-            allocator = self.cpu_allocator
-        elif device == 'gpu':
-            allocator = self.gpu_allocator
+        cpu_mem_offset = self.allocator.cpu_mem_offset()
+        phy_block = self.allocator.get_physical_blocks(logical_blocks[-1])
+        if phy_block < cpu_mem_offset:
+            device = 'gpu'
         else:
-            raise ValueError(f'Unknown device {device}')
-        return allocator.get_num_free_blocks() >= 1
+            device = 'cpu'
+        phy_allocator = self.allocator.get_phy_allocator(device)
+        return phy_allocator.get_num_free_blocks() >= 1
 
     def fork(self, from_msg: SchedulerSequence, to_msg: SchedulerSequence):
         """Fork new message."""
-        from_msg_id = from_msg.seq_id
-        from_block_table = self.block_tables[from_msg_id]
 
-        block_table: BlockTable = []
-        for block in from_block_table[:-1]:
-            block.ref_count += 1
-            block_table.append(block)
+        def _copy_lask_block(logical_blocks, copy_map):
+            cpu_mem_offset = self.allocator.cpu_mem_offset()
+            phy_block = self.allocator.get_physical_blocks(logical_blocks[-1])
+            if phy_block < cpu_mem_offset:
+                device = 'gpu'
+            else:
+                device = 'cpu'
+            block = self.allocator.allocate(1, device)
+            new_phy_block = self.allocator.get_physical_blocks(block[0])
+            copy_map[phy_block] = new_phy_block
+            return block[0]
 
-        # process last block
-        from_logical_blocks = from_msg.logical_blocks
-        last_block = from_block_table[-1]
+        logical_blocks = from_msg.logical_blocks
         copy_map: Dict[int, int] = dict()
-        if from_logical_blocks[-1].is_full():
-            last_block.ref_count += 1
-            block_table.append(last_block)
+        if logical_blocks.last_block_size() == logical_blocks.get_block_size():
+            self.allocator.add_ref_count(logical_blocks, 1)
         else:
-            device = last_block.device
-            if device == 'cpu':
-                allocator = self.cpu_allocator
-            elif device == 'gpu':
-                allocator = self.gpu_allocator
-            block = allocator.allocate()
-            block_table.append(block)
-            copy_map[last_block.block_id] = block.block_id
-
-        self.block_tables[to_msg.seq_id] = block_table
+            new_logical_blocks = logical_blocks.clone()
+            self.allocator.add_ref_count(new_logical_blocks[:-1], 1)
+            block = _copy_lask_block(logical_blocks, copy_map)
+            new_logical_blocks[-1] = block
+            to_msg.logical_blocks = new_logical_blocks
+
         return copy_map
 
-    def _can_swap(self, msg: SchedulerSequence, allocator: BlockAllocator):
-        """Check if swap can be performed."""
-        block_table = self.get_block_table(msg)
-        assert block_table is not None
-
-        num_free_blocks = allocator.get_num_free_blocks()
-        return num_free_blocks > len(block_table)
-
-    def can_swap_in(self, msg: SchedulerSequence):
-        """Check if the message can be swapped in."""
-        return self._can_swap(msg, self.gpu_allocator)
-
-    def swap_in(self, msg: SchedulerSequence):
-        """Swap the message into GPU."""
-        block_table = self.get_block_table(msg)
-        assert block_table is not None
-
-        swap_map: Dict[int, int] = {}
-        for i in range(len(block_table)):
-            block = block_table[i]
-            if block.device == 'cpu':
-                new_block = self.gpu_allocator.allocate()
-                swap_map[block.block_id] = new_block.block_id
-                block_table[i] = new_block
-                self.cpu_allocator.free(block)
-
-        return swap_map
-
-    def can_swap_out(self, msg: SchedulerSequence):
-        """Check if the message can be swap out."""
-        return self._can_swap(msg, self.cpu_allocator)
-
-    def swap_out(self, msg: SchedulerSequence):
-        """Swap the message out to host."""
-        block_table = self.get_block_table(msg)
-        assert block_table is not None
-
-        swap_map: Dict[int, int] = {}
-        for i in range(len(block_table)):
-            block = block_table[i]
-            if block.device == 'gpu':
-                new_block = self.cpu_allocator.allocate()
-                swap_map[block.block_id] = new_block.block_id
-                block_table[i] = new_block
-                self.gpu_allocator.free(block)
-
-        return swap_map
-
-    def reset(self) -> None:
-        """Reset block table."""
-        for block_table in self.block_tables.values():
-            self._free_block_table(block_table)
-        self.block_tables.clear()
+    def try_swap_out(self, msg: SchedulerSequence):
+        """Try swap msg out."""
+        swap_map = dict()
+        logical_blocks = msg.logical_blocks
+        cpu_mem_offset = self.allocator.cpu_mem_offset()
+        phy_blocks = self.allocator.get_physical_blocks(logical_blocks)
+        cpu_allocator = self.allocator.get_phy_allocator('cpu')
+        gpu_allocator = self.allocator.get_phy_allocator('gpu')
+
+        def _can_swap():
+            """check swap."""
+            if len(logical_blocks) == 0:
+                return False
+
+            # we only support all blocks of a sequence on same device
+            if phy_blocks[0] >= cpu_mem_offset:
+                return False
+
+            # no free blocks
+            num_free = self.get_num_free_cpu_blocks()
+            if num_free < len(phy_blocks):
+                return False
+
+            # don't swap sequence with multiple reference
+            ref_count = gpu_allocator.get_ref_count(phy_blocks)
+            if np.count_nonzero(ref_count != 1) > 0:
+                return False
+
+            return True
+
+        def _do_swap():
+            """perform swap."""
+            new_blocks = cpu_allocator.allocate(len(logical_blocks))
+
+            old_blocks = phy_blocks
+            swap_map = dict(zip(old_blocks, new_blocks - self.num_gpu_blocks))
+
+            gpu_allocator.free(old_blocks)
+            self.allocator.update_phy_map(logical_blocks.get_real_blocks(),
+                                          new_blocks)
+            return True, swap_map
+
+        if not _can_swap():
+            return False, swap_map
+        else:
+            return _do_swap()
+
+    def try_swap_in(self, msg: SchedulerSequence):
+        """Try swap msg in."""
+        swap_map = dict()
+        logical_blocks = msg.logical_blocks
+        cpu_mem_offset = self.allocator.cpu_mem_offset()
+        phy_blocks = self.allocator.get_physical_blocks(logical_blocks)
+        cpu_allocator = self.allocator.get_phy_allocator('cpu')
+        gpu_allocator = self.allocator.get_phy_allocator('gpu')
+
+        def _can_swap():
+            """check swap."""
+            if len(logical_blocks) == 0:
+                return False
+
+            # we only support all blocks of a sequence on same device
+            if phy_blocks[0] < cpu_mem_offset:
+                return False
+
+            # no free blocks
+            num_free = self.get_num_free_gpu_blocks()
+            if num_free < len(phy_blocks):
+                return False
+
+            # don't swap sequence with multiple reference
+            ref_count = cpu_allocator.get_ref_count(phy_blocks)
+            if np.count_nonzero(ref_count != 1) > 0:
+                return False
+
+            return True
+
+        def _do_swap():
+            """perform swap."""
+            new_blocks = gpu_allocator.allocate(len(logical_blocks))
+
+            old_blocks = phy_blocks
+            swap_map = dict(zip(old_blocks - self.num_gpu_blocks, new_blocks))
+
+            cpu_allocator.free(old_blocks)
+            self.allocator.update_phy_map(logical_blocks.get_real_blocks(),
+                                          new_blocks)
+            return True, swap_map
+
+        if not _can_swap():
+            return False, swap_map
+        else:
+            return _do_swap()
 
     def get_num_free_gpu_blocks(self) -> int:
         """Get number of free gpu blocks."""
-        return self.gpu_allocator.get_num_free_blocks()
+        return self.allocator.get_phy_allocator('gpu').get_num_free_blocks()
 
     def get_num_free_cpu_blocks(self) -> int:
         """Get number of free cpu blocks."""
-        return self.cpu_allocator.get_num_free_blocks()
+        return self.allocator.get_phy_allocator('cpu').get_num_free_blocks()
+
+    def on_device(self, msg: SchedulerSequence, device: str):
+        allocator = self.allocator
+        logical_blocks = msg.logical_blocks
+        return allocator.on_device(logical_blocks.get_real_blocks(), device)
diff --git a/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
index 0a1b05da5a..707aca20e5 100644
--- a/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
+++ b/lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
@@ -13,32 +13,28 @@ class BaseEvictionHelper:
     def __init__(self, block_manager: BlockManager):
         self.block_manager: BlockManager = block_manager
 
-    def can_swap_out(self, seq: SchedulerSequence):
-        """sequence can swap out."""
-        raise NotImplementedError('Not implemented.')
-
-    def can_swap_in(self, seq: SchedulerSequence):
-        """sequence can swap in."""
-        raise NotImplementedError('Not implemented.')
-
     def need_swap_in(self, seq: SchedulerSequence):
         """sequence need swap in."""
         raise NotImplementedError('Not implemented.')
 
-    def try_swap_out_seqs(self, seqs: SeqList, swap_out_map: Dict[int, int]):
-        """try swap sequence out."""
+    def try_swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int,
+                                                                      int]):
+        """try swap out."""
         raise NotImplementedError('Not implemented.')
 
-    def swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
-        """sequence swap in."""
+    def try_swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
+        """try swap in."""
         raise NotImplementedError('Not implemented.')
 
-    def swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int, int]):
-        """sequence swap out."""
-        raise NotImplementedError('Not implemented.')
+    def try_swap_out_seqs(self, seqs: SeqList, swap_out_map: Dict[int, int]):
+        """try swap sequence out."""
+        for seq in reversed(seqs):
+            if self.try_swap_out(seq, swap_out_map):
+                return True
+        return False
 
-    def try_swap_out(self, hanging: SeqList, waiting: SeqList,
-                     swap_out_map: Dict[int, int]):
+    def try_swap_out_unused(self, hanging: SeqList, waiting: SeqList,
+                            swap_out_map: Dict[int, int]):
         """try swap out hanging and waiting sequence."""
         if self.try_swap_out_seqs(hanging, swap_out_map):
             return True
diff --git a/lmdeploy/pytorch/paging/eviction_helper/copy_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/copy_eviction_helper.py
index 837906ae28..6f91580ea8 100644
--- a/lmdeploy/pytorch/paging/eviction_helper/copy_eviction_helper.py
+++ b/lmdeploy/pytorch/paging/eviction_helper/copy_eviction_helper.py
@@ -2,7 +2,7 @@
 from typing import Dict
 
 from ...messages import SchedulerSequence
-from .base_eviction_helper import BaseEvictionHelper, SeqList
+from .base_eviction_helper import BaseEvictionHelper
 
 
 class CopyEvictionHelper(BaseEvictionHelper):
@@ -11,44 +11,21 @@ class CopyEvictionHelper(BaseEvictionHelper):
     def __init__(self, block_manager):
         super().__init__(block_manager)
 
-    def can_swap_out(self, seq: SchedulerSequence):
-        """sequence can swap out."""
-        block_table = self.block_manager.get_block_table(seq)
-        if block_table is None or len(block_table) == 0:
-            return False
-        first_block = block_table[0]
-        device = first_block.device
-        return device == 'gpu'
-
-    def can_swap_in(self, seq: SchedulerSequence):
-        """sequence can swap in."""
-        return self.block_manager.can_swap_in(seq)
-
     def need_swap_in(self, seq: SchedulerSequence):
         """sequence need swap in."""
-        block_table = self.block_manager.get_block_table(seq)
-        if block_table is None or len(block_table) == 0:
-            return False
-        first_block = block_table[0]
-        device = first_block.device
-        return device == 'cpu'
-
-    def swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
-        """sequence swap in."""
-        swap_in_map.update(self.block_manager.swap_in(seq))
-
-    def swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int, int]):
-        """sequence swap out."""
-        swap_out_map.update(self.block_manager.swap_out(seq))
-
-    def try_swap_out_seqs(self, seqs: SeqList, swap_out_map: Dict[int, int]):
-        """try swap sequence out."""
-        for seq in seqs:
-            if not self.can_swap_out(seq):
-                continue
-            if not self.block_manager.can_swap_out(seq):
-                continue
-            self.swap_out(seq, swap_out_map)
-            return True
-
-        return False
+        return self.block_manager.on_device(seq, 'cpu')
+
+    def try_swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int,
+                                                                      int]):
+        """try swap out."""
+        success, swap_map = self.block_manager.try_swap_out(seq)
+        if success:
+            swap_out_map.update(swap_map)
+        return success
+
+    def try_swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
+        """try swap in."""
+        success, swap_map = self.block_manager.try_swap_in(seq)
+        if success:
+            swap_in_map.update(swap_map)
+        return success
diff --git a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
index 879f3cf9d9..3490b3e95a 100644
--- a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
+++ b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
@@ -4,7 +4,7 @@
 from lmdeploy.pytorch.paging.block_manager import BlockManager
 
 from ...messages import SchedulerSequence
-from .base_eviction_helper import BaseEvictionHelper, SeqList
+from .base_eviction_helper import BaseEvictionHelper
 
 
 class RecomputeEvictionHelper(BaseEvictionHelper):
@@ -13,14 +13,6 @@ class RecomputeEvictionHelper(BaseEvictionHelper):
     def __init__(self, block_manager: BlockManager):
         super().__init__(block_manager)
 
-    def can_swap_out(self, seq: SchedulerSequence):
-        """sequence can swap out."""
-        return True
-
-    def can_swap_in(self, seq: SchedulerSequence):
-        """sequence can swap in."""
-        return self.block_manager.can_allocate(seq)
-
     def need_swap_in(self, seq: SchedulerSequence):
         """sequence need swap in."""
         return False
@@ -34,13 +26,19 @@ def swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int, int]):
         self.block_manager.free(seq)
         seq.set_step(0)
 
-    def try_swap_out_seqs(self, seqs: SeqList, swap_out_map: Dict[int, int]):
-        """try swap sequence out."""
-        for seq in seqs:
-            if not self.can_swap_out(seq):
-                continue
-            self.swap_out(seq, )
-            swap_out_map.update(self.block_manager.swap_out(seq))
+    def try_swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int,
+                                                                      int]):
+        """try swap out."""
+        if seq.history_len > 0:
+            self.swap_out(seq, swap_out_map)
             return True
+        else:
+            return False
 
-        return False
+    def try_swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
+        """try swap in."""
+        if self.block_manager.can_allocate(seq):
+            self.swap_in(seq, swap_in_map)
+            return True
+        else:
+            return False
diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py
index d6423b06db..8ee3e337d6 100644
--- a/lmdeploy/pytorch/paging/scheduler.py
+++ b/lmdeploy/pytorch/paging/scheduler.py
@@ -6,7 +6,6 @@
 
 from lmdeploy.utils import get_logger
 
-from ..block import PhysicalTokenBlock
 from ..config import CacheConfig, SchedulerConfig
 from ..messages import MessageStatus, SchedulerSequence, SchedulerSession
 from .block_manager import BlockManager
@@ -14,7 +13,6 @@
 logger = get_logger('lmdeploy')
 
 SeqList = List[SchedulerSequence]
-BlockTable = List[PhysicalTokenBlock]
 
 
 def _find_seq_with_session_id(group: SeqList, session_id: int):
@@ -29,7 +27,6 @@ class SchedulerOutput:
     swap_in_map: Dict[int, int]
     swap_out_map: Dict[int, int]
     copy_map: Dict[int, int]
-    block_tables: List[BlockTable]
 
 
 class Scheduler:
@@ -48,11 +45,9 @@ def __init__(self, scheduler_config: SchedulerConfig,
         self.waiting: SeqList = []
         self.running: SeqList = []
         self.hanging: SeqList = []
-        self.aborted: SeqList = []
         self.sessions: Dict[int, SchedulerSession] = OrderedDict()
 
         self.block_manager = BlockManager(
-            cache_config.block_size,
             cache_config.num_gpu_blocks,
             cache_config.num_cpu_blocks,
         )
@@ -88,7 +83,7 @@ def add_session(self, session_id: int):
             session_id (int): New session id.
         """
         assert session_id not in self.sessions
-        session = SchedulerSession(session_id)
+        session = SchedulerSession(session_id, self.cache_config.block_size)
         self.sessions[session_id] = session
         return session
 
@@ -108,46 +103,97 @@ def add_sequence(self, seq: SchedulerSequence):
                 self.scheduler_config.max_request_output_len
         self.waiting.append(seq)
 
-    def _schedule(self):
-        """Schedule next step.
+    def _schedule_prefill(self):
+        """Schedule for prefilling."""
 
-        Running is the messages to perform inference. Swap in/swap out is the
-        table used to perform memory paging (between host and device)
-
-        The schedule follow steps:
-        1. Try allocate resources for all running sequence. If there are no
-            enough resources, try `swap out` caches of hanging and waiting
-            sequence. If there are still no enough resources, move the sequence
-            to waiting.
-        2. Check if sequence in the waiting list can be moved to running
-        """
+        max_batches = self.scheduler_config.max_batches - len(self.running)
+        block_manager = self.block_manager
+        eviction_helper = self.eviction_helper
+        swap_out_map: Dict[int, int] = dict()
+        swap_in_map: Dict[int, int] = dict()
+        copy_map: Dict[int, int] = dict()
         running: SeqList = []
-        swap_out_map: Dict[int, int] = {}
-        swap_in_map: Dict[int, int] = {}
-        copy_map: Dict[int, int] = {}
+
+        def _to_running(seq: SchedulerSequence):
+            """to running."""
+            self._set_message_status(seq, MessageStatus.RUNNING)
+            running.append(seq)
+
+        def _evict_until_can_append(seq: SchedulerSequence):
+            """evict until can append."""
+            while eviction_helper.try_swap_out_unused(self.hanging,
+                                                      self.waiting[1:],
+                                                      swap_out_map):
+                if block_manager.can_append_slot(seq):
+                    return True
+            return False
+
+        def _reorder_waiting():
+            """reorder waiting."""
+            self.waiting = sorted(self.waiting,
+                                  key=lambda seq: seq.arrive_time)
+
+        if len(running) >= max_batches or len(self.waiting) == 0:
+            return running, swap_in_map, swap_out_map, copy_map
+
+        _reorder_waiting()
+        while len(self.waiting) > 0 and len(running) < max_batches:
+            seq = self.waiting[0]
+
+            if not block_manager.can_allocate(seq):
+                if not _evict_until_can_append(seq):
+                    break
+
+            if eviction_helper.need_swap_in(seq):
+                if not eviction_helper.try_swap_in(seq, swap_in_map):
+                    break
+            # allocate session memory
+            block_manager.allocate(seq)
+            self.waiting.pop(0)
+            _to_running(seq)
+
+        self.running += running
+        return running, swap_in_map, swap_out_map, copy_map
+
+    def _schedule_decoding(self):
+        """schedule decoding."""
+        assert len(self.running) != 0
+
         block_manager = self.block_manager
         eviction_helper = self.eviction_helper
+        swap_out_map: Dict[int, int] = dict()
+        swap_in_map: Dict[int, int] = dict()
+        copy_map: Dict[int, int] = dict()
+        running: SeqList = []
 
         def _to_running(seq: SchedulerSequence):
+            """to running."""
             self._set_message_status(seq, MessageStatus.RUNNING)
             running.append(seq)
 
         def _try_append_slot(seq):
             """try append slot."""
-            if self.block_manager.can_append_slot(seq):
-                self.block_manager.append_slot(seq)
+            if seq.num_required_blocks() == 0:
                 _to_running(seq)
                 return True
-            else:
-                return False
+            if block_manager.can_append_slot(seq):
+                block_manager.append_slot(seq)
+                _to_running(seq)
+                return True
+            return False
 
-        block_size = self.cache_config.block_size
+        def _evict_until_can_append(seq: SchedulerSequence):
+            """evict until can append."""
+            while eviction_helper.try_swap_out_unused(self.hanging,
+                                                      self.waiting,
+                                                      swap_out_map):
+                if block_manager.can_append_slot(seq):
+                    return True
+            return False
 
         # 1. running
         for seq in self.running:
             # token + 1
-            num_required_tokens = seq.num_required_tokens()
-            seq.append_tokens(num_required_tokens, block_size)
 
             if len(seq.logical_blocks) > self.block_manager.num_gpu_blocks:
                 # Reach max gpu cache size.
@@ -156,89 +202,32 @@ def _try_append_slot(seq):
                                'reach max gpu size.')
                 self._set_message_status(seq, MessageStatus.ABORTED)
                 self.block_manager.free(seq)
-                self.aborted.append(seq)
 
             if not _try_append_slot(seq):
-                # try free unused cache from hanging and waiting
-                do_running = False
-                while eviction_helper.try_swap_out(self.hanging, self.waiting,
-                                                   swap_out_map):
-                    if _try_append_slot(seq):
-                        do_running = True
-                        break
-                if not do_running:
+                # try free unused cache from waiting
+                if _evict_until_can_append(seq):
+                    _try_append_slot(seq)
+                else:
                     # move to waiting
                     self._set_message_status(seq, MessageStatus.WAITING)
-                    self.waiting.append(seq)
-
-        max_batches = self.scheduler_config.max_batches
-
-        # 2. waiting
-        self.waiting = sorted(self.waiting, key=lambda seq: seq.arrive_time)
-        while len(self.waiting) > 0 and len(running) < max_batches:
-            seq = self.waiting[0]
-            num_required_tokens = seq.num_required_tokens()
-            seq.append_tokens(num_required_tokens, block_size)
-
-            block_table = block_manager.get_block_table(seq)
-            if block_table is not None:
-                if not block_manager.can_append_slot(seq):
-                    can_append = False
-                    while eviction_helper.try_swap_out_seqs(
-                            self.hanging, swap_out_map):
-                        if block_manager.can_append_slot(seq):
-                            can_append = True
-                            break
-                    if not can_append:
-                        break
-                if eviction_helper.need_swap_in(seq):
-                    if eviction_helper.can_swap_in(seq):
-                        eviction_helper.swap_in(seq, swap_in_map)
-                    else:
-                        break
-                block_manager.append_slot(seq)
-                self.waiting.pop(0)
-                _to_running(seq)
-            else:
-                if not block_manager.can_allocate(seq):
-                    can_alloc = False
-                    while eviction_helper.try_swap_out_seqs(
-                            self.hanging, swap_out_map):
-                        if block_manager.can_allocate(seq):
-                            can_alloc = True
-                            break
-                    if not can_alloc:
-                        break
-                # allocate session memory
-                block_manager.allocate(seq)
-                self.waiting.pop(0)
-                _to_running(seq)
+                    self.waiting.insert(0, seq)
 
         self.running = running
-
-        running = [
-            msg for msg in self.running if msg.status == MessageStatus.RUNNING
-        ]
-        if len(running) == 0:
-            logger.warning('No enough resources. Free gpu blocks: '
-                           f'{self.block_manager.get_num_free_gpu_blocks()}, '
-                           'Please end sessions.')
         return running, swap_in_map, swap_out_map, copy_map
 
-    def schedule(self):
+    def schedule(self, is_prefill: bool):
         """Schedule inputs for next steps."""
-        running, swap_in_map, swap_out_map, copy_map = self._schedule()
-
-        block_tables = [
-            self.block_manager.get_block_table(seq) for seq in running
-        ]
+        if is_prefill:
+            output = self._schedule_prefill()
+        else:
+            output = self._schedule_decoding()
+        running, swap_in_map, swap_out_map, copy_map = output
 
         return SchedulerOutput(
             running=running,
             swap_in_map=swap_in_map,
             swap_out_map=swap_out_map,
             copy_map=copy_map,
-            block_tables=block_tables,
         )
 
     def _set_session_status(self, session_id: int, status: MessageStatus):
@@ -278,6 +267,9 @@ def has_unfinished(self):
         """Check if there are any unfinished message."""
         return self.waiting or self.running
 
+    def has_running(self):
+        return len(self.running) > 0
+
     def _remove_sequence(self, seq: SchedulerSequence):
         """Remove sequence(unsafe)
 
diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py
index 9588b00da1..66ba6c68d2 100644
--- a/lmdeploy/serve/async_engine.py
+++ b/lmdeploy/serve/async_engine.py
@@ -1,12 +1,9 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import asyncio
 import dataclasses
-import os.path as osp
 import random
 from contextlib import contextmanager
-from typing import Literal, Optional
-
-from lmdeploy.model import MODELS, BaseModel
+from typing import List, Literal, Optional
 
 
 @dataclasses.dataclass
@@ -28,26 +25,24 @@ class AsyncEngine:
         tp (int): tensor parallel
     """
 
-    def __init__(self, model_path, instance_num=32, tp=1) -> None:
+    def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None:
         from lmdeploy import turbomind as tm
-        from lmdeploy.tokenizer import Tokenizer
-        tokenizer_model_path = osp.join(model_path, 'triton_models',
-                                        'tokenizer')
-        tokenizer = Tokenizer(tokenizer_model_path)
-        self.tm_model = tm.TurboMind(model_path,
-                                     eos_id=tokenizer.eos_token_id,
-                                     tp=tp)
-        self.tokenizer = tokenizer
+        self.tm_model = tm.TurboMind.from_pretrained(model_path,
+                                                     tp=tp,
+                                                     **kwargs)
+        self.tokenizer = self.tm_model.tokenizer
         self.generators = [
             self.tm_model.create_instance() for i in range(instance_num)
         ]
         self.instance_num = instance_num
-        self.model: BaseModel = MODELS.get(self.tm_model.model_name)()
+        self.model = self.tm_model.model
         self.available = [True] * instance_num
         self.starts = [None] * instance_num
         self.steps = {}
+        self.loop = asyncio.get_event_loop()
 
     def stop_session(self, session_id: int):
+        """Stop a session by a session_id."""
         instance_id = session_id % self.instance_num
         input_ids = self.tokenizer.encode('')
         for outputs in self.generators[instance_id].stream_infer(
@@ -60,8 +55,24 @@ def stop_session(self, session_id: int):
             pass
         self.available[instance_id] = True
 
+    def end_session(self, session_id: int):
+        """Clear a session by a session_id."""
+        instance_id = session_id % self.instance_num
+        input_ids = self.tokenizer.encode('')
+        for outputs in self.generators[instance_id].stream_infer(
+                session_id,
+                input_ids,
+                request_output_len=0,
+                sequence_start=False,
+                sequence_end=True,
+                stop=True):
+            pass
+        self.steps[str(session_id)] = 0
+        self.available[instance_id] = True
+
     @contextmanager
     def safe_run(self, instance_id: int, session_id: Optional[int] = None):
+        """A context manager to make sure server's safe running."""
         self.available[instance_id] = False
         try:
             yield
@@ -82,22 +93,80 @@ async def get_generator(self, instance_id: int, stop: bool = False):
                 await asyncio.sleep(0.1)
         return self.generators[instance_id]
 
+    def batch_infer(self,
+                    prompts: List[str],
+                    request_output_len=512,
+                    top_k=40,
+                    top_p=0.8,
+                    temperature=0.8,
+                    repetition_penalty=1.0,
+                    ignore_eos=False,
+                    do_preprocess=True,
+                    **kwargs):
+        """Inference a batch of prompts.
+
+        Args:
+            prompts (List[str]): a batch of prompts
+            request_output_len (int): output token nums
+            top_k (int): The number of the highest probability vocabulary
+              tokens to keep for top-k-filtering
+            top_p (float): If set to float < 1, only the smallest set of most
+              probable tokens with probabilities that add up to top_p or higher
+            are kept for generation.
+            temperature (float): to modulate the next token probability
+            repetition_penalty (float): The parameter for repetition penalty.
+              1.0 means no penalty
+            ignore_eos (bool): indicator for ignoring eos
+            do_preprocess (bool): whether pre-process the messages.
+        """
+        assert isinstance(prompts, List), 'prompts should be a list'
+        batch_size = len(prompts)
+        outputs = [''] * batch_size
+        generators = []
+        for i, prompt in enumerate(prompts):
+            generators.append(
+                self.generate(prompt,
+                              i,
+                              stream_response=True,
+                              sequence_start=True,
+                              sequence_end=True,
+                              request_output_len=request_output_len,
+                              top_k=top_k,
+                              top_p=top_p,
+                              temperature=temperature,
+                              ignore_eos=ignore_eos,
+                              repetition_penalty=repetition_penalty,
+                              do_preprocess=do_preprocess,
+                              **kwargs))
+
+        async def _inner_call(i, generator):
+            async for out in generator:
+                outputs[i] += out.response
+
+        async def gather():
+            await asyncio.gather(
+                *[_inner_call(i, generators[i]) for i in range(batch_size)])
+
+        self.loop.run_until_complete(gather())
+        return outputs
+
     async def generate(
-        self,
-        messages,
-        session_id,
-        stream_response=True,
-        sequence_start=True,
-        sequence_end=False,
-        step=0,
-        request_output_len=512,
-        stop=False,
-        top_k=40,
-        top_p=0.8,
-        temperature=0.8,
-        repetition_penalty=1.0,
-        ignore_eos=False,
-    ):
+            self,
+            messages,
+            session_id,
+            stream_response=True,
+            sequence_start=True,
+            sequence_end=True,  # no interactive mode by default
+            step=0,
+            request_output_len=512,
+            stop=False,
+            top_k=40,
+            top_p=0.8,
+            temperature=0.8,
+            repetition_penalty=1.0,
+            ignore_eos=False,
+            do_preprocess=True,
+            **kwargs):
         """Generate responses.
 
         Args:
@@ -109,15 +178,16 @@ async def generate(
             sequence_end (bool): indicator for ending a sequence
             step (int): the offset of the k/v cache
             stop (bool): whether stop inference
-            top_p (float): If set to float < 1, only the smallest set of most
-              probable tokens with probabilities that add up to top_p or higher
-            are kept for generation.
             top_k (int): The number of the highest probability vocabulary
               tokens to keep for top-k-filtering
+            top_p (float): If set to float < 1, only the smallest set of most
+              probable tokens with probabilities that add up to top_p or higher
+              are kept for generation.
             temperature (float): to modulate the next token probability
             repetition_penalty (float): The parameter for repetition penalty.
               1.0 means no penalty
             ignore_eos (bool): indicator for ignoring eos
+            do_preprocess (bool): whether pre-process the messages.
         """
         instance_id = session_id % self.instance_num
         if str(session_id) not in self.steps:
@@ -125,14 +195,18 @@ async def generate(
         if step != 0:
             self.steps[str(session_id)] = step
         seed = random.getrandbits(64)
-        prompt = self.model.messages2prompt(messages, sequence_start)
-        input_ids = self.tokenizer.encode(prompt)
+        prompt = messages
+        if do_preprocess:
+            prompt = self.model.messages2prompt(prompt, sequence_start)
+        input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
         finish_reason = 'stop' if stop else None
         if self.steps[str(session_id)] + len(
-                input_ids) >= self.tm_model.session_len:
+                input_ids) + request_output_len >= self.tm_model.session_len:
             finish_reason = 'length'
             yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
                          finish_reason)
+            if sequence_end is True and sequence_start is False:
+                self.end_session(session_id)
         else:
             generator = await self.get_generator(instance_id, stop)
             with self.safe_run(instance_id, session_id):
@@ -156,103 +230,23 @@ async def generate(
                     # decode res
                     response = self.tokenizer.decode(res.tolist(),
                                                      offset=response_size)
+                    # utf-8 char at the end means it's a potential unfinished
+                    # byte sequence, continue to concate it with the next
+                    # sequence and decode them together
+                    if response.endswith('�'):
+                        continue
                     # response, history token len,
                     # input token len, gen token len
                     yield GenOut(response, self.steps[str(session_id)],
                                  len(input_ids), tokens, finish_reason)
                     response_size = tokens
 
-                # update step
-                self.steps[str(session_id)] += len(input_ids) + tokens
-                if sequence_end or stop:
-                    self.steps[str(session_id)] = 0
-
-    async def generate_openai(
-        self,
-        messages,
-        instance_id,
-        stream_response=True,
-        renew_session=False,
-        request_output_len=512,
-        stop=False,
-        top_k=40,
-        top_p=0.8,
-        temperature=0.8,
-        repetition_penalty=1.0,
-        ignore_eos=False,
-    ):
-        """Generate responses.
-
-        Args:
-            messages (str | List): chat history or prompt
-            instance_id (int): actually request host ip
-            stream_response (bool): whether return responses streamingly
-            renew_session (bool): renew the session
-            request_output_len (int): output token nums
-            stop (bool): whether stop inference
-            top_p (float): If set to float < 1, only the smallest set of most
-              probable tokens with probabilities that add up to top_p or higher
-            are kept for generation.
-            top_k (int): The number of the highest probability vocabulary
-              tokens to keep for top-k-filtering
-            temperature (float): to modulate the next token probability
-            repetition_penalty (float): The parameter for repetition penalty.
-              1.0 means no penalty
-            ignore_eos (bool): indicator for ignoring eos
-        """
-        session_id = instance_id
-        instance_id %= self.instance_num
-        sequence_start = False
-        generator = await self.get_generator(instance_id)
-        if renew_session:  # renew a session
-            empty_input_ids = self.tokenizer.encode('')
-            for outputs in generator.stream_infer(session_id=session_id,
-                                                  input_ids=[empty_input_ids],
-                                                  request_output_len=0,
-                                                  sequence_start=False,
-                                                  sequence_end=True,
-                                                  stop=True):
-                pass
-            self.steps[str(session_id)] = 0
-        if str(session_id) not in self.steps:
-            self.steps[str(session_id)] = 0
-        if self.steps[str(session_id)] == 0:
-            sequence_start = True
-        seed = random.getrandbits(64)
-        prompt = self.model.messages2prompt(messages, sequence_start)
-        input_ids = self.tokenizer.encode(prompt)
-        finish_reason = 'stop' if stop else None
-        if self.steps[str(session_id)] + len(
-                input_ids) >= self.tm_model.session_len:
-            finish_reason = 'length'
-            yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
-                         finish_reason)
-        else:
-            with self.safe_run(instance_id, session_id):
-                response_size = 0
-                async for outputs in generator.async_stream_infer(
-                        session_id=session_id,
-                        input_ids=[input_ids],
-                        stream_output=stream_response,
-                        request_output_len=request_output_len,
-                        sequence_start=(sequence_start),
-                        sequence_end=False,
-                        step=self.steps[str(session_id)],
-                        stop=stop,
-                        top_k=top_k,
-                        top_p=top_p,
-                        temperature=temperature,
-                        repetition_penalty=repetition_penalty,
-                        ignore_eos=ignore_eos,
-                        random_seed=seed if sequence_start else None):
-                    res, tokens = outputs[0]
-                    # decode res
-                    response = self.tokenizer.decode(res.tolist(),
-                                                     offset=response_size)
-                    # response, history len, input len, generation len
+                # `response_size` might be note updated since
+                # ` if response.endswith('�')`
+                if response_size != tokens:
                     yield GenOut(response, self.steps[str(session_id)],
                                  len(input_ids), tokens, finish_reason)
-                    response_size = tokens
-
                 # update step
                 self.steps[str(session_id)] += len(input_ids) + tokens
+                if sequence_end or stop:
+                    self.steps[str(session_id)] = 0
diff --git a/lmdeploy/serve/client.py b/lmdeploy/serve/client.py
index 283e96e299..424e83143f 100644
--- a/lmdeploy/serve/client.py
+++ b/lmdeploy/serve/client.py
@@ -1,8 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import os
 
-import fire
-
 from lmdeploy.serve.turbomind.chatbot import Chatbot
 
 
@@ -20,7 +18,6 @@ def input_prompt(model_name):
 def main(tritonserver_addr: str,
          session_id: int = 1,
          cap: str = 'chat',
-         sys_instruct: str = None,
          stream_output: bool = True,
          **kwargs):
     """An example to communicate with inference server through the command line
@@ -32,13 +29,11 @@ def main(tritonserver_addr: str,
         session_id (int): the identical id of a session
         cap (str): the capability of a model. For example, codellama has
             the ability among ['completion', 'infill', 'instruct', 'python']
-        sys_instruct (str): the content of 'system' role, which is used by
-            conversational model
         stream_output (bool): indicator for streaming output or not
         **kwargs (dict): other arguments for initializing model's chat template
     """
     log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
-    kwargs.update(capability=cap, system=sys_instruct)
+    kwargs.update(capability=cap)
     chatbot = Chatbot(tritonserver_addr,
                       log_level=log_level,
                       display=stream_output,
@@ -69,4 +64,6 @@ def main(tritonserver_addr: str,
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/serve/gradio/api_server_backend.py b/lmdeploy/serve/gradio/api_server_backend.py
new file mode 100644
index 0000000000..8dd92fa0fd
--- /dev/null
+++ b/lmdeploy/serve/gradio/api_server_backend.py
@@ -0,0 +1,186 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from threading import Lock
+from typing import Sequence
+
+import gradio as gr
+
+from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
+from lmdeploy.serve.openai.api_client import (get_model_list,
+                                              get_streaming_response)
+
+
+class InterFace:
+    api_server_url: str = None
+    global_session_id: int = 0
+    lock = Lock()
+
+
+def chat_stream_restful(instruction: str, state_chatbot: Sequence,
+                        cancel_btn: gr.Button, reset_btn: gr.Button,
+                        session_id: int):
+    """Chat with AI assistant.
+
+    Args:
+        instruction (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        session_id (int): the session id
+    """
+    state_chatbot = state_chatbot + [(instruction, None)]
+
+    yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
+
+    for response, tokens, finish_reason in get_streaming_response(
+            instruction,
+            f'{InterFace.api_server_url}/v1/chat/interactive',
+            session_id=session_id,
+            request_output_len=512,
+            interactive_mode=True):
+        if finish_reason == 'length':
+            gr.Warning('WARNING: exceed session max length.'
+                       ' Please restart the session by reset button.')
+        if tokens < 0:
+            gr.Warning('WARNING: running on the old session.'
+                       ' Please restart the session by reset button.')
+        if state_chatbot[-1][-1] is None:
+            state_chatbot[-1] = (state_chatbot[-1][0], response)
+        else:
+            state_chatbot[-1] = (state_chatbot[-1][0],
+                                 state_chatbot[-1][1] + response
+                                 )  # piece by piece
+        yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
+
+    yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
+
+
+def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
+                       session_id: int):
+    """reset the session.
+
+    Args:
+        instruction_txtbox (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        session_id (int): the session id
+    """
+    state_chatbot = []
+    # end the session
+    for response, tokens, finish_reason in get_streaming_response(
+            '',
+            f'{InterFace.api_server_url}/v1/chat/interactive',
+            session_id=session_id,
+            request_output_len=0,
+            interactive_mode=False):
+        pass
+
+    return (
+        state_chatbot,
+        state_chatbot,
+        gr.Textbox.update(value=''),
+    )
+
+
+def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
+                        reset_btn: gr.Button, session_id: int):
+    """stop the session.
+
+    Args:
+        instruction_txtbox (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        session_id (int): the session id
+    """
+    yield (state_chatbot, disable_btn, disable_btn)
+    # end the session
+    for out in get_streaming_response(
+            '',
+            f'{InterFace.api_server_url}/v1/chat/interactive',
+            session_id=session_id,
+            request_output_len=0,
+            stop=True):
+        pass
+    time.sleep(0.5)
+    messages = []
+    for qa in state_chatbot:
+        messages.append(dict(role='user', content=qa[0]))
+        if qa[1] is not None:
+            messages.append(dict(role='assistant', content=qa[1]))
+    for out in get_streaming_response(
+            messages,
+            f'{InterFace.api_server_url}/v1/chat/interactive',
+            session_id=session_id,
+            request_output_len=0,
+            interactive_mode=True):
+        pass
+    yield (state_chatbot, disable_btn, enable_btn)
+
+
+def run_api_server(api_server_url: str,
+                   server_name: str = 'localhost',
+                   server_port: int = 6006,
+                   batch_size: int = 32):
+    """chat with AI assistant through web ui.
+
+    Args:
+        api_server_url (str): restufl api url
+        server_name (str): the ip address of gradio server
+        server_port (int): the port of gradio server
+        batch_size (int): batch size for running Turbomind directly
+    """
+    InterFace.api_server_url = api_server_url
+    model_names = get_model_list(f'{api_server_url}/v1/models')
+    model_name = ''
+    if isinstance(model_names, list) and len(model_names) > 0:
+        model_name = model_names[0]
+    else:
+        raise ValueError('gradio can find a suitable model from restful-api')
+
+    with gr.Blocks(css=CSS, theme=THEME) as demo:
+        state_chatbot = gr.State([])
+        state_session_id = gr.State(0)
+
+        with gr.Column(elem_id='container'):
+            gr.Markdown('## LMDeploy Playground')
+
+            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
+            instruction_txtbox = gr.Textbox(
+                placeholder='Please input the instruction',
+                label='Instruction')
+            with gr.Row():
+                cancel_btn = gr.Button(value='Cancel', interactive=False)
+                reset_btn = gr.Button(value='Reset')
+
+        send_event = instruction_txtbox.submit(chat_stream_restful, [
+            instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
+            state_session_id
+        ], [state_chatbot, chatbot, cancel_btn, reset_btn])
+        instruction_txtbox.submit(
+            lambda: gr.Textbox.update(value=''),
+            [],
+            [instruction_txtbox],
+        )
+        cancel_btn.click(
+            cancel_restful_func,
+            [state_chatbot, cancel_btn, reset_btn, state_session_id],
+            [state_chatbot, cancel_btn, reset_btn],
+            cancels=[send_event])
+
+        reset_btn.click(reset_restful_func,
+                        [instruction_txtbox, state_chatbot, state_session_id],
+                        [state_chatbot, chatbot, instruction_txtbox],
+                        cancels=[send_event])
+
+        def init():
+            with InterFace.lock:
+                InterFace.global_session_id += 1
+            new_session_id = InterFace.global_session_id
+            return new_session_id
+
+        demo.load(init, inputs=None, outputs=[state_session_id])
+
+    print(f'server is gonna mount on: http://{server_name}:{server_port}')
+    demo.queue(concurrency_count=batch_size, max_size=100,
+               api_open=True).launch(
+                   max_threads=10,
+                   share=True,
+                   server_port=server_port,
+                   server_name=server_name,
+               )
diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py
index 71db7a2749..cf8815ad0f 100644
--- a/lmdeploy/serve/gradio/app.py
+++ b/lmdeploy/serve/gradio/app.py
@@ -1,542 +1,41 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-import os
-import threading
-import time
-from functools import partial
-from typing import Sequence
-
-import fire
-import gradio as gr
-
-from lmdeploy.serve.async_engine import AsyncEngine
-from lmdeploy.serve.gradio.css import CSS
-from lmdeploy.serve.openai.api_client import (get_model_list,
-                                              get_streaming_response)
-from lmdeploy.serve.openai.api_server import ip2id
-from lmdeploy.serve.turbomind.chatbot import Chatbot
-
-THEME = gr.themes.Soft(
-    primary_hue=gr.themes.colors.blue,
-    secondary_hue=gr.themes.colors.sky,
-    font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
-
-enable_btn = gr.Button.update(interactive=True)
-disable_btn = gr.Button.update(interactive=False)
-
-
-def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
-                request: gr.Request):
-    """Chat with AI assistant.
-
-    Args:
-        instruction (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        llama_chatbot (Chatbot): the instance of a chatbot
-        request (gr.Request): the request from a user
-        model_name (str): the name of deployed model
-    """
-    instruction = state_chatbot[-1][0]
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-
-    bot_response = llama_chatbot.stream_infer(
-        session_id, instruction, f'{session_id}-{len(state_chatbot)}')
-
-    for status, tokens, _ in bot_response:
-        state_chatbot[-1] = (state_chatbot[-1][0], tokens)
-        yield (state_chatbot, state_chatbot, '')
-
-    return (state_chatbot, state_chatbot, '')
-
-
-def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
-                   llama_chatbot: gr.State, triton_server_addr: str,
-                   model_name: str):
-    """reset the session."""
-    state_chatbot = []
-    log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
-    llama_chatbot = Chatbot(triton_server_addr,
-                            model_name,
-                            log_level=log_level,
-                            display=True)
-
-    return (
-        llama_chatbot,
-        state_chatbot,
-        state_chatbot,
-        gr.Textbox.update(value=''),
-    )
-
-
-def cancel_func(
-    instruction_txtbox: gr.Textbox,
-    state_chatbot: gr.State,
-    llama_chatbot: gr.State,
-):
-    """cancel the session."""
-    session_id = llama_chatbot._session.session_id
-    llama_chatbot.cancel(session_id)
-
-    return (
-        llama_chatbot,
-        state_chatbot,
-    )
-
-
-def add_instruction(instruction, state_chatbot):
-    state_chatbot = state_chatbot + [(instruction, None)]
-    return ('', state_chatbot)
-
-
-def run_server(triton_server_addr: str,
-               server_name: str = 'localhost',
-               server_port: int = 6006):
-    """chat with AI assistant through web ui.
-
-    Args:
-        triton_server_addr (str): the communication address of inference server
-        server_name (str): the ip address of gradio server
-        server_port (int): the port of gradio server
-    """
-    with gr.Blocks(css=CSS, theme=THEME) as demo:
-        log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
-        llama_chatbot = gr.State(
-            Chatbot(triton_server_addr, log_level=log_level, display=True))
-        state_chatbot = gr.State([])
-        model_name = llama_chatbot.value.model_name
-        reset_all = partial(reset_all_func,
-                            model_name=model_name,
-                            triton_server_addr=triton_server_addr)
-
-        with gr.Column(elem_id='container'):
-            gr.Markdown('## LMDeploy Playground')
-
-            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
-            instruction_txtbox = gr.Textbox(
-                placeholder='Please input the instruction',
-                label='Instruction')
-            with gr.Row():
-                cancel_btn = gr.Button(value='Cancel')
-                reset_btn = gr.Button(value='Reset')
-
-        send_event = instruction_txtbox.submit(
-            add_instruction, [instruction_txtbox, state_chatbot],
-            [instruction_txtbox, state_chatbot]).then(
-                chat_stream, [state_chatbot, llama_chatbot],
-                [state_chatbot, chatbot])
-
-        cancel_btn.click(cancel_func,
-                         [instruction_txtbox, state_chatbot, llama_chatbot],
-                         [llama_chatbot, chatbot],
-                         cancels=[send_event])
-
-        reset_btn.click(
-            reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
-            [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
-            cancels=[send_event])
-
-    print(f'server is gonna mount on: http://{server_name}:{server_port}')
-    demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
-        max_threads=10,
-        share=True,
-        server_port=server_port,
-        server_name=server_name,
-    )
-
-
-# a IO interface mananing variables
-class InterFace:
-    async_engine: AsyncEngine = None  # for run_local
-    restful_api_url: str = None  # for run_restful
-
-
-def chat_stream_restful(
-    instruction: str,
-    state_chatbot: Sequence,
-    cancel_btn: gr.Button,
-    reset_btn: gr.Button,
-    request: gr.Request,
-):
-    """Chat with AI assistant.
-
-    Args:
-        instruction (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        request (gr.Request): the request from a user
-    """
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-    bot_summarized_response = ''
-    state_chatbot = state_chatbot + [(instruction, None)]
-
-    yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
-           f'{bot_summarized_response}'.strip())
-
-    for response, tokens, finish_reason in get_streaming_response(
-            instruction,
-            f'{InterFace.restful_api_url}/generate',
-            session_id=session_id,
-            request_output_len=512,
-            sequence_start=(len(state_chatbot) == 1),
-            sequence_end=False):
-        if finish_reason == 'length':
-            gr.Warning('WARNING: exceed session max length.'
-                       ' Please restart the session by reset button.')
-        if tokens < 0:
-            gr.Warning('WARNING: running on the old session.'
-                       ' Please restart the session by reset button.')
-        if state_chatbot[-1][-1] is None:
-            state_chatbot[-1] = (state_chatbot[-1][0], response)
-        else:
-            state_chatbot[-1] = (state_chatbot[-1][0],
-                                 state_chatbot[-1][1] + response
-                                 )  # piece by piece
-        yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
-               f'{bot_summarized_response}'.strip())
-
-    yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
-           f'{bot_summarized_response}'.strip())
-
-
-def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
-                       request: gr.Request):
-    """reset the session.
-
-    Args:
-        instruction_txtbox (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        request (gr.Request): the request from a user
-    """
-    state_chatbot = []
-
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-    # end the session
-    for response, tokens, finish_reason in get_streaming_response(
-            '',
-            f'{InterFace.restful_api_url}/generate',
-            session_id=session_id,
-            request_output_len=0,
-            sequence_start=False,
-            sequence_end=True):
-        pass
-
-    return (
-        state_chatbot,
-        state_chatbot,
-        gr.Textbox.update(value=''),
-    )
-
-
-def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
-                        reset_btn: gr.Button, request: gr.Request):
-    """stop the session.
-
-    Args:
-        instruction_txtbox (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        request (gr.Request): the request from a user
-    """
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-    # end the session
-    for out in get_streaming_response('',
-                                      f'{InterFace.restful_api_url}/generate',
-                                      session_id=session_id,
-                                      request_output_len=0,
-                                      sequence_start=False,
-                                      sequence_end=False,
-                                      stop=True):
-        pass
-    time.sleep(0.5)
-    messages = []
-    for qa in state_chatbot:
-        messages.append(dict(role='user', content=qa[0]))
-        if qa[1] is not None:
-            messages.append(dict(role='assistant', content=qa[1]))
-    for out in get_streaming_response(messages,
-                                      f'{InterFace.restful_api_url}/generate',
-                                      session_id=session_id,
-                                      request_output_len=0,
-                                      sequence_start=True,
-                                      sequence_end=False):
-        pass
-    return (state_chatbot, disable_btn, enable_btn)
-
-
-def run_restful(restful_api_url: str,
-                server_name: str = 'localhost',
-                server_port: int = 6006,
-                batch_size: int = 32):
-    """chat with AI assistant through web ui.
-
-    Args:
-        restful_api_url (str): restufl api url
-        server_name (str): the ip address of gradio server
-        server_port (int): the port of gradio server
-        batch_size (int): batch size for running Turbomind directly
-    """
-    InterFace.restful_api_url = restful_api_url
-    model_names = get_model_list(f'{restful_api_url}/v1/models')
-    model_name = ''
-    if isinstance(model_names, list) and len(model_names) > 0:
-        model_name = model_names[0]
-    else:
-        raise ValueError('gradio can find a suitable model from restful-api')
-
-    with gr.Blocks(css=CSS, theme=THEME) as demo:
-        state_chatbot = gr.State([])
-
-        with gr.Column(elem_id='container'):
-            gr.Markdown('## LMDeploy Playground')
-
-            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
-            instruction_txtbox = gr.Textbox(
-                placeholder='Please input the instruction',
-                label='Instruction')
-            with gr.Row():
-                cancel_btn = gr.Button(value='Cancel', interactive=False)
-                reset_btn = gr.Button(value='Reset')
-
-        send_event = instruction_txtbox.submit(
-            chat_stream_restful,
-            [instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
-            [state_chatbot, chatbot, cancel_btn, reset_btn])
-        instruction_txtbox.submit(
-            lambda: gr.Textbox.update(value=''),
-            [],
-            [instruction_txtbox],
-        )
-        cancel_btn.click(cancel_restful_func,
-                         [state_chatbot, cancel_btn, reset_btn],
-                         [state_chatbot, cancel_btn, reset_btn],
-                         cancels=[send_event])
-
-        reset_btn.click(reset_restful_func,
-                        [instruction_txtbox, state_chatbot],
-                        [state_chatbot, chatbot, instruction_txtbox],
-                        cancels=[send_event])
-
-    print(f'server is gonna mount on: http://{server_name}:{server_port}')
-    demo.queue(concurrency_count=batch_size, max_size=100,
-               api_open=True).launch(
-                   max_threads=10,
-                   share=True,
-                   server_port=server_port,
-                   server_name=server_name,
-               )
-
-
-async def chat_stream_local(
-    instruction: str,
-    state_chatbot: Sequence,
-    cancel_btn: gr.Button,
-    reset_btn: gr.Button,
-    request: gr.Request,
-):
-    """Chat with AI assistant.
-
-    Args:
-        instruction (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        request (gr.Request): the request from a user
-    """
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-    bot_summarized_response = ''
-    state_chatbot = state_chatbot + [(instruction, None)]
-
-    yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
-           f'{bot_summarized_response}'.strip())
-
-    async for outputs in InterFace.async_engine.generate(
-            instruction,
-            session_id,
-            stream_response=True,
-            sequence_start=(len(state_chatbot) == 1)):
-        response = outputs.response
-        if outputs.finish_reason == 'length':
-            gr.Warning('WARNING: exceed session max length.'
-                       ' Please restart the session by reset button.')
-        if outputs.generate_token_len < 0:
-            gr.Warning('WARNING: running on the old session.'
-                       ' Please restart the session by reset button.')
-        if state_chatbot[-1][-1] is None:
-            state_chatbot[-1] = (state_chatbot[-1][0], response)
-        else:
-            state_chatbot[-1] = (state_chatbot[-1][0],
-                                 state_chatbot[-1][1] + response
-                                 )  # piece by piece
-        yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
-               f'{bot_summarized_response}'.strip())
-
-    yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
-           f'{bot_summarized_response}'.strip())
-
-
-async def reset_local_func(instruction_txtbox: gr.Textbox,
-                           state_chatbot: gr.State, request: gr.Request):
-    """reset the session.
-
-    Args:
-        instruction_txtbox (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        request (gr.Request): the request from a user
-    """
-    state_chatbot = []
-
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-    # end the session
-    async for out in InterFace.async_engine.generate('',
-                                                     session_id,
-                                                     request_output_len=1,
-                                                     stream_response=True,
-                                                     sequence_start=False,
-                                                     sequence_end=True):
-        pass
-
-    return (
-        state_chatbot,
-        state_chatbot,
-        gr.Textbox.update(value=''),
-    )
-
-
-async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
-                            reset_btn: gr.Button, request: gr.Request):
-    """stop the session.
-
-    Args:
-        instruction_txtbox (str): user's prompt
-        state_chatbot (Sequence): the chatting history
-        request (gr.Request): the request from a user
-    """
-    session_id = threading.current_thread().ident
-    if request is not None:
-        session_id = ip2id(request.kwargs['client']['host'])
-    # end the session
-    async for out in InterFace.async_engine.generate('',
-                                                     session_id,
-                                                     request_output_len=0,
-                                                     stream_response=True,
-                                                     sequence_start=False,
-                                                     sequence_end=False,
-                                                     stop=True):
-        pass
-    messages = []
-    for qa in state_chatbot:
-        messages.append(dict(role='user', content=qa[0]))
-        if qa[1] is not None:
-            messages.append(dict(role='assistant', content=qa[1]))
-    async for out in InterFace.async_engine.generate(messages,
-                                                     session_id,
-                                                     request_output_len=0,
-                                                     stream_response=True,
-                                                     sequence_start=True,
-                                                     sequence_end=False):
-        pass
-    return (state_chatbot, disable_btn, enable_btn)
-
-
-def run_local(model_path: str,
-              server_name: str = 'localhost',
-              server_port: int = 6006,
-              batch_size: int = 4,
-              tp: int = 1):
-    """chat with AI assistant through web ui.
-
-    Args:
-        model_path (str): the path of the deployed model
-        server_name (str): the ip address of gradio server
-        server_port (int): the port of gradio server
-        batch_size (int): batch size for running Turbomind directly
-        tp (int): tensor parallel for Turbomind
-    """
-    InterFace.async_engine = AsyncEngine(model_path=model_path,
-                                         instance_num=batch_size,
-                                         tp=tp)
-
-    with gr.Blocks(css=CSS, theme=THEME) as demo:
-        state_chatbot = gr.State([])
-
-        with gr.Column(elem_id='container'):
-            gr.Markdown('## LMDeploy Playground')
-
-            chatbot = gr.Chatbot(
-                elem_id='chatbot',
-                label=InterFace.async_engine.tm_model.model_name)
-            instruction_txtbox = gr.Textbox(
-                placeholder='Please input the instruction',
-                label='Instruction')
-            with gr.Row():
-                cancel_btn = gr.Button(value='Cancel', interactive=False)
-                reset_btn = gr.Button(value='Reset')
-
-        send_event = instruction_txtbox.submit(
-            chat_stream_local,
-            [instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
-            [state_chatbot, chatbot, cancel_btn, reset_btn])
-        instruction_txtbox.submit(
-            lambda: gr.Textbox.update(value=''),
-            [],
-            [instruction_txtbox],
-        )
-        cancel_btn.click(cancel_local_func,
-                         [state_chatbot, cancel_btn, reset_btn],
-                         [state_chatbot, cancel_btn, reset_btn],
-                         cancels=[send_event])
-
-        reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot],
-                        [state_chatbot, chatbot, instruction_txtbox],
-                        cancels=[send_event])
-
-    print(f'server is gonna mount on: http://{server_name}:{server_port}')
-    demo.queue(concurrency_count=batch_size, max_size=100,
-               api_open=True).launch(
-                   max_threads=10,
-                   share=True,
-                   server_port=server_port,
-                   server_name=server_name,
-               )
 
 
 def run(model_path_or_server: str,
-        server_name: str = 'localhost',
+        server_name: str = '0.0.0.0',
         server_port: int = 6006,
         batch_size: int = 32,
         tp: int = 1,
-        restful_api: bool = False):
+        **kwargs):
     """chat with AI assistant through web ui.
 
     Args:
         model_path_or_server (str): the path of the deployed model or the
-            tritonserver URL or restful api URL. The former is for directly
-            running service with gradio. The latter is for running with
-            tritonserver by default. If the input URL is restful api. Please
-            enable another flag `restful_api`.
+            tritonserver URL or restful api URL. For example:
+            - ./workspace
+            - 0.0.0.0:23333
+            - http://0.0.0.0:23333
         server_name (str): the ip address of gradio server
         server_port (int): the port of gradio server
         batch_size (int): batch size for running Turbomind directly
         tp (int): tensor parallel for Turbomind
-        restufl_api (bool): a flag for model_path_or_server
     """
     if ':' in model_path_or_server:
-        if restful_api:
-            run_restful(model_path_or_server, server_name, server_port,
-                        batch_size)
+        if 'http:' in model_path_or_server:
+            from lmdeploy.serve.gradio.api_server_backend import run_api_server
+            run_api_server(model_path_or_server, server_name, server_port,
+                           batch_size)
         else:
-            run_server(model_path_or_server, server_name, server_port)
+            from lmdeploy.serve.gradio.triton_server_backend import \
+                run_triton_server
+            run_triton_server(model_path_or_server, server_name, server_port)
     else:
+        from lmdeploy.serve.gradio.turbomind_coupled import run_local
         run_local(model_path_or_server, server_name, server_port, batch_size,
-                  tp)
+                  tp, **kwargs)
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(run)
diff --git a/lmdeploy/serve/gradio/constants.py b/lmdeploy/serve/gradio/constants.py
new file mode 100644
index 0000000000..891c572e5a
--- /dev/null
+++ b/lmdeploy/serve/gradio/constants.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import gradio as gr
+
+CSS = """
+#container {
+    width: 95%;
+    margin-left: auto;
+    margin-right: auto;
+}
+
+#chatbot {
+    height: 500px;
+    overflow: auto;
+}
+
+.chat_wrap_space {
+    margin-left: 0.5em
+}
+"""
+
+THEME = gr.themes.Soft(
+    primary_hue=gr.themes.colors.blue,
+    secondary_hue=gr.themes.colors.sky,
+    font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
+
+enable_btn = gr.Button.update(interactive=True)
+disable_btn = gr.Button.update(interactive=False)
diff --git a/lmdeploy/serve/gradio/css.py b/lmdeploy/serve/gradio/css.py
deleted file mode 100644
index b3bd233222..0000000000
--- a/lmdeploy/serve/gradio/css.py
+++ /dev/null
@@ -1,18 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-
-CSS = """
-#container {
-    width: 95%;
-    margin-left: auto;
-    margin-right: auto;
-}
-
-#chatbot {
-    height: 500px;
-    overflow: auto;
-}
-
-.chat_wrap_space {
-    margin-left: 0.5em
-}
-"""
diff --git a/lmdeploy/serve/gradio/triton_server_backend.py b/lmdeploy/serve/gradio/triton_server_backend.py
new file mode 100644
index 0000000000..9148903cc5
--- /dev/null
+++ b/lmdeploy/serve/gradio/triton_server_backend.py
@@ -0,0 +1,143 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from functools import partial
+from threading import Lock
+from typing import Sequence
+
+import gradio as gr
+
+from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
+from lmdeploy.serve.turbomind.chatbot import Chatbot
+
+
+class InterFace:
+    global_session_id: int = 0
+    lock = Lock()
+
+
+def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
+                cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
+    """Chat with AI assistant.
+
+    Args:
+        instruction (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        llama_chatbot (Chatbot): the instance of a chatbot
+        cancel_btn (bool): enable the cancel button or not
+        reset_btn (bool): enable the reset button or not
+        session_id (int): the session id
+    """
+    instruction = state_chatbot[-1][0]
+
+    bot_response = llama_chatbot.stream_infer(
+        session_id, instruction, f'{session_id}-{len(state_chatbot)}')
+
+    for status, tokens, _ in bot_response:
+        state_chatbot[-1] = (state_chatbot[-1][0], tokens)
+        yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
+
+    yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
+
+
+def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
+                   llama_chatbot: gr.State, triton_server_addr: str,
+                   model_name: str):
+    """reset the session."""
+    state_chatbot = []
+    log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
+    llama_chatbot = Chatbot(triton_server_addr,
+                            model_name,
+                            log_level=log_level,
+                            display=True)
+
+    return (
+        llama_chatbot,
+        state_chatbot,
+        state_chatbot,
+        gr.Textbox.update(value=''),
+    )
+
+
+def cancel_func(
+    state_chatbot: gr.State,
+    llama_chatbot: gr.State,
+    cancel_btn: gr.Button,
+    reset_btn: gr.Button,
+):
+    """cancel the session."""
+    yield (llama_chatbot, state_chatbot, disable_btn, disable_btn)
+    session_id = llama_chatbot._session.session_id
+    llama_chatbot.cancel(session_id)
+
+    yield (llama_chatbot, state_chatbot, disable_btn, enable_btn)
+
+
+def add_instruction(instruction, state_chatbot):
+    state_chatbot = state_chatbot + [(instruction, None)]
+    return ('', state_chatbot)
+
+
+def run_triton_server(triton_server_addr: str,
+                      server_name: str = 'localhost',
+                      server_port: int = 6006):
+    """chat with AI assistant through web ui.
+
+    Args:
+        triton_server_addr (str): the communication address of inference server
+        server_name (str): the ip address of gradio server
+        server_port (int): the port of gradio server
+    """
+    with gr.Blocks(css=CSS, theme=THEME) as demo:
+        log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
+        llama_chatbot = gr.State(
+            Chatbot(triton_server_addr, log_level=log_level, display=True))
+        state_chatbot = gr.State([])
+        state_session_id = gr.State(0)
+        model_name = llama_chatbot.value.model_name
+        reset_all = partial(reset_all_func,
+                            model_name=model_name,
+                            triton_server_addr=triton_server_addr)
+
+        with gr.Column(elem_id='container'):
+            gr.Markdown('## LMDeploy Playground')
+
+            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
+            instruction_txtbox = gr.Textbox(
+                placeholder='Please input the instruction',
+                label='Instruction')
+            with gr.Row():
+                cancel_btn = gr.Button(value='Cancel', interactive=False)
+                reset_btn = gr.Button(value='Reset')
+
+        send_event = instruction_txtbox.submit(
+            add_instruction, [instruction_txtbox, state_chatbot],
+            [instruction_txtbox, state_chatbot]).then(chat_stream, [
+                state_chatbot, llama_chatbot, cancel_btn, reset_btn,
+                state_session_id
+            ], [state_chatbot, chatbot, cancel_btn, reset_btn])
+
+        cancel_btn.click(cancel_func,
+                         [state_chatbot, llama_chatbot, cancel_btn, reset_btn],
+                         [llama_chatbot, chatbot, cancel_btn, reset_btn],
+                         cancels=[send_event])
+
+        reset_btn.click(
+            reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
+            [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
+            cancels=[send_event])
+
+        def init():
+            with InterFace.lock:
+                InterFace.global_session_id += 1
+            new_session_id = InterFace.global_session_id
+            return new_session_id
+
+        demo.load(init, inputs=None, outputs=[state_session_id])
+
+    print(f'server is gonna mount on: http://{server_name}:{server_port}')
+    demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
+        max_threads=10,
+        share=True,
+        server_port=server_port,
+        server_name=server_name,
+    )
diff --git a/lmdeploy/serve/gradio/turbomind_coupled.py b/lmdeploy/serve/gradio/turbomind_coupled.py
new file mode 100644
index 0000000000..dfb38bf89f
--- /dev/null
+++ b/lmdeploy/serve/gradio/turbomind_coupled.py
@@ -0,0 +1,194 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from threading import Lock
+from typing import Sequence
+
+import gradio as gr
+
+from lmdeploy.serve.async_engine import AsyncEngine
+from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
+
+
+class InterFace:
+    async_engine: AsyncEngine = None
+    global_session_id: int = 0
+    lock = Lock()
+
+
+async def chat_stream_local(
+    instruction: str,
+    state_chatbot: Sequence,
+    cancel_btn: gr.Button,
+    reset_btn: gr.Button,
+    session_id: int,
+):
+    """Chat with AI assistant.
+
+    Args:
+        instruction (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        cancel_btn (gr.Button): the cancel button
+        reset_btn (gr.Button): the reset button
+        session_id (int): the session id
+    """
+    state_chatbot = state_chatbot + [(instruction, None)]
+
+    yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
+
+    async for outputs in InterFace.async_engine.generate(
+            instruction,
+            session_id,
+            stream_response=True,
+            sequence_start=(len(state_chatbot) == 1),
+            sequence_end=False):
+        response = outputs.response
+        if outputs.finish_reason == 'length':
+            gr.Warning('WARNING: exceed session max length.'
+                       ' Please restart the session by reset button.')
+        if outputs.generate_token_len < 0:
+            gr.Warning('WARNING: running on the old session.'
+                       ' Please restart the session by reset button.')
+        if state_chatbot[-1][-1] is None:
+            state_chatbot[-1] = (state_chatbot[-1][0], response)
+        else:
+            state_chatbot[-1] = (state_chatbot[-1][0],
+                                 state_chatbot[-1][1] + response
+                                 )  # piece by piece
+        yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
+
+    yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
+
+
+async def reset_local_func(instruction_txtbox: gr.Textbox,
+                           state_chatbot: Sequence, session_id: int):
+    """reset the session.
+
+    Args:
+        instruction_txtbox (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        session_id (int): the session id
+    """
+    state_chatbot = []
+    # end the session
+    async for out in InterFace.async_engine.generate('',
+                                                     session_id,
+                                                     request_output_len=1,
+                                                     stream_response=True,
+                                                     sequence_start=False,
+                                                     sequence_end=True):
+        pass
+    return (state_chatbot, state_chatbot, gr.Textbox.update(value=''))
+
+
+async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
+                            reset_btn: gr.Button, session_id: int):
+    """stop the session.
+
+    Args:
+        instruction_txtbox (str): user's prompt
+        state_chatbot (Sequence): the chatting history
+        cancel_btn (gr.Button): the cancel button
+        reset_btn (gr.Button): the reset button
+        session_id (int): the session id
+    """
+    yield (state_chatbot, disable_btn, enable_btn)
+    async for out in InterFace.async_engine.generate('',
+                                                     session_id,
+                                                     request_output_len=0,
+                                                     stream_response=True,
+                                                     sequence_start=False,
+                                                     sequence_end=False,
+                                                     stop=True):
+        pass
+    messages = []
+    for qa in state_chatbot:
+        messages.append(dict(role='user', content=qa[0]))
+        if qa[1] is not None:
+            messages.append(dict(role='assistant', content=qa[1]))
+    async for out in InterFace.async_engine.generate(messages,
+                                                     session_id,
+                                                     request_output_len=0,
+                                                     stream_response=True,
+                                                     sequence_start=True,
+                                                     sequence_end=False):
+        pass
+    yield (state_chatbot, disable_btn, enable_btn)
+
+
+def run_local(model_path: str,
+              server_name: str = 'localhost',
+              server_port: int = 6006,
+              batch_size: int = 4,
+              tp: int = 1,
+              **kwargs):
+    """chat with AI assistant through web ui.
+
+    Args:
+        model_path (str): the path of the deployed model
+        server_name (str): the ip address of gradio server
+        server_port (int): the port of gradio server
+        batch_size (int): batch size for running Turbomind directly
+        tp (int): tensor parallel for Turbomind
+    """
+    InterFace.async_engine = AsyncEngine(model_path=model_path,
+                                         instance_num=batch_size,
+                                         tp=tp,
+                                         **kwargs)
+
+    with gr.Blocks(css=CSS, theme=THEME) as demo:
+        state_chatbot = gr.State([])
+        state_session_id = gr.State(0)
+
+        with gr.Column(elem_id='container'):
+            gr.Markdown('## LMDeploy Playground')
+
+            chatbot = gr.Chatbot(
+                elem_id='chatbot',
+                label=InterFace.async_engine.tm_model.model_name)
+            instruction_txtbox = gr.Textbox(
+                placeholder='Please input the instruction',
+                label='Instruction')
+            with gr.Row():
+                cancel_btn = gr.Button(value='Cancel', interactive=False)
+                reset_btn = gr.Button(value='Reset')
+
+        send_event = instruction_txtbox.submit(chat_stream_local, [
+            instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
+            state_session_id
+        ], [state_chatbot, chatbot, cancel_btn, reset_btn])
+        instruction_txtbox.submit(
+            lambda: gr.Textbox.update(value=''),
+            [],
+            [instruction_txtbox],
+        )
+        cancel_btn.click(
+            cancel_local_func,
+            [state_chatbot, cancel_btn, reset_btn, state_session_id],
+            [state_chatbot, cancel_btn, reset_btn],
+            cancels=[send_event])
+
+        reset_btn.click(reset_local_func,
+                        [instruction_txtbox, state_chatbot, state_session_id],
+                        [state_chatbot, chatbot, instruction_txtbox],
+                        cancels=[send_event])
+
+        def init():
+            with InterFace.lock:
+                InterFace.global_session_id += 1
+            new_session_id = InterFace.global_session_id
+            return new_session_id
+
+        demo.load(init, inputs=None, outputs=[state_session_id])
+
+    print(f'server is gonna mount on: http://{server_name}:{server_port}')
+    demo.queue(concurrency_count=batch_size, max_size=100,
+               api_open=True).launch(
+                   max_threads=10,
+                   share=True,
+                   server_port=server_port,
+                   server_name=server_name,
+               )
+
+
+if __name__ == '__main__':
+    import fire
+    fire.Fire(run_local)
diff --git a/lmdeploy/serve/openai/api_client.py b/lmdeploy/serve/openai/api_client.py
index a8718331be..a1610e05ea 100644
--- a/lmdeploy/serve/openai/api_client.py
+++ b/lmdeploy/serve/openai/api_client.py
@@ -1,8 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import json
-from typing import Iterable, List
+from typing import Any, Dict, Iterable, List, Optional, Union
 
-import fire
 import requests
 
 
@@ -15,13 +14,306 @@ def get_model_list(api_url: str):
     return None
 
 
+class APIClient:
+    """Chatbot for LLaMA series models with turbomind as inference engine.
+
+    Args:
+        api_server_url (str): communicating address 'http://:' of
+            api_server
+    """
+
+    def __init__(self, api_server_url: str, **kwargs):
+        self.api_server_url = api_server_url
+        self.chat_intractive_v1_url = f'{api_server_url}/v1/chat/interactive'
+        self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions'
+        self.completions_v1_url = f'{api_server_url}/v1/completions'
+        self.models_v1_url = f'{api_server_url}/v1/models'
+        self._available_models = None
+
+    @property
+    def available_models(self):
+        """Show available models."""
+        if self._available_models is not None:
+            return self._available_models
+        response = requests.get(self.models_v1_url)
+        if hasattr(response, 'text'):
+            model_list = json.loads(response.text)
+            model_list = model_list.pop('data', [])
+            self._available_models = [item['id'] for item in model_list]
+            return self._available_models
+        return None
+
+    def chat_completions_v1(self,
+                            model: str,
+                            messages: Union[str, List[Dict[str, str]]],
+                            temperature: Optional[float] = 0.7,
+                            top_p: Optional[float] = 1.0,
+                            n: Optional[int] = 1,
+                            max_tokens: Optional[int] = 512,
+                            stop: Optional[bool] = False,
+                            stream: Optional[bool] = False,
+                            presence_penalty: Optional[float] = 0.0,
+                            frequency_penalty: Optional[float] = 0.0,
+                            user: Optional[str] = None,
+                            repetition_penalty: Optional[float] = 1.0,
+                            session_id: Optional[int] = -1,
+                            ignore_eos: Optional[bool] = False,
+                            **kwargs):
+        """Chat completion v1.
+
+        Args:
+            model: model name. Available from self.available_models.
+            messages: string prompt or chat history in OpenAI format.
+            temperature (float): to modulate the next token probability
+            top_p (float): If set to float < 1, only the smallest set of most
+                probable tokens with probabilities that add up to top_p or
+                higher are kept for generation.
+            n (int): How many chat completion choices to generate for each
+                input message. Only support one here.
+            stream: whether to stream the results or not. Default to false.
+            max_tokens (int): output token nums
+            repetition_penalty (float): The parameter for repetition penalty.
+                1.0 means no penalty
+            ignore_eos (bool): indicator for ignoring eos
+            session_id (int): if not specified, will set random value
+
+        Yields:
+            json objects in openai formats
+        """
+        pload = {
+            k: v
+            for k, v in locals().copy().items()
+            if k[:2] != '__' and k not in ['self']
+        }
+        headers = {'content-type': 'application/json'}
+        response = requests.post(self.chat_completions_v1_url,
+                                 headers=headers,
+                                 json=pload,
+                                 stream=stream)
+        for chunk in response.iter_lines(chunk_size=8192,
+                                         decode_unicode=False,
+                                         delimiter=b'\n'):
+            if chunk:
+                if stream:
+                    decoded = chunk.decode('utf-8')
+                    if decoded == 'data: [DONE]':
+                        continue
+                    if decoded[:6] == 'data: ':
+                        decoded = decoded[6:]
+                    output = json.loads(decoded)
+                    yield output
+                else:
+                    decoded = chunk.decode('utf-8')
+                    output = json.loads(decoded)
+                    yield output
+
+    def chat_interactive_v1(self,
+                            prompt: Union[str, List[Dict[str, str]]],
+                            session_id: int = -1,
+                            interactive_mode: bool = False,
+                            stream: bool = False,
+                            stop: bool = False,
+                            request_output_len: int = 512,
+                            top_p: float = 0.8,
+                            top_k: int = 40,
+                            temperature: float = 0.8,
+                            repetition_penalty: float = 1.0,
+                            ignore_eos: bool = False,
+                            **kwargs):
+        """Interactive completions.
+
+        - On interactive mode, the chat history is kept on the server. Please
+        set `interactive_mode = True`.
+        - On normal mode, no chat history is kept on the server. Set
+        `interactive_mode = False`.
+
+        Args:
+            prompt: the prompt to use for the generation.
+            session_id: determine which instance will be called.
+                If not specified with a value other than -1, using random value
+                directly.
+            interactive_mode (bool): turn on interactive mode or not. On
+                interactive mode, session history is kept on the server (and
+                vice versa).
+            stream: whether to stream the results or not.
+            stop: whether to stop the session response or not.
+            request_output_len (int): output token nums
+            top_p (float): If set to float < 1, only the smallest set of most
+                probable tokens with probabilities that add up to top_p or
+                higher are kept for generation.
+            top_k (int): The number of the highest probability vocabulary
+                tokens to keep for top-k-filtering
+            temperature (float): to modulate the next token probability
+            repetition_penalty (float): The parameter for repetition penalty.
+                1.0 means no penalty
+            ignore_eos (bool): indicator for ignoring eos
+
+        Yields:
+            json objects consist of text, tokens, finish_reason
+        """
+        pload = {
+            k: v
+            for k, v in locals().copy().items()
+            if k[:2] != '__' and k not in ['self']
+        }
+        headers = {'content-type': 'application/json'}
+        response = requests.post(self.chat_intractive_v1_url,
+                                 headers=headers,
+                                 json=pload,
+                                 stream=stream)
+        for chunk in response.iter_lines(chunk_size=8192,
+                                         decode_unicode=False,
+                                         delimiter=b'\n'):
+            if chunk:
+                decoded = chunk.decode('utf-8')
+                output = json.loads(decoded)
+                yield output
+
+    def completions_v1(
+            self,
+            model: str,
+            prompt: Union[str, List[Any]],
+            suffix: Optional[str] = None,
+            temperature: Optional[float] = 0.7,
+            n: Optional[int] = 1,
+            max_tokens: Optional[int] = 16,
+            stream: Optional[bool] = False,
+            top_p: Optional[float] = 1.0,
+            user: Optional[str] = None,
+            # additional argument of lmdeploy
+            repetition_penalty: Optional[float] = 1.0,
+            session_id: Optional[int] = -1,
+            ignore_eos: Optional[bool] = False,
+            **kwargs):
+        """Chat completion v1.
+
+        Args:
+            model (str): model name. Available from /v1/models.
+            prompt (str): the input prompt.
+            suffix (str): The suffix that comes after a completion of inserted
+                text.
+            max_tokens (int): output token nums
+            temperature (float): to modulate the next token probability
+            top_p (float): If set to float < 1, only the smallest set of most
+                probable tokens with probabilities that add up to top_p or
+                higher are kept for generation.
+            n (int): How many chat completion choices to generate for each
+                input message. Only support one here.
+            stream: whether to stream the results or not. Default to false.
+            repetition_penalty (float): The parameter for repetition penalty.
+                1.0 means no penalty
+            user (str): A unique identifier representing your end-user.
+            ignore_eos (bool): indicator for ignoring eos
+            session_id (int): if not specified, will set random value
+
+        Yields:
+            json objects in openai formats
+        """
+        pload = {
+            k: v
+            for k, v in locals().copy().items()
+            if k[:2] != '__' and k not in ['self']
+        }
+        headers = {'content-type': 'application/json'}
+        response = requests.post(self.completions_v1_url,
+                                 headers=headers,
+                                 json=pload,
+                                 stream=stream)
+        for chunk in response.iter_lines(chunk_size=8192,
+                                         decode_unicode=False,
+                                         delimiter=b'\n'):
+            if chunk:
+                if stream:
+                    decoded = chunk.decode('utf-8')[6:]
+                    if decoded == 'data: [DONE]':
+                        continue
+                    if decoded[:6] == 'data: ':
+                        decoded = decoded[6:]
+                    output = json.loads(decoded)
+                    yield output
+                else:
+                    decoded = chunk.decode('utf-8')
+                    output = json.loads(decoded)
+                    yield output
+
+    def chat(self,
+             prompt: str,
+             session_id: int,
+             request_output_len: int = 512,
+             stream: bool = False,
+             top_p: float = 0.8,
+             top_k: int = 40,
+             temperature: float = 0.8,
+             repetition_penalty: float = 1.0,
+             ignore_eos: bool = False):
+        """Chat with a unique session_id.
+
+        Args:
+            prompt: the prompt to use for the generation.
+            session_id: determine which instance will be called.
+                If not specified with a value other than -1, using random value
+                directly.
+            stream: whether to stream the results or not.
+            stop: whether to stop the session response or not.
+            request_output_len (int): output token nums
+            top_p (float): If set to float < 1, only the smallest set of most
+                probable tokens with probabilities that add up to top_p or
+                higher are kept for generation.
+            top_k (int): The number of the highest probability vocabulary
+                tokens to keep for top-k-filtering
+            temperature (float): to modulate the next token probability
+            repetition_penalty (float): The parameter for repetition penalty.
+                1.0 means no penalty
+            ignore_eos (bool): indicator for ignoring eos
+
+        Yields:
+            text, tokens, finish_reason
+        """
+        assert session_id != -1, 'please set a value other than -1'
+        for outputs in self.chat_interactive_v1(
+                prompt,
+                session_id=session_id,
+                request_output_len=request_output_len,
+                interactive_mode=True,
+                stream=stream,
+                top_k=top_k,
+                top_p=top_p,
+                temperature=temperature,
+                repetition_penalty=repetition_penalty,
+                ignore_eos=ignore_eos):
+            if outputs['finish_reason'] == 'length':
+                print('WARNING: exceed session max length.'
+                      ' Please end the session.')
+            yield outputs['text'], outputs['tokens'], outputs['finish_reason']
+
+    def end_session(self, session_id: int):
+        """End the session with a unique session_id.
+
+        Args:
+            session_id: determine which instance will be called.
+                If not specified with a value other than -1, using random value
+                directly.
+        """
+        for out in self.chat_interactive_v1(prompt='',
+                                            session_id=session_id,
+                                            request_output_len=0,
+                                            interactive_mode=False):
+            pass
+
+
+def input_prompt():
+    """Input a prompt in the consolo interface."""
+    print('\ndouble enter to end input >>> ', end='')
+    sentinel = ''  # ends when this string is seen
+    return '\n'.join(iter(input, sentinel))
+
+
 def get_streaming_response(prompt: str,
                            api_url: str,
                            session_id: int,
                            request_output_len: int = 512,
                            stream: bool = True,
-                           sequence_start: bool = True,
-                           sequence_end: bool = True,
+                           interactive_mode: bool = False,
                            ignore_eos: bool = False,
                            stop: bool = False) -> Iterable[List[str]]:
     headers = {'User-Agent': 'Test Client'}
@@ -30,8 +322,7 @@ def get_streaming_response(prompt: str,
         'stream': stream,
         'session_id': session_id,
         'request_output_len': request_output_len,
-        'sequence_start': sequence_start,
-        'sequence_end': sequence_end,
+        'interactive_mode': interactive_mode,
         'ignore_eos': ignore_eos,
         'stop': stop
     }
@@ -50,43 +341,26 @@ def get_streaming_response(prompt: str,
             yield output, tokens, finish_reason
 
 
-def input_prompt():
-    """Input a prompt in the consolo interface."""
-    print('\ndouble enter to end input >>> ', end='')
-    sentinel = ''  # ends when this string is seen
-    return '\n'.join(iter(input, sentinel))
-
-
-def main(restful_api_url: str, session_id: int = 0):
-    nth_round = 1
+def main(api_server_url: str, session_id: int = 0):
+    api_client = APIClient(api_server_url)
     while True:
         prompt = input_prompt()
-        if prompt == 'exit':
-            for output, tokens, finish_reason in get_streaming_response(
-                    '',
-                    f'{restful_api_url}/generate',
-                    session_id=session_id,
-                    request_output_len=0,
-                    sequence_start=(nth_round == 1),
-                    sequence_end=True):
-                pass
-            exit(0)
+        if prompt in ['exit', 'end']:
+            api_client.end_session(session_id)
+            if prompt == 'exit':
+                exit(0)
         else:
-            for output, tokens, finish_reason in get_streaming_response(
+            for text, tokens, finish_reason in api_client.chat(
                     prompt,
-                    f'{restful_api_url}/generate',
                     session_id=session_id,
                     request_output_len=512,
-                    sequence_start=(nth_round == 1),
-                    sequence_end=False):
+                    stream=True):
                 if finish_reason == 'length':
-                    print('WARNING: exceed session max length.'
-                          ' Please end the session.')
                     continue
-                print(output, end='')
-
-            nth_round += 1
+                print(text, end='')
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py
index 94271c4b9b..0b61f7967b 100644
--- a/lmdeploy/serve/openai/api_server.py
+++ b/lmdeploy/serve/openai/api_server.py
@@ -1,10 +1,11 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
 import os
+import random
 import time
 from http import HTTPStatus
 from typing import AsyncGenerator, List, Optional
 
-import fire
 import uvicorn
 from fastapi import FastAPI, Request
 from fastapi.middleware.cors import CORSMiddleware
@@ -14,12 +15,12 @@
 from lmdeploy.serve.openai.protocol import (  # noqa: E501
     ChatCompletionRequest, ChatCompletionResponse,
     ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
-    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest,
-    EmbeddingsResponse, ErrorResponse, GenerateRequest, GenerateResponse,
+    ChatCompletionStreamResponse, ChatMessage, CompletionRequest,
+    CompletionResponse, CompletionResponseChoice,
+    CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
+    EmbeddingsRequest, ErrorResponse, GenerateRequest, GenerateResponse,
     ModelCard, ModelList, ModelPermission, UsageInfo)
 
-os.environ['TM_LOG_LEVEL'] = 'ERROR'
-
 
 class VariableInterface:
     """A IO interface maintaining variables."""
@@ -105,9 +106,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
         1.0 means no penalty
 
     Additional arguments supported by LMDeploy:
-    - renew_session (bool): Whether renew the session. Can be used when the
-        session length is exceeded.
     - ignore_eos (bool): indicator for ignoring eos
+    - session_id (int): if not specified, will set random value
 
     Currently we do not support the following features:
     - function_call (Users should implement this by themselves)
@@ -115,20 +115,22 @@ async def chat_completions_v1(request: ChatCompletionRequest,
     - presence_penalty (replaced with repetition_penalty)
     - frequency_penalty (replaced with repetition_penalty)
     """
-    session_id = ip2id(raw_request.client.host)
+    if request.session_id == -1:
+        request.session_id = random.randint(1, 10086)
     error_check_ret = await check_request(request)
     if error_check_ret is not None:
         return error_check_ret
 
     model_name = request.model
-    request_id = str(session_id)
+    request_id = str(request.session_id)
     created_time = int(time.time())
 
-    result_generator = VariableInterface.async_engine.generate_openai(
+    result_generator = VariableInterface.async_engine.generate(
         request.messages,
-        session_id,
+        request.session_id,
         True,  # always use stream to enable batching
-        request.renew_session,
+        sequence_start=True,
+        sequence_end=True,
         request_output_len=request.max_tokens if request.max_tokens else 512,
         stop=request.stop,
         top_p=request.top_p,
@@ -189,7 +191,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
     async for res in result_generator:
         if await raw_request.is_disconnected():
             # Abort the request if the client disconnects.
-            VariableInterface.async_engine.stop_session(session_id)
+            VariableInterface.async_engine.stop_session(request.session_id)
             return create_error_response(HTTPStatus.BAD_REQUEST,
                                          'Client disconnected')
         final_res = res
@@ -223,43 +225,191 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
     return response
 
 
-@app.post('/v1/embeddings')
-async def create_embeddings(request: EmbeddingsRequest,
-                            raw_request: Request = None):
-    """Creates embeddings for the text."""
+@app.post('/v1/completions')
+async def completions_v1(request: CompletionRequest,
+                         raw_request: Request = None):
+    """Completion API similar to OpenAI's API.
+
+    Go to `https://platform.openai.com/docs/api-reference/completions/create`
+    for the API specification.
+
+    The request should be a JSON object with the following fields:
+    - model (str): model name. Available from /v1/models.
+    - prompt (str): the input prompt.
+    - suffix (str): The suffix that comes after a completion of inserted text.
+    - max_tokens (int): output token nums
+    - temperature (float): to modulate the next token probability
+    - top_p (float): If set to float < 1, only the smallest set of most
+        probable tokens with probabilities that add up to top_p or higher
+        are kept for generation.
+    - n (int): How many chat completion choices to generate for each input
+        message. Only support one here.
+    - stream: whether to stream the results or not. Default to false.
+    - repetition_penalty (float): The parameter for repetition penalty.
+        1.0 means no penalty
+    - user (str): A unique identifier representing your end-user.
+
+    Additional arguments supported by LMDeploy:
+    - ignore_eos (bool): indicator for ignoring eos
+    - session_id (int): if not specified, will set random value
+
+    Currently we do not support the following features:
+    - logprobs (not supported yet)
+    - presence_penalty (replaced with repetition_penalty)
+    - frequency_penalty (replaced with repetition_penalty)
+    """
+    if request.session_id == -1:
+        request.session_id = random.randint(1, 10086)
     error_check_ret = await check_request(request)
     if error_check_ret is not None:
         return error_check_ret
 
-    embedding = await VariableInterface.async_engine.get_embeddings(
-        request.input)
-    data = [{'object': 'embedding', 'embedding': embedding, 'index': 0}]
-    token_num = len(embedding)
-    return EmbeddingsResponse(
-        data=data,
-        model=request.model,
-        usage=UsageInfo(
-            prompt_tokens=token_num,
-            total_tokens=token_num,
-            completion_tokens=None,
-        ),
-    ).dict(exclude_none=True)
-
-
-@app.post('/generate')
-async def generate(request: GenerateRequest, raw_request: Request = None):
+    model_name = request.model
+    request_id = str(request.session_id)
+    created_time = int(time.time())
+    if isinstance(request.prompt, str):
+        request.prompt = [request.prompt]
+    generators = []
+    for i in range(len(request.prompt)):
+        result_generator = VariableInterface.async_engine.generate(
+            request.prompt[i],
+            request.session_id + i,
+            True,  # always use stream to enable batching
+            sequence_start=True,
+            sequence_end=True,
+            request_output_len=request.max_tokens
+            if request.max_tokens else 512,
+            stop=False,
+            top_p=request.top_p,
+            temperature=request.temperature,
+            repetition_penalty=request.repetition_penalty,
+            ignore_eos=request.ignore_eos,
+            do_preprocess=False)
+        generators.append(result_generator)
+
+    def create_stream_response_json(
+        index: int,
+        text: str,
+        finish_reason: Optional[str] = None,
+    ) -> str:
+        choice_data = CompletionResponseStreamChoice(
+            index=index,
+            text=text,
+            finish_reason=finish_reason,
+        )
+        response = CompletionStreamResponse(
+            id=request_id,
+            created=created_time,
+            model=model_name,
+            choices=[choice_data],
+        )
+        response_json = response.model_dump_json()
+
+        return response_json
+
+    async def completion_stream_generator() -> AsyncGenerator[str, None]:
+        # First chunk with role
+        for generator in generators:
+            for i in range(request.n):
+                choice_data = CompletionResponseStreamChoice(
+                    index=i,
+                    text='',
+                    finish_reason=None,
+                )
+                chunk = CompletionStreamResponse(id=request_id,
+                                                 choices=[choice_data],
+                                                 model=model_name)
+                data = chunk.model_dump_json(exclude_unset=True)
+                yield f'data: {data}\n\n'
+
+            async for res in generator:
+                response_json = create_stream_response_json(
+                    index=0,
+                    text=res.response,
+                )
+                yield f'data: {response_json}\n\n'
+        yield 'data: [DONE]\n\n'
+
+    # Streaming response
+    if request.stream:
+        return StreamingResponse(completion_stream_generator(),
+                                 media_type='text/event-stream')
+
+    # Non-streaming response
+    usage = UsageInfo()
+    choices = []
+
+    async def _inner_call(i, generator):
+        final_res = None
+        text = ''
+        async for res in generator:
+            if await raw_request.is_disconnected():
+                # Abort the request if the client disconnects.
+                VariableInterface.async_engine.stop_session(request.session_id)
+                return create_error_response(HTTPStatus.BAD_REQUEST,
+                                             'Client disconnected')
+            final_res = res
+            text += res.response
+        assert final_res is not None
+        choice_data = CompletionResponseChoice(
+            index=0,
+            text=text,
+            finish_reason=final_res.finish_reason,
+        )
+        choices.append(choice_data)
+
+        total_tokens = sum([
+            final_res.history_token_len, final_res.input_token_len,
+            final_res.generate_token_len
+        ])
+        usage.prompt_tokens += final_res.input_token_len
+        usage.completion_tokens += final_res.generate_token_len
+        usage.total_tokens += total_tokens
+
+    await asyncio.gather(
+        *[_inner_call(i, generators[i]) for i in range(len(generators))])
+
+    response = CompletionResponse(
+        id=request_id,
+        created=created_time,
+        model=model_name,
+        choices=choices,
+        usage=usage,
+    )
+
+    return response
+
+
+@app.post('/v1/embeddings', tags=['unsupported'])
+async def create_embeddings(request: EmbeddingsRequest,
+                            raw_request: Request = None):
+    """Creates embeddings for the text."""
+    return create_error_response(HTTPStatus.BAD_REQUEST,
+                                 'Unsupported by turbomind.')
+
+
+@app.post('/generate',
+          tags=['deprecated'],
+          description='please use /v1/chat/interactive')
+@app.post('/v1/chat/interactive')
+async def chat_interactive_v1(request: GenerateRequest,
+                              raw_request: Request = None):
     """Generate completion for the request.
 
+    - On interactive mode, the chat history is kept on the server. Please set
+    `interactive_mode = True`.
+    - On normal mode, no chat history is kept on the server. Set
+    `interactive_mode = False`.
+
     The request should be a JSON object with the following fields:
     - prompt: the prompt to use for the generation.
     - session_id: determine which instance will be called. If not specified
-        with a value other than -1, using host ip directly.
-    - sequence_start (bool): indicator for starting a sequence.
-    - sequence_end (bool): indicator for ending a sequence
+        with a value other than -1, using random value directly.
+    - interactive_mode (bool): turn on interactive mode or not. On interactive
+        mode, session history is kept on the server (and vice versa).
     - stream: whether to stream the results or not.
     - stop: whether to stop the session response or not.
     - request_output_len (int): output token nums
-    - step (int): the offset of the k/v cache
     - top_p (float): If set to float < 1, only the smallest set of most
         probable tokens with probabilities that add up to top_p or higher
         are kept for generation.
@@ -271,15 +421,18 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
     - ignore_eos (bool): indicator for ignoring eos
     """
     if request.session_id == -1:
-        session_id = ip2id(raw_request.client.host)
-        request.session_id = session_id
+        request.session_id = random.randint(10087, 23333)
+
+    async_engine = VariableInterface.async_engine
+    sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0
+    sequence_end = not request.interactive_mode
 
-    generation = VariableInterface.async_engine.generate(
+    generation = async_engine.generate(
         request.prompt,
         request.session_id,
         stream_response=True,  # always use stream to enable batching
-        sequence_start=request.sequence_start,
-        sequence_end=request.sequence_end,
+        sequence_start=sequence_start,
+        sequence_end=sequence_end,
         request_output_len=request.request_output_len,
         top_p=request.top_p,
         top_k=request.top_k,
@@ -308,7 +461,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
         async for out in generation:
             if await raw_request.is_disconnected():
                 # Abort the request if the client disconnects.
-                VariableInterface.async_engine.stop_session(session_id)
+                async_engine.stop_session(request.session_id)
                 return create_error_response(HTTPStatus.BAD_REQUEST,
                                              'Client disconnected')
             text += out.response
@@ -319,14 +472,16 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
 
 
 def main(model_path: str,
-         server_name: str = 'localhost',
+         server_name: str = '0.0.0.0',
          server_port: int = 23333,
-         instance_num: int = 32,
+         instance_num: int = 64,
          tp: int = 1,
          allow_origins: List[str] = ['*'],
          allow_credentials: bool = True,
          allow_methods: List[str] = ['*'],
-         allow_headers: List[str] = ['*']):
+         allow_headers: List[str] = ['*'],
+         log_level: str = 'ERROR',
+         **kwargs):
     """An example to perform model inference through the command line
     interface.
 
@@ -340,7 +495,10 @@ def main(model_path: str,
         allow_credentials (bool): whether to allow credentials for CORS
         allow_methods (List[str]): a list of allowed HTTP methods for CORS
         allow_headers (List[str]): a list of allowed HTTP headers for CORS
-    """
+        log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
+    """ # noqa E501
+    os.environ['TM_LOG_LEVEL'] = log_level
+
     if allow_origins:
         app.add_middleware(
             CORSMiddleware,
@@ -352,9 +510,15 @@ def main(model_path: str,
 
     VariableInterface.async_engine = AsyncEngine(model_path=model_path,
                                                  instance_num=instance_num,
-                                                 tp=tp)
+                                                 tp=tp,
+                                                 **kwargs)
+    for i in range(3):
+        print(f'HINT:    Please open \033[93m\033[1mhttp://{server_name}:'
+              f'{server_port}\033[0m in a browser for detailed api usage!!!')
     uvicorn.run(app=app, host=server_name, port=server_port, log_level='info')
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py
index 756af1a4ca..bee2e2c91c 100644
--- a/lmdeploy/serve/openai/protocol.py
+++ b/lmdeploy/serve/openai/protocol.py
@@ -70,7 +70,7 @@ class ChatCompletionRequest(BaseModel):
     user: Optional[str] = None
     # additional argument of lmdeploy
     repetition_penalty: Optional[float] = 1.0
-    renew_session: Optional[bool] = False
+    session_id: Optional[int] = -1
     ignore_eos: Optional[bool] = False
 
 
@@ -135,6 +135,10 @@ class CompletionRequest(BaseModel):
     presence_penalty: Optional[float] = 0.0
     frequency_penalty: Optional[float] = 0.0
     user: Optional[str] = None
+    # additional argument of lmdeploy
+    repetition_penalty: Optional[float] = 1.0
+    session_id: Optional[int] = -1
+    ignore_eos: Optional[bool] = False
 
 
 class CompletionResponseChoice(BaseModel):
@@ -175,7 +179,7 @@ class CompletionStreamResponse(BaseModel):
 class EmbeddingsRequest(BaseModel):
     """Embedding request."""
     model: str = None
-    input: Union[str, List[Any]]
+    input: Union[str, List[str]]
     user: Optional[str] = None
 
 
@@ -191,8 +195,7 @@ class GenerateRequest(BaseModel):
     """Generate request."""
     prompt: Union[str, List[Dict[str, str]]]
     session_id: int = -1
-    sequence_start: bool = True
-    sequence_end: bool = False
+    interactive_mode: bool = False
     stream: bool = False
     stop: bool = False
     request_output_len: int = 512
diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py
index cc12fcff3b..e13aa9e4d4 100644
--- a/lmdeploy/serve/turbomind/chatbot.py
+++ b/lmdeploy/serve/turbomind/chatbot.py
@@ -67,7 +67,6 @@ class Chatbot:
         model_name (str): name of the to-be-deployed mode
         log_level (int): the level of the log
         display (bool): display the generated text on consolo or not
-        profile_generation (bool): profile token generation or not
     """
 
     def __init__(self,
@@ -76,8 +75,6 @@ def __init__(self,
                  ignore_eos: bool = False,
                  log_level: int = logging.INFO,
                  display: bool = False,
-                 profile_generation: bool = False,
-                 profile_serving: bool = False,
                  **model_kwargs):
         self.tritonserver_addr = tritonserver_addr
         self.model_name = model_name
@@ -97,6 +94,7 @@ def __init__(self,
         if ignore_eos:
             stop_words = None
             bad_words = np.array([[[self.eos_id], [1]]], dtype=np.int32)
+            self.eos_id = -1
         self.cfg = mmengine.Config(
             dict(session_len=self.model.session_len,
                  top_p=self.model.top_p,
@@ -107,8 +105,6 @@ def __init__(self,
                  bad_words=bad_words))
         self.log_level = log_level
         self.display = display
-        self.profile_generation = profile_generation
-        self.profile_serving = profile_serving
 
     def stream_infer(self,
                      session_id: int,
@@ -416,8 +412,6 @@ def _stop_words(self, stop_words: List[str]):
     def _get_prompt(self, prompt: str, sequence_start: bool):
         """return the concatenated prompt according to the model's chat
         template."""
-        if self.profile_generation or self.profile_serving:
-            return prompt
         return self.model.get_prompt(prompt, sequence_start)
 
     def _stream_infer(self,
@@ -459,10 +453,16 @@ def _stream_infer(self,
             session.sequence_length = 0
 
         input_ids, input_lengths = self.preprocess(prompt)
+        # got input_ids with default add_bos == True
+        if not sequence_start and input_ids[0][0] == self.bos_id:
+            input_ids = input_ids[:, 1:]
+            input_lengths = input_lengths - 1
+        # will crash if last_token_id == eos_id and send empty input_ids
+        if sequence_end and request_output_len == 0:
+            input_ids = np.array([[1]], dtype=np.uint32)
+            input_lengths = np.array([[1]], dtype=np.uint32)
         input_tokens = input_lengths.squeeze()
-        if self.profile_generation:
-            yield StatusCode.TRITON_STREAM_ING, \
-                  'ignore preprocessing during profiling generation', 0
+
         if request_output_len is None:
             request_output_len = max(
                 128,
@@ -498,8 +498,7 @@ def _stream_infer(self,
         producer.start()
         for status, res, n_token in self.stream_consumer(
                 self.postprocess, que, session, input_tokens, preseq_length,
-                cancel, logger, self.display, self.profile_generation,
-                self.eos_id):
+                cancel, logger, self.display, self.eos_id):
             yield status, res, n_token
 
         producer.join()
@@ -592,8 +591,7 @@ def _stream_producer(tritonserver_addr, session, que, cfg, input_ids,
 
     @staticmethod
     def stream_consumer(postprocess, res_queue, session, n_input_token,
-                        preseq_length, cancel, logger, display,
-                        profile_generation, eos_id):
+                        preseq_length, cancel, logger, display, eos_id):
         """Consume the response from the triton inference server.
 
         Args:
@@ -606,7 +604,6 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
             cancel (bool): indicator for cancelling the session
             logger (util.Logger):
             display (bool): display the text in the consolo interface or not
-            profile_generation (bool): indicator for profiling token generation
             eos_id (int): eos token id
 
         Yields:
@@ -650,15 +647,15 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
                     session.sequence_length = session.sequence_length - 1
                     output_ids = output_ids[:, :, :-1]
 
-                if profile_generation:
-                    yield (StatusCode.TRITON_STREAM_ING,
-                           'postprocessing is ignored during profiling '
-                           'token generation', output_ids.shape[-1])
-                    continue
                 output_str = postprocess(
                     output_ids, np.array([[n_token]], dtype=np.uint32))
-                n_token = output_ids.shape[-1]
                 text = output_str[0].decode()
+                # utf-8 char at the end means it's a potential unfinished
+                # byte sequence, continue to concate it with the next
+                # sequence and decode them together
+                if text.endswith('�'):
+                    continue
+                n_token = output_ids.shape[-1]
                 if display:
                     print(text, end='', flush=True)
                 session.response += text
@@ -668,7 +665,10 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
                 logger.error(f'catch exception: {e}')
                 logger.error(
                     f'session {session.session_id}: prompt: {session.prompt}')
-
+        # `n_token` might be not updated since `if text.endswith('�')`
+        if n_token != output_ids.shape[-1]:
+            n_token = output_ids.shape[-1]
+            session.response += text
         # put session back to queue so that `_stream_infer` can update it in
         # `self.sessions`
         while not res_queue.empty():
diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py
deleted file mode 100644
index cc8db88f5c..0000000000
--- a/lmdeploy/serve/turbomind/deploy.py
+++ /dev/null
@@ -1,1046 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import configparser
-import json
-import os
-import os.path as osp
-import re
-import shutil
-import sys
-from pathlib import Path
-
-import fire
-import safetensors
-import torch
-from safetensors.torch import load_file
-from sentencepiece import SentencePieceProcessor
-
-import lmdeploy
-from lmdeploy.model import MODELS
-
-supported_formats = ['llama', 'hf', 'awq', 'qwen']
-
-
-def get_package_root_path():
-    import lmdeploy
-    return Path(lmdeploy.__file__).parent
-
-
-def create_workspace(_path: str):
-    """Create a workspace.
-
-    Args:
-        _path (str): the path of the workspace
-    Returns:
-        bool: success or not
-    """
-    try:
-        if osp.exists(_path):
-            shutil.rmtree(_path)
-        os.makedirs(_path)
-        print(f'create workspace in directory {_path}')
-        return True
-    except Exception as e:
-        print(f'create workspace in {_path} failed: {e}')
-        return False
-
-
-def destroy_workspace(_path: str):
-    """destroy workspace.
-
-    Args:
-        _path(str): the path of the workspace
-    Returns:
-        bool: success or not
-    """
-    try:
-        shutil.rmtree(_path)
-        print(f'destroy workspace in directory {_path}')
-        return True
-    except Exception as e:
-        print(f'destroy workspace in {_path} failed: {e}')
-        return False
-
-
-def copy_triton_model_templates(_path: str):
-    """copy triton model templates to the specified path.
-
-    Args:
-        _path (str): the target path
-    Returns:
-        str: the path of the triton models
-    """
-    try:
-        cur_path = osp.abspath(__file__)
-        dir_path = osp.dirname(cur_path)
-        triton_models_path = osp.join(dir_path, 'triton_models')
-        dst_path = osp.join(_path, 'triton_models')
-        shutil.copytree(triton_models_path, dst_path, symlinks=True)
-        print(f'copy triton model templates from "{triton_models_path}" to '
-              f'"{dst_path}" successfully')
-        shutil.copy(osp.join(dir_path, 'service_docker_up.sh'), _path)
-        return dst_path
-    except Exception as e:
-        print(f'copy triton model templates from "{triton_models_path}"'
-              f' to "{dst_path}" failed: {e}')
-        return None
-
-
-def tokenizer_info_sp(model_path: str):
-    """Return the vocabulary size, bos token id and eos token id.
-
-    Args:
-        model_path (str): the tokenizer model's path
-    Returns:
-        tuple: vocabulary size, bos token id and eos token id
-    """
-    assert os.path.isfile(model_path), model_path
-    sp_model = SentencePieceProcessor(model_file=model_path)
-    # BOS / EOS token IDs
-    n_words = sp_model.vocab_size()
-    bos_id = sp_model.bos_id()
-    eos_id = sp_model.eos_id()
-    return n_words, bos_id, eos_id
-
-
-def tokenizer_info_qwen(model_dir: str):
-    n_words = 151851
-    bos_id = 0
-    eos_id = 151643
-    return n_words, bos_id, eos_id
-
-
-def load_checkpoint(model_path):
-    """Load checkpoint files into torch format.
-
-    Args:
-        model_path (str): the checkpoint folder
-    Returns:
-        Dict[str, torch.Tensor]: weight in torch format
-    """
-    suffixes = ['.safetensors', '.bin']
-    for suffix in suffixes:
-        files = [
-            file for file in os.listdir(model_path) if file.endswith(suffix)
-        ]
-        if len(files) > 0:
-            break
-
-    assert len(files) > 0, f'could not find checkpoints in {model_path}'
-    files = sorted(files)
-    print(files)
-    params = {}
-    for file in files:
-        if file.endswith('.bin'):
-            tmp = torch.load(osp.join(model_path, file), map_location='cpu')
-        else:
-            tmp = load_file(osp.join(model_path, file))
-        params.update(tmp)
-    return params
-
-
-def export(model_name: str,
-           num_layer: int,
-           norm_eps: float,
-           kv_head_num: int,
-           model_params: dict,
-           tokenizer_path: str,
-           out_dir: str,
-           tp: int,
-           size_per_head: int = 128,
-           group_size: int = 0,
-           weight_type: str = 'fp16',
-           max_position_embeddings: int = 0,
-           use_dynamic_ntk: int = 0,
-           use_logn_attn: int = 0,
-           rope_theta: float = 10000.0,
-           tokenizer_info=tokenizer_info_sp):
-    """Export deploying information to a config file.
-
-    Args:
-        model_name (str): model's name
-        num_layer (int): the number of transformer blocks
-        norm_eps (float): norm epsilon
-        model_params (dict): parameters of a model
-        tokenizer_path (str): the tokenizer model's path
-        out_dir (str): the path of the output directory
-        tp (int): the number of tensor parallelism
-        size_per_head (int): the dimension of each head
-    """
-    out_dir = osp.join(out_dir, 'weights')
-    os.makedirs(out_dir, exist_ok=True)
-
-    def save_bin(param: torch.Tensor, name):
-        print(name, param.shape)
-        if param.dtype in [torch.float, torch.bfloat16]:
-            param = param.half()
-        param.contiguous().cpu().numpy().tofile(osp.join(out_dir, name))
-
-    attn_bias = False
-    inter_size = 0
-
-    tok_embeddings = model_params['tok_embeddings.weight']
-    _vocab_size, dim = tok_embeddings.shape
-    head_num = dim // size_per_head
-    if _vocab_size % tp != 0:
-        # Resolve https://github.com/InternLM/lmdeploy/issues/266
-        # Pad tok_embeddings and output weights, making their shape divisible by TP # noqa: E501
-        pad_size = (_vocab_size + tp - 1) // tp * tp - _vocab_size
-        # Pad weight at the bottom of dim 0
-        model_params['tok_embeddings.weight'] = torch.nn.functional.pad(
-            tok_embeddings, (0, 0, 0, pad_size), 'constant', 0)
-        # Pad output weight at the bottom of dim 0
-        model_params['output.weight'] = torch.nn.functional.pad(
-            model_params['output.weight'], (0, 0, 0, pad_size), 'constant', 0)
-
-    # reverse the splitting axes since the weights are transposed above
-    for param_name, param_data in model_params.items():
-        split_dim = None
-        key, ext = param_name.split('.')[-2:]
-        if key == 'w_qkv' and ext == 'bias':
-            attn_bias = True
-        copy = False
-        if key in ['w1', 'w3', 'w13', 'w_qkv']:
-            split_dim = -1
-            # TODO: move parameter extraction outside of the loop
-            if key == 'w1':
-                inter_size = max(inter_size, param_data.shape[-1])
-            elif key == 'w13':
-                inter_size = max(inter_size, param_data.shape[-1] // 2)
-        elif key in ['w2', 'wo']:
-            if ext in ['bias']:
-                copy = True
-            else:
-                split_dim = 0
-        if split_dim is not None:
-            print(f'*** splitting {param_name}, shape={param_data.shape}, '
-                  f'split_dim={split_dim}')
-            assert param_data.shape[split_dim] % tp == 0
-            split_size = param_data.shape[split_dim] // tp
-            splits = torch.split(param_data, split_size, dim=split_dim)
-            for i, split in enumerate(splits):
-                prefix, ext = osp.splitext(param_name)
-                save_bin(split, f'{prefix}.{i}{ext}')
-        elif copy:
-            print(f'### copying {param_name}, shape={param_data.shape}')
-            copies = [param_data] * tp
-            for i, copy in enumerate(copies):
-                prefix, ext = osp.splitext(param_name)
-                save_bin(copy, f'{prefix}.{i}{ext}')
-        else:
-            save_bin(param_data, param_name)
-
-    assert inter_size > 0
-
-    # export config and save it to {out_dir}/config.ini
-    model = MODELS.get(model_name)()
-    vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path)
-    assert _vocab_size >= vocab_size, \
-        f'different vocab size {_vocab_size} vs {vocab_size}'
-    cfg = dict(llama=dict(
-        model_name=model_name,
-        head_num=head_num,
-        kv_head_num=kv_head_num,
-        size_per_head=size_per_head,
-        vocab_size=_vocab_size,
-        num_layer=num_layer,
-        rotary_embedding=size_per_head,
-        rope_theta=rope_theta,
-        inter_size=inter_size,
-        norm_eps=norm_eps,
-        attn_bias=int(attn_bias),
-        start_id=bos_id,
-        end_id=eos_id,
-        weight_type=weight_type,
-        group_size=group_size,
-        # parameters for turbomind
-        max_batch_size=32,
-        max_context_token_num=4,
-        session_len=model.session_len + 8,
-        step_length=1,
-        cache_max_entry_count=48,
-        cache_chunk_size=1,
-        use_context_fmha=1,
-        quant_policy=0,
-        tensor_para_size=tp,
-        # extra attention params
-        max_position_embeddings=max_position_embeddings,
-        use_dynamic_ntk=int(use_dynamic_ntk),
-        use_logn_attn=int(use_logn_attn),
-    ))
-
-    config = configparser.ConfigParser()
-    for section, key_values in cfg.items():
-        config[section] = key_values
-
-    config_path = osp.join(out_dir, 'config.ini')
-    with open(config_path, 'w') as f:
-        config.write(f)
-    return True
-
-
-def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
-              dim: int):
-
-    def reshape(x):
-        return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)
-
-    qkv = torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)
-
-    # (input_dim, head_num + 2 * kv_head_num)
-    return qkv.view(q.size(0), -1)
-
-
-def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
-                 triton_models_path: str, tp: int):
-    """Deploy a model with huggingface transformers' format.
-
-    Args:
-        model_name (str): the name of the to-be-deployed model
-        model_path (str): the path of the directory where the model weight
-          files are
-        tokenizer_path (str): the path of the tokenizer model path
-        triton_models_path (str): the path of the exported triton models
-        tp (int): the number of tensor parallelism
-    """
-    if osp.exists(tokenizer_path):
-        shutil.copy(tokenizer_path,
-                    osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
-        with get_package_root_path() as root_path:
-            shutil.copy(osp.join(root_path, 'tokenizer.py'),
-                        osp.join(triton_models_path, 'tokenizer'))
-    else:
-        print(f'tokenizer model {tokenizer_path} does not exist')
-        return False
-    # read model arguments from params.json
-    try:
-        params_path = osp.join(model_path, 'params.json')
-        with open(params_path) as f:
-            model_arg = json.load(f)
-            num_layer = model_arg['n_layers']
-            norm_eps = model_arg['norm_eps']
-            head_num = model_arg.get('n_heads', 32)
-            kv_head_num = model_arg.get('n_kv_heads', head_num)
-    except Exception as e:
-        print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}')
-        return False
-
-    # convert weights from llama to turbomind format
-    checkpoints = []
-    for pattern in ['*.pth', '*.pt']:
-        checkpoints += sorted(Path(model_path).glob(pattern))
-    print(checkpoints)
-    n_ckpt = len(checkpoints)
-    model_params = {}
-
-    def get_param(_name, _size):
-        print(_name, _size)
-        if _name not in model_params:
-            model_params[_name] = torch.zeros(_size,
-                                              dtype=torch.float16,
-                                              device='cpu')
-        return model_params[_name]
-
-    for i, ckpt_path in enumerate(checkpoints):
-        ckpt = torch.load(ckpt_path, map_location='cpu')
-        for param_name, param_data in ckpt.items():
-            key, ext = param_name.split('.')[-2:]
-            # column-parallel
-            if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'output']:
-                size = param_data.size(0)
-                if ext == 'weight':
-                    param = get_param(
-                        param_name,
-                        [size * n_ckpt, param_data.size(1)])
-                    param.data[size * i:size * (i + 1), :] = param_data
-                else:  # bias
-                    param = get_param(param_name, [size * n_ckpt])
-                    param.data[size * i:size * (i + 1)] = param_data
-            # row-parallel
-            elif key in ['w2', 'wo', 'tok_embeddings']:
-                size = param_data.size(-1)
-                if ext == 'weight':
-                    param = get_param(param_name,
-                                      [param_data.size(0), size * n_ckpt])
-                    param.data[:, size * i:size * (i + 1)] = param_data
-                else:  # bias
-                    param = get_param(param_name, [size])
-                    param.data = param_data
-            elif i == 0:
-                param = get_param(param_name, param_data.size())
-                param.data = param_data
-        del ckpt
-
-    for name, param in model_params.items():
-        # transpose all weights as TurboMind is expecting column-major
-        # weights: (output_dims, input_dims) -> (input_dims, output_dims)
-        key = name.split('.')[-2]
-        if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'w2', 'wo']:
-            param.data = param.data.t()
-
-    # concat qkv projection
-    for t in ['weight', 'bias']:
-        for i in range(1000):
-            _qkv = [
-                f'layers.{i}.attention.{k}.{t}' for k in ['wq', 'wk', 'wv']
-            ]
-            try:
-                qkv = tuple(map(model_params.pop, _qkv))
-            except KeyError:
-                break
-            # concat by heads
-            qkv = merge_qkv(*qkv, tp, dim=2 if t == 'weight' else 1)
-            print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape)
-            model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv
-
-    assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}'
-
-    return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
-                  tokenizer_path, triton_models_path, tp)
-
-
-def permute(x: torch.Tensor):
-    SIZE_PER_HEAD = 128
-    if x.shape[-1] > 1:
-        dim = x.shape[-1]
-        n_heads = dim // SIZE_PER_HEAD
-        return x.view(-1, n_heads, 2,
-                      dim // n_heads // 2).transpose(2, 3).reshape(-1, dim)
-    else:  # scales, zeros
-        dim = x.shape[0]
-        n_heads = dim // SIZE_PER_HEAD
-        return x.view(n_heads, 2, dim // n_heads // 2,
-                      1).transpose(1, 2).reshape(dim, 1)
-
-
-def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
-              triton_models_path: str, tp: int):
-    """Deploy a model with huggingface transformers' format.
-
-    Args:
-        model_name (str): the name of the to-be-deployed model
-        model_path (str): the path of the directory where the model weight
-          files are
-        tokenizer_path (str): the path of the tokenizer model path
-        triton_models_path (str): the path of the exported triton models
-        tp (int): the number of tensor parallelism
-    """
-    if tokenizer_path is None:
-        tokenizer_path = osp.join(model_path, 'tokenizer.model')
-    if osp.exists(tokenizer_path):
-        shutil.copy(tokenizer_path,
-                    osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
-        for _file in os.listdir(model_path):
-            if _file.endswith('.json') or _file.endswith('.py'):
-                json_path = osp.join(model_path, _file)
-                shutil.copy(json_path,
-                            osp.join(triton_models_path, 'tokenizer', _file))
-        with get_package_root_path() as root_path:
-            shutil.copy(osp.join(root_path, 'tokenizer.py'),
-                        osp.join(triton_models_path, 'tokenizer'))
-    else:
-        print(f'tokenizer model {tokenizer_path} does not exist')
-        exit(-1)
-
-    # read model arguments from params.json
-    try:
-        params_path = osp.join(model_path, 'config.json')
-        with open(params_path) as f:
-            model_arg = json.load(f)
-            num_layer = model_arg['num_hidden_layers']
-            norm_eps = model_arg['rms_norm_eps']
-            rope_theta = float(model_arg.get('rope_theta', 10000.0))
-            max_position_embeddings = int(
-                model_arg.get('max_position_embeddings', 0))
-            repo_scaling = bool(model_arg.get('rope_scaling', False))
-            if 'num_key_value_heads' in model_arg:
-                kv_head_num = model_arg['num_key_value_heads']
-            else:
-                kv_head_num = model_arg['num_attention_heads']
-    except Exception as e:
-        print(f'get "num_hidden_layers" and "rms_norm_eps" from '
-              f'{params_path} failed: {e}')
-        return False
-
-    # convert weights from hf to turbomind
-    model_params = {}
-
-    _qweight = 'weight'
-    _suffixes = [_qweight, 'bias']
-
-    _params = load_checkpoint(model_path)
-
-    def get_tensor(name):
-        """return tensor according its name."""
-        return _params[name]
-
-    def get_tensor_transposed(name: str):
-        """return a transposed tensor according its name."""
-        if name not in _params and name.find('bias'):
-            return None
-        return _params[name].t()
-
-    w_pack = False
-    if 'model.layers.0.self_attn.W_pack.weight' in _params:
-        w_pack = True
-
-    for i in range(1000):
-        try:
-            # attention weights
-            for suffix in _suffixes:
-                if w_pack:
-                    _qkvo = [
-                        f'model.layers.{i}.self_attn.{t}'
-                        for t in ['W_pack', 'o_proj']
-                    ]
-                    qkv, o = map(get_tensor_transposed,
-                                 map(('{}.' + suffix).format, _qkvo))
-
-                    if qkv is None:
-                        continue
-                    _shape = qkv.shape[1] // 3
-                    _qkv = torch.split(qkv, [_shape, _shape, _shape], dim=1)
-                    q = _qkv[0]
-                    k = _qkv[1]
-                    v = _qkv[2]
-
-                else:
-                    _qkvo = [
-                        f'model.layers.{i}.self_attn.{t}_proj' for t in 'qkvo'
-                    ]
-                    q, k, v, o = map(get_tensor_transposed,
-                                     map(('{}.' + suffix).format, _qkvo))
-                if q is None:
-                    continue
-                # q, k has different layout for fb & hf, convert to fb's
-                # layout
-                q = permute(q)
-                k = permute(k)
-                if suffix == _qweight:  # weight, qweight
-                    qkv = merge_qkv(q, k, v, tp, dim=2)
-                    print(suffix, qkv.shape)
-                else:  # scales, zeros, bias
-                    qkv = merge_qkv(q, k, v, tp, dim=1)
-                    print(suffix, qkv.shape)
-                for k, v in [('w_qkv', qkv), ('wo', o)]:
-                    model_params[f'layers.{i}.attention.{k}.{suffix}'] = v
-            # ffn weights
-            _w123 = [
-                f'model.layers.{i}.mlp.{t}_proj'
-                for t in ['gate', 'down', 'up']
-            ]
-            for suffix in _suffixes:
-                w1, w2, w3 = map(get_tensor_transposed,
-                                 map(('{}.' + suffix).format, _w123))
-                if w1 is None:
-                    continue
-                if suffix in ['scales', 'zeros', 'bias']:
-                    w1, w2, w3 = map(lambda x: x.squeeze(dim=-1), [w1, w2, w3])
-                for k, v in [('w1', w1), ('w2', w2), ('w3', w3)]:
-                    model_params[f'layers.{i}.feed_forward.{k}.{suffix}'] = v
-            other = [('attention_norm.weight', 'input_layernorm.weight'),
-                     ('ffn_norm.weight', 'post_attention_layernorm.weight')]
-            for ft, hf in other:
-                model_params[f'layers.{i}.' +
-                             ft] = get_tensor(f'model.layers.{i}.' + hf)
-        except safetensors.SafetensorError:
-            break
-        except KeyError:
-            break
-
-    assert num_layer == i, f'miss matched layers: {num_layer} vs {i}'
-
-    other = [('tok_embeddings.weight', 'model.embed_tokens.weight'),
-             ('norm.weight', 'model.norm.weight'),
-             ('output.weight', 'lm_head.weight')]
-    for ft, hf in other:
-        model_params[ft] = get_tensor(hf)
-
-    if model_name == 'baichuan2-7b':
-        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/modeling_baichuan.py#L507
-        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
-        model_params['output.weight'] = torch.nn.functional.normalize(
-            model_params['output.weight'])
-
-    return export(model_name,
-                  num_layer,
-                  norm_eps,
-                  kv_head_num,
-                  model_params,
-                  tokenizer_path,
-                  triton_models_path,
-                  tp,
-                  max_position_embeddings=max_position_embeddings,
-                  use_dynamic_ntk=repo_scaling,
-                  rope_theta=rope_theta)
-
-
-def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
-               triton_models_path: str, tp: int, quant_path: str,
-               group_size: int):
-    """Deploy a model with huggingface transformers' format.
-
-    Args:
-        model_name (str): the name of the to-be-deployed model
-        model_path (str): the path of the directory where the model weight
-          files are
-        tokenizer_path (str): the path of the tokenizer model path
-        triton_models_path (str): the path of the exported triton models
-        tp (int): the number of tensor parallelism
-        quant_path (str): path of the quantized model, which can be None
-        group_size (int): a parameter used in AWQ to quantize fp16 weights
-            to 4 bits
-    """
-    if tokenizer_path is None:
-        tokenizer_path = osp.join(model_path, 'tokenizer.model')
-    if osp.exists(tokenizer_path):
-        shutil.copy(tokenizer_path,
-                    osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
-        for _file in os.listdir(model_path):
-            if _file.endswith('.json') or _file.endswith('.py'):
-                json_path = osp.join(model_path, _file)
-                shutil.copy(json_path,
-                            osp.join(triton_models_path, 'tokenizer', _file))
-        with get_package_root_path() as root_path:
-            shutil.copy(osp.join(root_path, 'tokenizer.py'),
-                        osp.join(triton_models_path, 'tokenizer'))
-    else:
-        print(f'tokenizer model {tokenizer_path} does not exist')
-        exit(-1)
-
-    # read model arguments from params.json
-    try:
-        params_path = osp.join(model_path, 'config.json')
-        with open(params_path) as f:
-            model_arg = json.load(f)
-            num_layer = model_arg['num_hidden_layers']
-            norm_eps = model_arg['rms_norm_eps']
-            rope_theta = float(model_arg.get('rope_theta', 10000.0))
-            if 'num_key_value_heads' in model_arg:
-                kv_head_num = model_arg['num_key_value_heads']
-            else:
-                kv_head_num = model_arg['num_attention_heads']
-    except Exception as e:
-        print(f'get "num_hidden_layers" and "rms_norm_eps" from '
-              f'{params_path} failed: {e}')
-        return False
-
-    # convert weights from hf to turbomind
-    if quant_path is None:
-        _files = [
-            osp.join(model_path, file) for file in os.listdir(model_path)
-            if file.endswith('.bin')
-        ]
-        _files = sorted(_files)
-    else:
-        _files = [quant_path]
-
-    model_params = {}
-
-    _params = {}
-    for _file in _files:
-        _tmp = torch.load(_file, map_location='cpu')
-        _params.update(_tmp)
-
-    def get_tensor(name):
-        """return tensor according its name."""
-        return _params[name].cuda().contiguous()
-
-    # import _turbomind as _tm
-    # TODO: find another way import _turbomind
-    lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
-    sys.path.append(osp.join(lmdeploy_dir, 'lib'))
-    import _turbomind as _tm  # noqa: E402
-
-    def transpose_qk_s4(src: torch.Tensor):
-        assert src.is_contiguous()
-        dst = torch.zeros_like(src)
-        _tm.transpose_qk_s4_k_m8(src, dst,
-                                 src.size(-1) * 8, src.size(0), group_size)
-        return dst
-
-    def fuse_w1_w3_s4(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
-                      w1_s: torch.Tensor, w3_qw: torch.Tensor,
-                      w3_qz: torch.Tensor, w3_s: torch.Tensor):
-
-        def fuse(a: torch.Tensor, b: torch.Tensor):
-            ab = torch.cat((a, b)).contiguous()
-            _ab = torch.zeros_like(ab)
-            _tm.fuse_w1_w3_s4_k_m8(ab, _ab, a.size(-1) * 8, a.size(0))
-            return _ab.view(a.size(0), -1)
-
-        w13_qw = fuse(w1_qw, w3_qw)
-        w13_qz = fuse(w1_qz, w3_qz)
-
-        w13_s = torch.cat((w1_s, w3_s)).view(2, w1_s.size(0), -1)
-        w13_s = w13_s.permute(1, 2, 0).contiguous().view(w1_s.size(0), -1)
-
-        return w13_qw, w13_qz, w13_s
-
-    def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
-                   group_size: int):
-        assert qw.is_contiguous()
-        assert qz.is_contiguous()
-        assert s.is_contiguous()
-        _qw = torch.zeros_like(qw)
-        _sz = torch.zeros_like(s, dtype=torch.int32)  # half2
-        _ws = torch.zeros_like(s)
-        _tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
-                            qw.size(-1) * 8, qw.size(0), group_size)
-        return _qw, _sz
-
-    def tp_m_s4(x: torch.Tensor, tp: int):
-        return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
-                                                            1).contiguous()
-
-    attn_bias = False
-
-    for i in range(num_layer):
-        print(i)
-
-        # attention weights
-        q_qw = get_tensor(f'model.layers.{i}.self_attn.q_proj.qweight')
-        k_qw = get_tensor(f'model.layers.{i}.self_attn.k_proj.qweight')
-        v_qw = get_tensor(f'model.layers.{i}.self_attn.v_proj.qweight')
-        o_qw = get_tensor(f'model.layers.{i}.self_attn.o_proj.qweight')
-
-        q_qz = get_tensor(f'model.layers.{i}.self_attn.q_proj.qzeros')
-        k_qz = get_tensor(f'model.layers.{i}.self_attn.k_proj.qzeros')
-        v_qz = get_tensor(f'model.layers.{i}.self_attn.v_proj.qzeros')
-        o_qz = get_tensor(f'model.layers.{i}.self_attn.o_proj.qzeros')
-
-        q_s = get_tensor(f'model.layers.{i}.self_attn.q_proj.scales')
-        k_s = get_tensor(f'model.layers.{i}.self_attn.k_proj.scales')
-        v_s = get_tensor(f'model.layers.{i}.self_attn.v_proj.scales')
-        o_s = get_tensor(f'model.layers.{i}.self_attn.o_proj.scales')
-
-        try:
-            q_b = get_tensor(f'model.layers.{i}.self_attn.q_proj.bias')
-            k_b = get_tensor(f'model.layers.{i}.self_attn.k_proj.bias')
-            v_b = get_tensor(f'model.layers.{i}.self_attn.v_proj.bias')
-            o_b = get_tensor(f'model.layers.{i}.self_attn.o_proj.bias')
-            attn_bias = True
-        except:  # noqa: E722
-            pass
-
-        q_qw = transpose_qk_s4(q_qw)
-        k_qw = transpose_qk_s4(k_qw)
-        q_qz = transpose_qk_s4(q_qz)
-        k_qz = transpose_qk_s4(k_qz)
-        q_s = permute(q_s)
-        k_s = permute(k_s)
-
-        qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2)
-        qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
-        qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
-
-        qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
-
-        qkv_qw = tp_m_s4(qkv_qw, tp)
-
-        model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw
-        model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz
-
-        o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
-
-        model_params[f'layers.{i}.attention.wo.qweight'] = o_qw
-        model_params[f'layers.{i}.attention.wo.scales_zeros'] = o_sz
-
-        if attn_bias:
-            q_b = permute(q_b)
-            k_b = permute(k_b)
-            qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
-            model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
-            model_params[f'layers.{i}.attention.wo.bias'] = o_b
-
-        # ffn weights
-        w1_qw = get_tensor(f'model.layers.{i}.mlp.gate_proj.qweight')
-        w2_qw = get_tensor(f'model.layers.{i}.mlp.down_proj.qweight')
-        w3_qw = get_tensor(f'model.layers.{i}.mlp.up_proj.qweight')
-
-        w1_qz = get_tensor(f'model.layers.{i}.mlp.gate_proj.qzeros')
-        w2_qz = get_tensor(f'model.layers.{i}.mlp.down_proj.qzeros')
-        w3_qz = get_tensor(f'model.layers.{i}.mlp.up_proj.qzeros')
-
-        w1_s = get_tensor(f'model.layers.{i}.mlp.gate_proj.scales')
-        w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales')
-        w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales')
-
-        w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
-                                              w3_s)
-
-        w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
-        w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
-
-        w13_qw = tp_m_s4(w13_qw, tp)
-
-        model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw
-        model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz
-
-        model_params[f'layers.{i}.feed_forward.w2.qweight'] = w2_qw
-        model_params[f'layers.{i}.feed_forward.w2.scales_zeros'] = w2_sz
-
-        # norm weights
-        attn_norm = get_tensor(f'model.layers.{i}.input_layernorm.weight')
-        ffn_norm = get_tensor(
-            f'model.layers.{i}.post_attention_layernorm.weight')
-
-        model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
-        model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm
-
-    other = [('tok_embeddings.weight', 'model.embed_tokens.weight'),
-             ('norm.weight', 'model.norm.weight'),
-             ('output.weight', 'lm_head.weight')]
-    for ft, hf in other:
-        model_params[ft] = get_tensor(hf)
-
-    return export(model_name,
-                  num_layer,
-                  norm_eps,
-                  kv_head_num,
-                  model_params,
-                  tokenizer_path,
-                  triton_models_path,
-                  tp,
-                  weight_type='int4',
-                  group_size=group_size,
-                  rope_theta=rope_theta)
-
-
-def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
-                triton_models_path: str, tp: int):
-    """Deploy a model with huggingface transformers' format.
-
-    Args:
-        model_name (str): the name of the to-be-deployed model
-        model_path (str): the path of the directory where the model weight
-          files are
-        tokenizer_path (str): the path of the tokenizer model path
-        triton_models_path (str): the path of the exported triton models
-        tp (int): the number of tensor parallelism
-        quant_path (str): path of the quantized model, which can be None
-        group_size (int): a parameter used in AWQ to quantize fp16 weights
-            to 4 bits
-    """
-
-    if osp.exists(model_path):
-        shutil.copy(osp.join(model_path, 'qwen.tiktoken'),
-                    osp.join(triton_models_path, 'tokenizer'))
-        for _file in os.listdir(model_path):
-            if _file.endswith('.json') or _file.endswith('.py'):
-                json_path = osp.join(model_path, _file)
-                shutil.copy(json_path,
-                            osp.join(triton_models_path, 'tokenizer', _file))
-        with get_package_root_path() as root_path:
-            shutil.copy(osp.join(root_path, 'tokenizer.py'),
-                        osp.join(triton_models_path, 'tokenizer'))
-    else:
-        print(f'tokenizer model {tokenizer_path} does not exist')
-        exit(-1)
-
-    # read model arguments from params.json
-    try:
-        params_path = osp.join(model_path, 'config.json')
-        with open(params_path) as f:
-            config = json.load(f)
-            num_layer = config['num_hidden_layers']
-            norm_eps = config['layer_norm_epsilon']
-            rope_theta = float(config.get('rotary_emb_base', 10000.0))
-            if 'num_key_value_heads' in config:
-                kv_head_num = config['num_key_value_heads']
-            else:
-                kv_head_num = config['num_attention_heads']
-            seq_length = config['seq_length']
-            use_dynamic_ntk = config['use_dynamic_ntk']
-            use_logn_attn = config['use_logn_attn']
-    except Exception as e:
-        print(f'get "num_hidden_layers" and "layer_norm_epsilon" from '
-              f'{params_path} failed: {e}')
-        return False
-
-    # convert weights from hf to turbomind
-    model_params = {}
-
-    _params = load_checkpoint(model_path)
-
-    def get_tensor(name, trans=True):
-        """return a transposed tensor according its name."""
-        if trans:
-            return _params[name].cuda().t()
-        else:
-            return _params[name].cuda()
-
-    for i in range(num_layer):
-        print(i)
-
-        # qkv weights
-        qkv_w = get_tensor(f'transformer.h.{i}.attn.c_attn.weight')
-        q_w, k_w, v_w = torch.split(qkv_w, qkv_w.size(-1) // 3, dim=-1)
-        q_w, k_w = permute(q_w), permute(k_w)
-        qkv_w = merge_qkv(q_w, k_w, v_w, tp, dim=2)
-        model_params[f'layers.{i}.attention.w_qkv.weight'] = qkv_w
-
-        # qkv bias
-        qkv_b = get_tensor(f'transformer.h.{i}.attn.c_attn.bias')
-        q_b, k_b, v_b = torch.split(qkv_b, qkv_b.size(-1) // 3)
-        q_b, k_b = permute(q_b), permute(k_b)
-        qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
-        model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
-
-        # o weights
-        o_w = get_tensor(f'transformer.h.{i}.attn.c_proj.weight')
-        model_params[f'layers.{i}.attention.wo.weight'] = o_w
-        model_params[f'layers.{i}.attention.wo.bias'] = torch.zeros_like(q_b)
-
-        # ffn weights
-        # ours: w2(silu(w1(x)) * w3(x))
-        # qwen: c_proj(w1(x) * silu(w2(x)))
-        w1 = get_tensor(f'transformer.h.{i}.mlp.w2.weight')
-        w3 = get_tensor(f'transformer.h.{i}.mlp.w1.weight')
-        w2 = get_tensor(f'transformer.h.{i}.mlp.c_proj.weight')
-        model_params[f'layers.{i}.feed_forward.w1.weight'] = w1
-        model_params[f'layers.{i}.feed_forward.w2.weight'] = w2
-        model_params[f'layers.{i}.feed_forward.w3.weight'] = w3
-
-        # norm weights
-        attn_norm = get_tensor(f'transformer.h.{i}.ln_1.weight')
-        ffn_norm = get_tensor(f'transformer.h.{i}.ln_2.weight')
-
-        model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
-        model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm
-
-    other = [('tok_embeddings.weight', 'transformer.wte.weight'),
-             ('norm.weight', 'transformer.ln_f.weight'),
-             ('output.weight', 'lm_head.weight')]
-    for ft, hf in other:
-        model_params[ft] = get_tensor(hf, trans=False)
-
-    return export(model_name,
-                  num_layer,
-                  norm_eps,
-                  kv_head_num,
-                  model_params,
-                  model_path,
-                  triton_models_path,
-                  tp,
-                  max_position_embeddings=seq_length,
-                  use_dynamic_ntk=use_dynamic_ntk,
-                  use_logn_attn=use_logn_attn,
-                  rope_theta=rope_theta,
-                  tokenizer_info=tokenizer_info_qwen)
-
-
-def pack_model_repository(workspace_path: str):
-    """package the model repository.
-
-    Args:
-        workspace_path: the path of workspace
-    """
-    os.symlink(src='../../tokenizer',
-               dst=osp.join(workspace_path, 'triton_models', 'preprocessing',
-                            '1', 'tokenizer'))
-    os.symlink(src='../../tokenizer',
-               dst=osp.join(workspace_path, 'triton_models', 'postprocessing',
-                            '1', 'tokenizer'))
-    os.symlink(src='../../weights',
-               dst=osp.join(workspace_path, 'triton_models', 'interactive',
-                            '1', 'weights'))
-    model_repo_dir = osp.join(workspace_path, 'model_repository')
-    os.makedirs(model_repo_dir, exist_ok=True)
-    os.symlink(src=osp.join('../triton_models/interactive'),
-               dst=osp.join(model_repo_dir, 'turbomind'))
-    os.symlink(src=osp.join('../triton_models/preprocessing'),
-               dst=osp.join(model_repo_dir, 'preprocessing'))
-    os.symlink(src=osp.join('../triton_models/postprocessing'),
-               dst=osp.join(model_repo_dir, 'postprocessing'))
-
-
-def main(model_name: str,
-         model_path: str,
-         model_format: str = None,
-         tokenizer_path: str = None,
-         dst_path: str = './workspace',
-         tp: int = 1,
-         quant_path: str = None,
-         group_size: int = 0):
-    """deploy llama family models via turbomind.
-
-    Args:
-        model_name (str): the name of the to-be-deployed model, such as
-            llama-7b, llama-13b, vicuna-7b and etc
-        model_path (str): the directory path of the model
-        model_format (str): the format of the model, fb or hf. 'fb' stands for
-            META's llama format, and 'hf' means huggingface format
-        tokenizer_path (str): the path of tokenizer model
-        dst_path (str): the destination path that saves outputs
-        tp (int): the number of GPUs used for tensor parallelism, should be 2^n
-        quant_path (str): path of the quantized model, which can be None
-        group_size (int): a parameter used in AWQ to quantize fp16 weights
-            to 4 bits
-    """
-    assert model_name in MODELS.module_dict.keys(), \
-        f"'{model_name}' is not supported. " \
-        f'The supported models are: {MODELS.module_dict.keys()}'
-
-    assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
-
-    if model_format is None:
-        model_format = 'qwen' if model_name == 'qwen-7b' else 'hf'
-
-    if model_format not in supported_formats:
-        print(f'the model format "{model_format}" is not supported. '
-              f'The supported format are: {supported_formats}')
-        exit(-1)
-
-    if model_format == 'llama' and tokenizer_path is None:
-        print('The model is llama. Its tokenizer model path should be '
-              'specified')
-        exit(-1)
-
-    if not create_workspace(dst_path):
-        exit(-1)
-
-    triton_models_path = copy_triton_model_templates(dst_path)
-    if triton_models_path is None:
-        exit(-1)
-
-    if model_format == 'llama':
-        res = deploy_llama(model_name, model_path, tokenizer_path,
-                           triton_models_path, tp)
-    elif model_format == 'hf':
-        res = deploy_hf(model_name, model_path, tokenizer_path,
-                        triton_models_path, tp)
-    elif model_format == 'awq':
-        res = deploy_awq(model_name, model_path, tokenizer_path,
-                         triton_models_path, tp, quant_path, group_size)
-    elif model_format == 'qwen':
-        res = deploy_qwen(model_name, model_path, tokenizer_path,
-                          triton_models_path, tp)
-
-    # update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
-    with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
-              'a') as f:
-        param = \
-            'parameters {\n  key: "tensor_para_size"\n  value: {\n    ' \
-            'string_value: ' + f'"{tp}"\n' + '  }\n}\n' + \
-            'parameters {\n  key: "model_name"\n  value: {\n    ' \
-            'string_value: ' + f'"{model_name}"\n' + '  }\n}\n'
-        f.write(param)
-    if not res:
-        print(f'deploy model "{model_name}" via turbomind failed')
-        destroy_workspace(dst_path)
-        exit(-1)
-
-    # pack model repository for triton inference server
-    pack_model_repository(dst_path)
-
-    # update the value of $TP in `service_docker_up.sh`
-    file_path = osp.join(dst_path, 'service_docker_up.sh')
-    with open(file_path, 'r') as f:
-        content = f.read()
-        content = re.sub('TP=1', f'TP={tp}', content)
-    with open(file_path, 'w') as f:
-        f.write(content)
-
-
-if __name__ == '__main__':
-    fire.Fire(main)
diff --git a/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py b/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py
index 77f51bfb3d..7e659fbae0 100644
--- a/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py
+++ b/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py
@@ -42,9 +42,7 @@ def initialize(self, args):
         self.model_config = model_config = json.loads(args['model_config'])
 
         # Parse model output configs and convert Triton types to numpy types
-        input_names = [
-            'INPUT_ID', 'REQUEST_INPUT_LEN', 'BAD_WORDS_IDS', 'STOP_WORDS_IDS'
-        ]
+        input_names = ['INPUT_ID', 'REQUEST_INPUT_LEN']
         for input_name in input_names:
             setattr(
                 self,
@@ -89,8 +87,6 @@ def execute(self, requests):
             # Get input tensors
             query = pb_utils.get_input_tensor_by_name(request,
                                                       'QUERY').as_numpy()
-            request_output_len = pb_utils.get_input_tensor_by_name(
-                request, 'REQUEST_OUTPUT_LEN').as_numpy()
 
             # Preprocessing input data.
             input_id, request_input_len = self._create_request(query)
@@ -104,8 +100,6 @@ def execute(self, requests):
                 'REQUEST_INPUT_LEN',
                 np.array(request_input_len).astype(
                     self.request_input_len_dtype))
-            request_output_len_tensor = pb_utils.Tensor(
-                'REQUEST_OUTPUT_LEN', request_output_len)
 
             # Create InferenceResponse. You can set an error here in case
             # there was a problem with handling this inference request.
@@ -114,10 +108,8 @@ def execute(self, requests):
             #
             # pb_utils.InferenceResponse(
             #    output_tensors=..., TritonError("An error occurred"))
-            inference_response = pb_utils.InferenceResponse(output_tensors=[
-                input_id_tensor, request_input_len_tensor,
-                request_output_len_tensor
-            ])
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[input_id_tensor, request_input_len_tensor])
             responses.append(inference_response)
 
         # You should return a list of pb_utils.InferenceResponse. Length
@@ -140,10 +132,18 @@ def _create_request(self, query):
         Returns:
             tuple: token ids and their length
         """
-        start_ids = [
-            torch.IntTensor(self.tokenizer.encode(s[0].decode()))
-            for s in query
-        ]
+        start_ids = []
+        for s in query:
+            _s = s[0].decode()
+            if _s == '':
+                start_id = [self.start_id
+                            ] if self.start_id is not None else [-1]
+            elif _s == '':
+                start_id = [self.end_id] if self.end_id is not None else [-1]
+            else:
+                start_id = self.tokenizer.encode(_s)
+            start_ids.append(torch.IntTensor(start_id))
+
         start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])
         start_ids = pad_sequence(start_ids,
                                  batch_first=True,
diff --git a/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt b/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt
index a87abd98df..997ba399ba 100644
--- a/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt
+++ b/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt
@@ -7,23 +7,6 @@ input [
         name: "QUERY"
         data_type: TYPE_STRING
         dims: [ -1 ]
-    },
-    {
-        name: "BAD_WORDS_DICT"
-        data_type: TYPE_STRING
-        dims: [ -1 ]
-        optional: true
-    },
-    {
-        name: "STOP_WORDS_DICT"
-        data_type: TYPE_STRING
-        dims: [ -1 ]
-        optional: true
-    },
-    {
-        name: "REQUEST_OUTPUT_LEN"
-        data_type: TYPE_UINT32
-        dims: [ -1 ]
     }
 ]
 output [
@@ -36,26 +19,6 @@ output [
         name: "REQUEST_INPUT_LEN"
         data_type: TYPE_UINT32
         dims: [ 1 ]
-    },
-    {
-        name: "BAD_WORDS_IDS"
-        data_type: TYPE_INT32
-        dims: [ 2, -1 ]
-    },
-    {
-        name: "STOP_WORDS_IDS"
-        data_type: TYPE_INT32
-        dims: [ 2, -1 ]
-    },
-    {
-        name: "REQUEST_OUTPUT_LEN"
-        data_type: TYPE_UINT32
-        dims: [ -1 ]
-    },
-    {
-        name: "PROMPT_LEARNING_TASK_NAME_IDS"
-        data_type: TYPE_UINT32
-        dims: [ 1 ]
     }
 ]
 
diff --git a/lmdeploy/serve/turbomind/utils.py b/lmdeploy/serve/turbomind/utils.py
index bd1c3a16c2..802f6abaa4 100644
--- a/lmdeploy/serve/turbomind/utils.py
+++ b/lmdeploy/serve/turbomind/utils.py
@@ -48,11 +48,7 @@ def infer(self, prompts: Union[str, List[str]]) -> tuple:
                       f'{type(prompts)}'
 
         input0_data = np.array(input0).astype(object)
-        output0_len = np.ones_like(input0).astype(np.uint32)
-        inputs = [
-            prepare_tensor('QUERY', input0_data),
-            prepare_tensor('REQUEST_OUTPUT_LEN', output0_len)
-        ]
+        inputs = [prepare_tensor('QUERY', input0_data)]
 
         with grpcclient.InferenceServerClient(self.tritonserver_addr) as \
                 client:
diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py
index 296d453ed4..2ddecdfd40 100644
--- a/lmdeploy/tokenizer.py
+++ b/lmdeploy/tokenizer.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import json
+import os
 import os.path as osp
 from typing import Optional, Sequence, Union
 
@@ -16,7 +17,7 @@ class SentencePieceTokenizer:
     def __init__(self, model_file: str):
         from sentencepiece import SentencePieceProcessor
         self.model = SentencePieceProcessor(model_file=model_file)
-        self._no_prefix_space_tokens = None
+        self._prefix_space_tokens = None
 
     @property
     def vocab_size(self):
@@ -34,24 +35,25 @@ def eos_token_id(self):
         return self.model.eos_id()
 
     @property
-    def no_prefix_space_tokens(self):
+    def prefix_space_tokens(self):
         """tokens without prefix space."""
-        if self._no_prefix_space_tokens is None:
+        if self._prefix_space_tokens is None:
             vocab = self.model.IdToPiece(list(range(self.vocab_size)))
-            self._no_prefix_space_tokens = {
+            self._prefix_space_tokens = {
                 i
-                for i, tok in enumerate(vocab) if not tok.startswith('▁')
+                for i, tok in enumerate(vocab) if tok.startswith('▁')
             }
-        return self._no_prefix_space_tokens
+        return self._prefix_space_tokens
 
     def _maybe_add_prefix_space(self, tokens, decoded):
         """maybe add prefix space for incremental decoding."""
-        if len(tokens) and tokens[0] not in self.no_prefix_space_tokens:
+        if len(tokens) and not decoded.startswith(' ') and\
+                tokens[0] in self.prefix_space_tokens:
             return ' ' + decoded
         else:
             return decoded
 
-    def encode(self, s: str):
+    def encode(self, s: str, add_bos: bool = True, **kwargs):
         """Tokenize a prompt.
 
         Args:
@@ -59,15 +61,7 @@ def encode(self, s: str):
         Returns:
             list[int]: token ids
         """
-        add_bos = False
-        add_eos = False
-        if s.find('') != -1:
-            s = s.replace('', '')
-            add_bos = True
-        if s == '':
-            s = ''
-            add_eos = True
-        return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
+        return self.model.Encode(s, add_bos=add_bos, **kwargs)
 
     def decode(self, t: Sequence[int], offset: Optional[int] = None):
         """De-tokenize.
@@ -110,31 +104,32 @@ class HuggingFaceTokenizer:
         model_dir (str): the directory of the tokenizer model
     """
 
-    def __init__(self, model_dir: str, trust_remote_code=True):
-        from transformers import (AutoTokenizer, CodeLlamaTokenizerFast,
-                                  LlamaTokenizerFast)
+    def __init__(self, model_dir: str):
+        from transformers import AutoTokenizer
         model_file = osp.join(model_dir, 'tokenizer.model')
         backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
         model_file_exists = osp.exists(model_file)
         if not osp.exists(backend_tokenizer_file) and model_file_exists:
             print('WARNING: Can not find tokenizer.json. '
                   'It may take long time to initialize the tokenizer.')
-        self.model = AutoTokenizer.from_pretrained(
-            model_dir, trust_remote_code=trust_remote_code)
-        self.need_padding = isinstance(self.model, LlamaTokenizerFast) \
-            or isinstance(self.model, CodeLlamaTokenizerFast)
-        self._no_prefix_space_tokens = None
+        self.model = AutoTokenizer.from_pretrained(model_dir,
+                                                   trust_remote_code=True)
+        self._prefix_space_tokens = None
         # save tokenizer.json to reuse
         if not osp.exists(backend_tokenizer_file) and model_file_exists:
             if hasattr(self.model, 'backend_tokenizer'):
-                self.model.backend_tokenizer.save(backend_tokenizer_file)
+                if os.access(model_dir, os.W_OK):
+                    self.model.backend_tokenizer.save(backend_tokenizer_file)
 
         if self.model.eos_token_id is None:
             generation_config_file = osp.join(model_dir,
                                               'generation_config.json')
-            with open(generation_config_file, 'r') as f:
-                cfg = json.load(f)
-                self.model.eos_token_id = cfg['eos_token_id']
+            if osp.exists(generation_config_file):
+                with open(generation_config_file, 'r') as f:
+                    cfg = json.load(f)
+                    self.model.eos_token_id = cfg['eos_token_id']
+            elif hasattr(self.model, 'eod_id'):  # Qwen remote
+                self.model.eos_token_id = self.model.eod_id
 
     @property
     def vocab_size(self):
@@ -152,26 +147,27 @@ def eos_token_id(self):
         return self.model.eos_token_id
 
     @property
-    def no_prefix_space_tokens(self):
+    def prefix_space_tokens(self):
         """tokens without prefix space."""
-        if self._no_prefix_space_tokens is None:
+        if self._prefix_space_tokens is None:
             vocab = self.model.convert_ids_to_tokens(
                 list(range(self.vocab_size)))
-            self._no_prefix_space_tokens = {
+            self._prefix_space_tokens = {
                 i
-                for i, tok in enumerate(vocab) if not tok.startswith('▁')
+                for i, tok in enumerate(vocab)
+                if tok.startswith('▁' if isinstance(tok, str) else b' ')
             }
-        return self._no_prefix_space_tokens
+        return self._prefix_space_tokens
 
     def _maybe_add_prefix_space(self, tokens, decoded):
         """maybe add prefix space for incremental decoding."""
-        if self.need_padding and len(
-                tokens) and tokens[0] not in self.no_prefix_space_tokens:
+        if len(tokens) and not decoded.startswith(' ') and\
+                tokens[0] in self.prefix_space_tokens:
             return ' ' + decoded
         else:
             return decoded
 
-    def encode(self, s: str):
+    def encode(self, s: str, add_bos: bool = True, **kwargs):
         """Tokenize a prompt.
 
         Args:
@@ -179,14 +175,12 @@ def encode(self, s: str):
         Returns:
             list[int]: token ids
         """
-        add_special_tokens = False
-        if s.find('') != -1:
-            s = s.replace('', '')
-        if s == '':
-            s = ''
-        if len(s) == 0:
-            add_special_tokens = True
-        return self.model.encode(s, add_special_tokens=add_special_tokens)
+        encoded = self.model.encode(s, **kwargs)
+        if not add_bos:
+            # in the middle of a session
+            if len(encoded) and encoded[0] == self.bos_token_id:
+                encoded = encoded[1:]
+        return encoded
 
     def decode(self, t: Sequence[int], offset: Optional[int] = None):
         """De-tokenize.
@@ -225,7 +219,7 @@ class Tokenizer:
         model_file (str): the path of the tokenizer model
     """
 
-    def __init__(self, model_file: str, trust_remote_code=True):
+    def __init__(self, model_file: str):
         if model_file.endswith('.model'):
             model_folder = osp.split(model_file)[0]
         else:
@@ -240,7 +234,7 @@ def __init__(self, model_file: str, trust_remote_code=True):
         if not use_hf_model:
             self.model = SentencePieceTokenizer(model_file)
         else:
-            self.model = HuggingFaceTokenizer(model_folder, trust_remote_code)
+            self.model = HuggingFaceTokenizer(model_folder)
 
     @property
     def vocab_size(self):
@@ -257,7 +251,7 @@ def eos_token_id(self):
         """end of the sentence token id."""
         return self.model.eos_token_id
 
-    def encode(self, s: str):
+    def encode(self, s: str, add_bos: bool = True, **kwargs):
         """Tokenize a prompt.
 
         Args:
@@ -265,7 +259,7 @@ def encode(self, s: str):
         Returns:
             list[int]: token ids
         """
-        return self.model.encode(s)
+        return self.model.encode(s, add_bos, **kwargs)
 
     def decode(self, t: Sequence[int], offset: Optional[int] = None):
         """De-tokenize.
diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py
index de31a5daa7..c0d2c8f4ed 100644
--- a/lmdeploy/turbomind/chat.py
+++ b/lmdeploy/turbomind/chat.py
@@ -1,30 +1,13 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import dataclasses
 import os
-import os.path as osp
 import random
 
-import fire
-
-from lmdeploy import turbomind as tm
-from lmdeploy.model import MODELS
-from lmdeploy.tokenizer import Tokenizer
+from lmdeploy.turbomind.utils import get_gen_param
 
 os.environ['TM_LOG_LEVEL'] = 'ERROR'
 
 
-@dataclasses.dataclass
-class GenParam:
-    top_p: float
-    top_k: float
-    temperature: float
-    repetition_penalty: float
-    sequence_start: bool = False
-    sequence_end: bool = False
-    step: int = 0
-    request_output_len: int = 512
-
-
 def input_prompt(model_name):
     """Input a prompt in the consolo interface."""
     if model_name == 'codellama':
@@ -46,36 +29,12 @@ def valid_str(string, coding='utf-8'):
     return ret
 
 
-def get_gen_param(cap,
-                  sampling_param,
-                  nth_round,
-                  step,
-                  request_output_len=512,
-                  **kwargs):
-    """return parameters used by token generation."""
-    gen_param = GenParam(**dataclasses.asdict(sampling_param),
-                         request_output_len=request_output_len)
-    # Fix me later. turbomind.py doesn't support None top_k
-    if gen_param.top_k is None:
-        gen_param.top_k = 40
-
-    if cap == 'chat':
-        gen_param.sequence_start = (nth_round == 1)
-        gen_param.sequence_end = False
-        gen_param.step = step
-    else:
-        gen_param.sequence_start = True
-        gen_param.sequence_end = True
-        gen_param.step = 0
-    return gen_param
-
-
 def main(model_path,
          session_id: int = 1,
          cap: str = 'chat',
-         sys_instruct: str = None,
-         tp=1,
-         stream_output=True,
+         tp: int = 1,
+         stream_output: bool = True,
+         request_output_len: int = 512,
          **kwargs):
     """An example to perform model inference through the command line
     interface.
@@ -85,24 +44,23 @@ def main(model_path,
         session_id (int): the identical id of a session
         cap (str): the capability of a model. For example, codellama has
             the ability among ['completion', 'infilling', 'chat', 'python']
-        sys_instruct (str): the content of 'system' role, which is used by
-            conversational model
         tp (int): GPU number used in tensor parallelism
         stream_output (bool): indicator for streaming output or not
         **kwarg (dict): other arguments for initializing model's chat template
     """
-    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
-    tokenizer = Tokenizer(tokenizer_model_path)
-    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
+    from lmdeploy import turbomind as tm
+    tm_model = tm.TurboMind.from_pretrained(model_path,
+                                            tp=tp,
+                                            capability=cap,
+                                            **kwargs)
+    tokenizer = tm_model.tokenizer
     generator = tm_model.create_instance()
 
     nth_round = 1
     step = 0
     seed = random.getrandbits(64)
     model_name = tm_model.model_name
-    model = MODELS.get(model_name)(capability=cap, **kwargs) \
-        if sys_instruct is None else MODELS.get(model_name)(
-            capability=cap, system=sys_instruct, **kwargs)
+    model = tm_model.model
 
     print(f'session {session_id}')
     while True:
@@ -112,26 +70,28 @@ def main(model_path,
         elif prompt == 'end':
             prompt = model.get_prompt('', nth_round == 1)
             input_ids = tokenizer.encode(prompt)
-            for outputs in generator.stream_infer(session_id=session_id,
-                                                  input_ids=[input_ids],
-                                                  request_output_len=512,
-                                                  sequence_start=False,
-                                                  sequence_end=True,
-                                                  stream_output=stream_output):
+            for outputs in generator.stream_infer(
+                    session_id=session_id,
+                    input_ids=[input_ids],
+                    request_output_len=request_output_len,
+                    sequence_start=False,
+                    sequence_end=True,
+                    stream_output=stream_output):
                 pass
             nth_round = 1
             step = 0
             seed = random.getrandbits(64)
         else:
             prompt = model.get_prompt(prompt, nth_round == 1)
-            input_ids = tokenizer.encode(prompt)
-            if step + len(input_ids) >= tm_model.session_len:
+            input_ids = tokenizer.encode(prompt, nth_round == 1)
+            if step + len(
+                    input_ids) + request_output_len >= tm_model.session_len:
                 print('WARNING: exceed session max length.'
                       ' Please end the session.')
                 continue
 
             gen_param = get_gen_param(cap, model.sampling_param, nth_round,
-                                      step, **kwargs)
+                                      step, request_output_len, **kwargs)
 
             print(f'{prompt} ', end='', flush=True)
             response_size = 0
@@ -145,6 +105,11 @@ def main(model_path,
                 res, tokens = outputs[0]
                 # decode res
                 response = tokenizer.decode(res.tolist(), offset=response_size)
+                # utf-8 char at the end means it's a potential unfinished
+                # byte sequence, continue to concate it with the next
+                # sequence and decode them together
+                if response.endswith('�'):
+                    continue
                 response = valid_str(response)
                 print(f'{response}', end='', flush=True)
                 response_size = tokens
@@ -157,4 +122,6 @@ def main(model_path,
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/turbomind/decode.py b/lmdeploy/turbomind/decode.py
index daef35298c..5ba4675c59 100644
--- a/lmdeploy/turbomind/decode.py
+++ b/lmdeploy/turbomind/decode.py
@@ -2,7 +2,6 @@
 import os
 import os.path as osp
 
-import fire
 import torch
 
 from lmdeploy import turbomind as tm
@@ -37,4 +36,6 @@ def main(model_path, inputs):
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/turbomind/deploy/__init__.py b/lmdeploy/turbomind/deploy/__init__.py
new file mode 100644
index 0000000000..ef101fec61
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py
new file mode 100644
index 0000000000..5bcab7b537
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/converter.py
@@ -0,0 +1,260 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import re
+import shutil
+from pathlib import Path
+
+import fire
+from huggingface_hub import snapshot_download
+
+from lmdeploy.model import MODELS
+from lmdeploy.turbomind.utils import create_hf_download_args
+
+from .source_model.base import INPUT_MODELS
+from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig
+
+supported_formats = ['llama', 'hf', 'awq', None]
+special_input_model_map = {
+    'qwen': 'qwen',
+    'baichuan': 'baichuan',
+    'baichuan2': 'baichuan2'
+}
+
+
+def get_package_root_path():
+    """Get lmdeploy root path."""
+    import lmdeploy
+    return Path(lmdeploy.__file__).parent
+
+
+def get_tokenizer_path(model_path: str, tokenizer_path: str):
+    """Get tokenizer path if not given."""
+    if tokenizer_path is not None:
+        assert osp.exists(tokenizer_path), f'{tokenizer_path} does not exists.'
+        return tokenizer_path
+    candidate = ['tokenizer.model', 'qwen.tiktoken']
+    for name in candidate:
+        tmp_path = osp.join(model_path, name)
+        if osp.exists(tmp_path):
+            tokenizer_path = tmp_path
+            break
+    assert tokenizer_path, 'please supply tokenizer path by --tokenizer-path'
+    return tokenizer_path
+
+
+def get_model_format(model_name: str, model_format: str):
+    """Get model format if not given or equal awq."""
+    # get model name prefix
+    if model_name.find('-') != -1:
+        model_name = model_name[:model_name.find('-')]
+    # rules:
+    # 1) llama -> match special -> hf (if not matched)
+    # 2) append awq (if model_format is awq)
+    inferred_model_format = model_format
+    if model_format in [None, 'hf']:
+        inferred_model_format = special_input_model_map.get(model_name, 'hf')
+    elif model_format == 'awq':
+        inferred_model_format = special_input_model_map.get(model_name,
+                                                            'hf') + '-awq'
+    return inferred_model_format
+
+
+def create_workspace(_path: str):
+    """Create a workspace.
+
+    Args:
+        _path (str): the path of the workspace
+    """
+    if osp.exists(_path):
+        print(f'remove workspace in directory {_path}')
+        shutil.rmtree(_path)
+    print(f'create workspace in directory {_path}')
+    os.makedirs(_path)
+
+
+def copy_triton_model_templates(_path: str):
+    """copy triton model templates to the specified path.
+
+    Args:
+        _path (str): the target path
+    Returns:
+        str: the path of the triton models
+    """
+
+    root = get_package_root_path()
+    dir_path = osp.join(root, 'serve', 'turbomind')
+    triton_models_path = osp.join(dir_path, 'triton_models')
+    dst_path = osp.join(_path, 'triton_models')
+    print(f'copy triton model templates from "{triton_models_path}" to '
+          f'"{dst_path}"')
+    shutil.copytree(triton_models_path, dst_path, symlinks=True)
+    service_docker_up_file = osp.join(dir_path, 'service_docker_up.sh')
+    print(f'copy service_docker_up.sh from "{service_docker_up_file}" to '
+          f'"{_path}"')
+    shutil.copy(osp.join(dir_path, 'service_docker_up.sh'), _path)
+    return dst_path
+
+
+def copy_tokenizer(model_path: str, tokenizer_path: str,
+                   triton_models_path: str):
+    """Copy tokenizer."""
+    shutil.copy(
+        tokenizer_path,
+        osp.join(triton_models_path,
+                 osp.join('tokenizer', osp.basename(tokenizer_path))))
+    for _file in os.listdir(model_path):
+        if _file.endswith('.json') or _file.endswith('.py'):
+            json_path = osp.join(model_path, _file)
+            shutil.copy(json_path,
+                        osp.join(triton_models_path, 'tokenizer', _file))
+    with get_package_root_path() as root_path:
+        shutil.copy(osp.join(root_path, 'tokenizer.py'),
+                    osp.join(triton_models_path, 'tokenizer'))
+
+
+def pack_model_repository(workspace_path: str):
+    """package the model repository.
+
+    Args:
+        workspace_path: the path of workspace
+    """
+    os.symlink(src=osp.join('..', '..', 'tokenizer'),
+               dst=osp.join(workspace_path, 'triton_models', 'preprocessing',
+                            '1', 'tokenizer'))
+    os.symlink(src=osp.join('..', '..', 'tokenizer'),
+               dst=osp.join(workspace_path, 'triton_models', 'postprocessing',
+                            '1', 'tokenizer'))
+    os.symlink(src=osp.join('..', '..', 'weights'),
+               dst=osp.join(workspace_path, 'triton_models', 'interactive',
+                            '1', 'weights'))
+    model_repo_dir = osp.join(workspace_path, 'model_repository')
+    os.makedirs(model_repo_dir, exist_ok=True)
+    os.symlink(src=osp.join('..', 'triton_models', 'interactive'),
+               dst=osp.join(model_repo_dir, 'turbomind'))
+    os.symlink(src=osp.join('..', 'triton_models', 'preprocessing'),
+               dst=osp.join(model_repo_dir, 'preprocessing'))
+    os.symlink(src=osp.join('..', 'triton_models', 'postprocessing'),
+               dst=osp.join(model_repo_dir, 'postprocessing'))
+
+
+def main(model_name: str,
+         model_path: str,
+         model_format: str = None,
+         tokenizer_path: str = None,
+         dst_path: str = 'workspace',
+         tp: int = 1,
+         quant_path: str = None,
+         group_size: int = 0,
+         **kwargs):
+    """deploy llama family models via turbomind.
+
+    Args:
+        model_name (str): the name of the to-be-deployed model, such as
+            llama-7b, llama-13b, vicuna-7b and etc
+        model_path (str): the directory path of the model
+        model_format (str): the format of the model, should choose from
+            ['llama', 'hf', 'awq', None]. 'llama' stands for META's llama
+            format, 'hf' means huggingface llama format, and 'awq' means
+            llama(hf) model quantized by lmdeploy/lite/quantization/awq.py.
+            the default value is None, which means the model_format will be
+            inferred based on model_name
+        tokenizer_path (str): the path of tokenizer model
+        dst_path (str): the destination path that saves outputs
+        tp (int): the number of GPUs used for tensor parallelism, should be 2^n
+        quant_path (str): Path of the quantized model, which can be None.
+        group_size (int): a parameter used in AWQ to quantize fp16 weights
+            to 4 bits
+        kwargs (dict): other params for convert
+    """
+
+    assert model_name in MODELS.module_dict.keys(), \
+        f"'{model_name}' is not supported. " \
+        f'The supported models are: {MODELS.module_dict.keys()}'
+
+    assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
+
+    output_format = 'fp16'
+
+    # get input model format
+    assert model_format in supported_formats, 'the model format ' \
+        f'should be in {supported_formats}'
+
+    inferred_model_format = get_model_format(model_name, model_format)
+    if inferred_model_format not in INPUT_MODELS.module_dict.keys():
+        supported_keys = list(INPUT_MODELS.module_dict.keys())
+        print(f'with model name {model_name} and model formst {model_format}, '
+              f'the inferred model format is {inferred_model_format}, '
+              f'which is not in supported list {supported_keys}')
+        exit(-1)
+
+    if not os.path.exists(model_path):
+        print(f'can\'t find model from local_path {model_path}, '
+              'try to download from huggingface')
+        download_kwargs = create_hf_download_args(**kwargs)
+        model_path = snapshot_download(model_path, **download_kwargs)
+        print(f'load model from {model_path}')
+
+    # get tokenizer path
+    tokenizer_path = get_tokenizer_path(model_path, tokenizer_path)
+
+    # create workspace
+    create_workspace(dst_path)
+
+    triton_models_path = copy_triton_model_templates(dst_path)
+
+    copy_tokenizer(model_path, tokenizer_path, triton_models_path)
+
+    # turbomind config
+    cfg = TurbomindModelConfig.from_dict({}, allow_none=True)
+    cfg.model_name = model_name
+    cfg.tensor_para_size = tp
+    cfg.rotary_embedding = cfg.size_per_head
+    cfg.group_size = group_size
+    if inferred_model_format.find('awq') != -1:
+        cfg.weight_type = 'int4'
+        output_format = 'w4'
+        assert group_size > 0, f'group_size: {group_size} should > 0'
+
+    # convert
+    print('model_name            ', model_name)
+    print('model_format          ', model_format)
+    print('inferred_model_format ', inferred_model_format)
+    print('model_path            ', model_path)
+    print('tokenizer_path        ', tokenizer_path)
+    print('output_format         ', output_format)
+    weight_path = osp.join(triton_models_path, 'weights')
+    input_model = INPUT_MODELS.get(inferred_model_format)(
+        model_path=model_path,
+        tokenizer_path=tokenizer_path,
+        ckpt_path=quant_path)
+    output_model = OUTPUT_MODELS.get(output_format)(input_model=input_model,
+                                                    cfg=cfg,
+                                                    to_file=True,
+                                                    out_dir=weight_path)
+    output_model.export()
+
+    # update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
+    with open(osp.join(triton_models_path, 'interactive', 'config.pbtxt'),
+              'a') as f:
+        param = \
+            'parameters {\n  key: "tensor_para_size"\n  value: {\n    ' \
+            'string_value: ' + f'"{tp}"\n' + '  }\n}\n' + \
+            'parameters {\n  key: "model_name"\n  value: {\n    ' \
+            'string_value: ' + f'"{model_name}"\n' + '  }\n}\n'
+        f.write(param)
+
+    # pack model repository for triton inference server
+    pack_model_repository(dst_path)
+
+    # update the value of $TP in `service_docker_up.sh`
+    file_path = osp.join(dst_path, 'service_docker_up.sh')
+    with open(file_path, 'r') as f:
+        content = f.read()
+        content = re.sub('TP=1', f'TP={tp}', content)
+    with open(file_path, 'w') as f:
+        f.write(content)
+
+
+if __name__ == '__main__':
+    fire.Fire(main)
diff --git a/lmdeploy/turbomind/deploy/source_model/__init__.py b/lmdeploy/turbomind/deploy/source_model/__init__.py
new file mode 100644
index 0000000000..7c6627c770
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .baichuan import Baichuan2Model, BaichuanModel  # noqa: F401
+from .baichuan_awq import Baichuan2AwqModel, BaichuanAwqModel  # noqa: F401
+from .llama import LlamaModel  # noqa: F401
+from .llama_awq import LlamaAwqModel  # noqa: F401
+from .meta_llama import MetaLlamaModel  # noqa: F401
+from .qwen import QwenModel  # noqa: F401
+from .qwen_awq import QwenAwqModel  # noqa: F401
diff --git a/lmdeploy/turbomind/deploy/source_model/baichuan.py b/lmdeploy/turbomind/deploy/source_model/baichuan.py
new file mode 100644
index 0000000000..46ccb6309d
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/baichuan.py
@@ -0,0 +1,67 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+
+from .base import INPUT_MODELS
+from .llama import LlamaModel, LlamaReader
+
+
+class BaichuanReader(LlamaReader):
+    """BaichuanReader."""
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def _attn(self, i: int, kind: str, size_dim: int, dim: int = 0):
+        """Get q, k, v, o kind for layer i."""
+        result = []
+        pack_key = f'model.layers.{i}.self_attn.W_pack.{kind}'
+        qkv = self.params[pack_key]
+        result.extend(torch.split(qkv, qkv.shape[size_dim] // 3, dim=dim))
+        o = self.params[f'model.layers.{i}.self_attn.o_proj.{kind}']
+        result.append(o)
+        return (*result, )
+
+    def attn(self, i: int):
+        """Get q, k, v, o weight for layer i."""
+        return self._attn(i, 'weight', 0, 0)
+
+    def attn_bias(self, i: int):
+        """Get q, k, v, o bias for layer i."""
+        return (None, ) * 4
+
+
+class Baichuan2Reader(BaichuanReader):
+    """Baichuan2Reader."""
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def output_weight(self):
+        """Get output."""
+        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
+        tensor = self.params.get('lm_head.weight', None)
+        if tensor is not None:
+            tensor = tensor.cuda()
+            tensor = torch.nn.functional.normalize(tensor)
+        return tensor
+
+
+@INPUT_MODELS.register_module(name='baichuan')
+class BaichuanModel(LlamaModel):
+    """Llama model in baichuan format."""
+
+    Reader = BaichuanReader
+
+    def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
+        super().__init__(model_path, tokenizer_path, **kwargs)
+
+
+@INPUT_MODELS.register_module(name='baichuan2')
+class Baichuan2Model(LlamaModel):
+    """Llama model in baichuan format."""
+
+    Reader = Baichuan2Reader
+
+    def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
+        super().__init__(model_path, tokenizer_path, **kwargs)
diff --git a/lmdeploy/turbomind/deploy/source_model/baichuan_awq.py b/lmdeploy/turbomind/deploy/source_model/baichuan_awq.py
new file mode 100644
index 0000000000..d5d60286a8
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/baichuan_awq.py
@@ -0,0 +1,87 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .baichuan import Baichuan2Model, BaichuanModel, BaichuanReader
+from .base import INPUT_MODELS
+from .llama_awq import ensure_fp16orint32
+
+
+class BaichuanAwqReader(BaichuanReader):
+    """BaichuanAwqReader."""
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def attn(self, i: int):
+        """Get q, k, v, o qweight for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'qweight', -1, -1))
+
+    def attn_zero(self, i: int):
+        """Get q, k, v, o qzeros for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'qzeros', -1, -1))
+
+    def attn_scale(self, i: int):
+        """Get q, k, v, o scales for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'scales', -1, -1))
+
+    def ffn(self, i: int):
+        """Get ffn qweight for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'qweight'))
+
+    def ffn_zero(self, i: int):
+        """Get ffn qzeros for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'qzeros'))
+
+    def ffn_scale(self, i: int):
+        """Get ffn scales for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'scales'))
+
+
+class Baichuan2AwqReader(BaichuanAwqReader):
+    """Baichuan2AwqReader."""
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def output_weight(self):
+        """Get output."""
+        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
+        tensor = self.params.get('lm_head.weight', None)
+        if tensor is not None:
+            tensor = tensor.cuda()
+            tensor = torch.nn.functional.normalize(tensor)
+        return tensor
+
+
+@INPUT_MODELS.register_module(name='baichuan-awq')
+class BaichuanAwqModel(BaichuanModel):
+    """Baichuan awq model in hf format."""
+
+    Reader = BaichuanAwqReader
+
+    def __init__(self,
+                 model_path: str,
+                 tokenizer_path: str,
+                 ckpt_path: str = None,
+                 **kwargs):
+        super().__init__(model_path,
+                         tokenizer_path,
+                         ckpt_path=ckpt_path,
+                         **kwargs)
+
+
+@INPUT_MODELS.register_module(name='baichuan2-awq')
+class Baichuan2AwqModel(Baichuan2Model):
+    """Baichuan2 awq model in hf format."""
+
+    Reader = Baichuan2AwqReader
+
+    def __init__(self,
+                 model_path: str,
+                 tokenizer_path: str,
+                 ckpt_path: str = None,
+                 **kwargs):
+        super().__init__(model_path,
+                         tokenizer_path,
+                         ckpt_path=ckpt_path,
+                         **kwargs)
diff --git a/lmdeploy/turbomind/deploy/source_model/base.py b/lmdeploy/turbomind/deploy/source_model/base.py
new file mode 100644
index 0000000000..c335b4c10b
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/base.py
@@ -0,0 +1,175 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import re
+from abc import ABC, abstractmethod
+from typing import Dict, Iterator, Tuple, Union
+
+import torch
+from mmengine import Registry
+
+INPUT_MODELS = Registry(
+    'source model', locations=['lmdeploy.turbomind.deploy.source_model.base'])
+
+
+class BaseReader(ABC):
+    """Base checkpoint manager."""
+
+    def __init__(self):
+        pass
+
+    @property
+    @abstractmethod
+    def start_layer_id(self) -> int:
+        """Get the start transformer layer number."""
+        pass
+
+    @property
+    @abstractmethod
+    def end_layer_id(self) -> int:
+        """Get the end transformer layer number."""
+        pass
+
+    @abstractmethod
+    def init_layer_id(self) -> None:
+        """Get start and end transformer layer number."""
+        self._start_layer_id = -1
+        self._end_layer_id = -1
+        layer_count = {}
+        for key in self.params:
+            layer_id = re.findall(self.attn_layer_patten, key)
+            if len(layer_id) == 0:
+                continue
+            layer_id = int(layer_id[0])
+            if layer_id not in layer_count:
+                layer_count[layer_id] = 0
+            layer_count[layer_id] += 1
+        if len(layer_count) == 0:
+            return
+        if not (len(layer_count) > 1 or self.last_bin):
+            return
+        max_count = max([layer_count[layer_id] for layer_id in layer_count])
+        valid_layer_id = [
+            layer_id for layer_id in layer_count
+            if layer_count[layer_id] == max_count
+        ]
+        self._start_layer_id = min(valid_layer_id)
+        self._end_layer_id = max(valid_layer_id) + 1
+
+    @abstractmethod
+    def clean_up(self, last: bool) -> None:
+        """Clean up unused params."""
+        if last:
+            self.params.clear()
+        else:
+            to_remove = []
+            for key in self.params:
+                layer_id = re.findall(self.attn_layer_patten, key)
+                if len(layer_id) == 0:
+                    # tok, norm, output
+                    to_remove.append(key)
+                else:
+                    layer_id = int(layer_id[0])
+                    if layer_id < self.end_layer_id:
+                        to_remove.append(key)
+            for key in to_remove:
+                self.params.pop(key, None)
+        torch.cuda.empty_cache()
+
+    @abstractmethod
+    def tok_embeddings(self) -> Union[torch.Tensor, None]:
+        """Get embeddings."""
+        pass
+
+    @abstractmethod
+    def norm_weight(self) -> Union[torch.Tensor, None]:
+        """Get norm."""
+        pass
+
+    @abstractmethod
+    def output_weight(self) -> Union[torch.Tensor, None]:
+        """Get output."""
+        pass
+
+    @abstractmethod
+    def attn(self, i: int) -> Tuple[torch.Tensor]:
+        """Get q, k, v, o weight for layer i."""
+        pass
+
+    @abstractmethod
+    def attn_bias(self, i: int) -> Tuple[torch.Tensor, None]:
+        """Get q, k, v, o bias for layer i."""
+        pass
+
+    @abstractmethod
+    def attn_zero(self, i: int) -> Tuple[torch.Tensor, None]:
+        """Get q, k, v, o zero point for layer i."""
+        pass
+
+    @abstractmethod
+    def attn_scale(self, i: int) -> Tuple[torch.Tensor, None]:
+        """Get q, k, v, o scale for layer i."""
+        pass
+
+    @abstractmethod
+    def attn_norm(self, i: int) -> torch.Tensor:
+        """Get attn norm for layer i."""
+        pass
+
+    @abstractmethod
+    def ffn(self, i: int) -> Tuple[torch.Tensor]:
+        """Get ffn weight for layer i."""
+        pass
+
+    @abstractmethod
+    def ffn_zero(self, i: int) -> Tuple[torch.Tensor, None]:
+        """Get ffn zero point for layer i."""
+        pass
+
+    @abstractmethod
+    def ffn_scale(self, i: int) -> Tuple[torch.Tensor, None]:
+        """Get ffn scale for layer i."""
+        pass
+
+    @abstractmethod
+    def ffn_norm(self, i: int) -> torch.Tensor:
+        """Get ffn norm for layer i."""
+        pass
+
+
+class BaseInputModel(ABC):
+    """Base class for input model."""
+
+    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
+        """Constructor for BaseInputModel.
+
+        Args:
+            model_path (str): the path of the model.
+            tokenizer_path (str): the path of the tokenizer model.
+        """
+        self.model_path = model_path
+        self.tokenizer_path = tokenizer_path
+
+    @property
+    @abstractmethod
+    def nmgrs(self) -> int:
+        """Get number of checkpoint."""
+        pass
+
+    @abstractmethod
+    def get_mgrs(self) -> Iterator[BaseReader]:
+        """Conctruct all BaseReader."""
+        pass
+
+    @abstractmethod
+    def tokenizer_info(self):
+        """Read tokenizer info."""
+        pass
+
+    @abstractmethod
+    def model_info(self) -> Dict:
+        """Read model info."""
+        pass
+
+    def bins(self) -> Iterator[BaseReader]:
+        """Get Reader."""
+        for mgr in self.get_mgrs():
+            yield mgr
diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py
new file mode 100644
index 0000000000..f800260467
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/llama.py
@@ -0,0 +1,198 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+
+import torch
+from safetensors.torch import load_file
+
+from lmdeploy.tokenizer import Tokenizer
+
+from .base import INPUT_MODELS, BaseInputModel, BaseReader
+
+
+class LlamaReader(BaseReader):
+    """LlamaReader."""
+
+    attn_layer_patten = r'model.layers.([0-9]+).'
+    tok_embeddings_key = 'model.embed_tokens.weight'
+    norm_weight_key = 'model.norm.weight'
+    output_weight_key = 'lm_head.weight'
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__()
+        self.params = unused_params
+        self.params.update(new_params)
+        self.last_bin = last_bin
+        self.init_layer_id()
+
+    def init_layer_id(self):
+        """Get start/end transformer layer id."""
+        super().init_layer_id()
+
+    def clean_up(self, last: bool) -> None:
+        """Clean up unused params."""
+        super().clean_up(last)
+
+    @property
+    def start_layer_id(self):
+        """Get start transformer layer id."""
+        return self._start_layer_id
+
+    @property
+    def end_layer_id(self):
+        """Get end transformer layer id."""
+        return self._end_layer_id
+
+    def tok_embeddings(self):
+        """Get embeddings."""
+        return self.params.get(self.tok_embeddings_key, None)
+
+    def norm_weight(self):
+        """Get norm."""
+        return self.params.get(self.norm_weight_key, None)
+
+    def output_weight(self):
+        """Get output."""
+        return self.params.get(self.output_weight_key, None)
+
+    def _attn(self, i: int, kind: str, allow_none=False):
+        """Get q, k, v, o kind for layer i."""
+        result = []
+        for key in ['q', 'k', 'v', 'o']:
+            tensor = self.params.get(
+                f'model.layers.{i}.self_attn.{key}_proj.{kind}')
+            if not allow_none:
+                assert tensor is not None
+            result.append(tensor)
+        return (*result, )
+
+    def attn(self, i: int):
+        """Get q, k, v, o weight for layer i."""
+        return self._attn(i, 'weight')
+
+    def attn_bias(self, i: int):
+        """Get q, k, v, o bias for layer i."""
+        return self._attn(i, 'bias', allow_none=True)
+
+    def attn_zero(self, i: int):
+        """Get q, k, v, o zero point for layer i."""
+        return (None, ) * 4
+
+    def attn_scale(self, i: int):
+        """Get q, k, v, o scale for layer i."""
+        return (None, ) * 4
+
+    def attn_norm(self, i: int):
+        """Get attn norm for layer i."""
+        return self.params[f'model.layers.{i}.input_layernorm.weight']
+
+    def _ffn(self, i: int, kind: str):
+        """Get ffn kind for layer i."""
+        result = []
+        for key in ['gate', 'down', 'up']:
+            tensor = self.params[f'model.layers.{i}.mlp.{key}_proj.{kind}']
+            result.append(tensor)
+        return (*result, )
+
+    def ffn(self, i: int):
+        """Get ffn weight for layer i."""
+        return self._ffn(i, 'weight')
+
+    def ffn_zero(self, i: int):
+        """Get ffn zero point for layer i."""
+        return (None, ) * 3
+
+    def ffn_scale(self, i: int):
+        """Get ffn scale for layer i."""
+        return (None, ) * 3
+
+    def ffn_norm(self, i: int):
+        """Get ffn norm for layer i."""
+        return self.params[f'model.layers.{i}.post_attention_layernorm.weight']
+
+
+@INPUT_MODELS.register_module(name='hf')
+class LlamaModel(BaseInputModel):
+    """Llama model in hf format."""
+
+    Reader = LlamaReader
+
+    def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
+        super().__init__(model_path, tokenizer_path)
+        ckpt_path = kwargs.get('ckpt_path')
+        if ckpt_path is None:
+            ckpt_path = model_path
+        self.ckpt_path = ckpt_path
+        self.ckpt_files = self.get_ckpt()
+
+    def get_ckpt(self):
+        """Get weight files."""
+        suffixes = ['.safetensors', '.bin']
+        files = []
+        for suffix in suffixes:
+            files = [
+                file for file in os.listdir(self.ckpt_path)
+                if file.endswith(suffix)
+            ]
+            if len(files) > 0:
+                break
+        files = sorted(files)
+        return files
+
+    @property
+    def nmgrs(self):
+        """Get number of checkpoint."""
+        return len(self.ckpt_files)
+
+    def get_mgrs(self):
+        """Conctruct all Reader."""
+        assert self.nmgrs > 0, \
+            f'could not find checkpoints in {self.ckpt_path}'
+        unused_params = {}
+        try:
+            for i, ckpt in enumerate(self.ckpt_files):
+                is_last_bin = i == len(self.ckpt_files) - 1
+                if ckpt.endswith('.bin'):
+                    new_params = torch.load(osp.join(self.ckpt_path, ckpt),
+                                            map_location='cpu')
+                else:
+                    new_params = load_file(osp.join(self.ckpt_path, ckpt))
+                ret = self.Reader(new_params, unused_params,
+                                  i == self.nmgrs - 1)
+                yield ret
+                ret.clean_up(is_last_bin)
+        except GeneratorExit:
+            ret.clean_up(True)
+
+    def tokenizer_info(self):
+        """Read tokenizer info."""
+        assert osp.isdir(self.model_path), self.model_path
+        tk_model = Tokenizer(self.model_path)
+        n_words = tk_model.vocab_size
+        bos_id = tk_model.bos_token_id
+        eos_id = tk_model.eos_token_id
+        return n_words, bos_id, eos_id
+
+    def model_info(self):
+        """Read model info."""
+        params_path = osp.join(self.model_path, 'config.json')
+        with open(params_path) as f:
+            model_arg = json.load(f)
+            num_layer = model_arg['num_hidden_layers']
+            norm_eps = model_arg['rms_norm_eps']
+            if 'num_key_value_heads' in model_arg:
+                kv_head_num = model_arg['num_key_value_heads']
+            else:
+                kv_head_num = model_arg['num_attention_heads']
+            rope_theta = float(model_arg.get('rope_theta', 10000.0))
+            max_position_embeddings = int(
+                model_arg.get('max_position_embeddings', 0))
+            repo_scaling = bool(model_arg.get('rope_scaling', False))
+
+        return dict(num_layer=num_layer,
+                    norm_eps=norm_eps,
+                    kv_head_num=kv_head_num,
+                    rope_theta=rope_theta,
+                    max_position_embeddings=max_position_embeddings,
+                    use_dynamic_ntk=int(repo_scaling))
diff --git a/lmdeploy/turbomind/deploy/source_model/llama_awq.py b/lmdeploy/turbomind/deploy/source_model/llama_awq.py
new file mode 100644
index 0000000000..9d2ae8ac50
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/llama_awq.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .base import INPUT_MODELS
+from .llama import LlamaModel, LlamaReader
+
+
+def ensure_fp16orint32(tensors: torch.Tensor):
+    """Ensure tensors in fp16/int32 format."""
+    result = []
+    for tensor in tensors:
+        if tensor is not None:
+            if tensor.dtype in [torch.float16, torch.float32, torch.bfloat16]:
+                result.append(tensor.half())
+            else:
+                assert tensor.dtype == torch.int32
+                result.append(tensor)
+        else:
+            result.append(None)
+    return (*result, )
+
+
+class LlamaAwqReader(LlamaReader):
+    """LlamaAwqReader."""
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def attn(self, i: int):
+        """Get q, k, v, o qweight for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'qweight'))
+
+    def attn_zero(self, i: int):
+        """Get q, k, v, o qzeros for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'qzeros'))
+
+    def attn_scale(self, i: int):
+        """Get q, k, v, o scales for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'scales'))
+
+    def ffn(self, i: int):
+        """Get ffn qweight for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'qweight'))
+
+    def ffn_zero(self, i: int):
+        """Get ffn qzeros for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'qzeros'))
+
+    def ffn_scale(self, i: int):
+        """Get ffn scales for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'scales'))
+
+
+@INPUT_MODELS.register_module(name='hf-awq')
+class LlamaAwqModel(LlamaModel):
+    """Llama Awq model in hf format."""
+
+    Reader = LlamaAwqReader
+
+    def __init__(self,
+                 model_path: str,
+                 tokenizer_path: str,
+                 ckpt_path: str = None,
+                 **kwargs):
+        super().__init__(model_path,
+                         tokenizer_path,
+                         ckpt_path=ckpt_path,
+                         **kwargs)
diff --git a/lmdeploy/turbomind/deploy/source_model/meta_llama.py b/lmdeploy/turbomind/deploy/source_model/meta_llama.py
new file mode 100644
index 0000000000..bc26361c73
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/meta_llama.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os.path as osp
+from pathlib import Path
+
+import torch
+from sentencepiece import SentencePieceProcessor
+
+from .base import INPUT_MODELS, BaseInputModel, BaseReader
+
+
+def reverse_permute(x: torch.Tensor, size_per_head: int = 128):
+    """reverse permute to hf format."""
+    if x.shape[-1] > 1:
+        dim = x.shape[-1]
+        n_heads = dim // size_per_head
+        return x.view(-1, n_heads, dim // n_heads // 2,
+                      2).transpose(2, 3).reshape(-1, dim)
+    else:  # scales, zeros
+        dim = x.shape[0]
+        n_heads = dim // size_per_head
+        return x.view(n_heads, dim // n_heads // 2, 2,
+                      1).transpose(1, 2).reshape(dim, 1)
+
+
+class MetaLlamaReader(BaseReader):
+    """MetaLlamaReader."""
+
+    def __init__(self, model_path: str, start_layer_id: int,
+                 end_layer_id: int):
+        super().__init__()
+        self._start_layer_id = start_layer_id
+        self._end_layer_id = end_layer_id
+        self.params = self.load_model(model_path)
+
+    def init_layer_id(self):
+        """Empty."""
+        pass
+
+    def load_model(self, model_path):
+        """Load all parameters."""
+        checkpoints = []
+        for pattern in ['*.pth', '*.pt']:
+            checkpoints += sorted(Path(model_path).glob(pattern))
+        n_ckpt = len(checkpoints)
+        model_params = {}
+
+        def get_param(_name, _size):
+            if _name not in model_params:
+                model_params[_name] = torch.zeros(_size,
+                                                  dtype=torch.float16,
+                                                  device='cpu')
+            return model_params[_name]
+
+        from tqdm import tqdm
+        pbar = tqdm(total=n_ckpt, desc='load meta ckpt', leave=False)
+        for i, ckpt_path in enumerate(checkpoints):
+            ckpt = torch.load(ckpt_path, map_location='cpu')
+            for param_name, param_data in ckpt.items():
+                key, ext = param_name.split('.')[-2:]
+                # column-parallel
+                if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'output']:
+                    size = param_data.size(0)
+                    if ext == 'weight':
+                        param = get_param(
+                            param_name,
+                            [size * n_ckpt, param_data.size(1)])
+                        param.data[size * i:size * (i + 1), :] = param_data
+                    else:  # bias
+                        param = get_param(param_name, [size * n_ckpt])
+                        param.data[size * i:size * (i + 1)] = param_data
+                # row-parallel
+                elif key in ['w2', 'wo', 'tok_embeddings']:
+                    size = param_data.size(-1)
+                    if ext == 'weight':
+                        param = get_param(param_name,
+                                          [param_data.size(0), size * n_ckpt])
+                        param.data[:, size * i:size * (i + 1)] = param_data
+                    else:  # bias
+                        param = get_param(param_name, [size])
+                        param.data = param_data
+                elif i == 0:
+                    param = get_param(param_name, param_data.size())
+                    param.data = param_data
+            del ckpt
+            pbar.update(1)
+        pbar.close()
+
+        for name, param in model_params.items():
+            # transpose all weights as TurboMind is expecting column-major
+            # (output_dims, input_dims) -> (input_dims, output_dims)
+            key = name.split('.')[-2]
+            if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'w2', 'wo']:
+                param.data = param.data.t()
+                if key in ['wq', 'wk']:
+                    param.data = reverse_permute(param.data)
+        return model_params
+
+    def clean_up(self, last: bool) -> None:
+        """Clean up unused params."""
+        self.params.clear()
+
+    @property
+    def start_layer_id(self):
+        """Get start transformer layer id."""
+        return self._start_layer_id
+
+    @property
+    def end_layer_id(self):
+        """Get end transformer layer id."""
+        return self._end_layer_id
+
+    def tok_embeddings(self):
+        """Get embeddings."""
+        return self.params.get('tok_embeddings.weight')
+
+    def norm_weight(self):
+        """Get norm."""
+        return self.params.get('norm.weight')
+
+    def output_weight(self):
+        """Get output."""
+        return self.params.get('output.weight')
+
+    def attn(self, i: int):
+        """Get q, k, v, o weight for layer i."""
+        result = []
+        for key in ['wq', 'wk', 'wv', 'wo']:
+            tensor = self.params[f'layers.{i}.attention.{key}.weight']
+            tensor = tensor.t() if tensor is not None else None
+            result.append(tensor)
+        return (*result, )
+
+    def attn_bias(self, i: int):
+        """Get q, k, v, o bias for layer i."""
+        result = []
+        for key in ['wq', 'wk', 'wv', 'wo']:
+            tensor = self.params.get(f'layers.{i}.attention.{key}.bias')
+            tensor = tensor.t() if tensor is not None else None
+            result.append(tensor)
+        return (*result, )
+
+    def attn_zero(self, i: int):
+        """Get q, k, v, o zero point for layer i."""
+        return (None, ) * 4
+
+    def attn_scale(self, i: int):
+        """Get q, k, v, o scale for layer i."""
+        return (None, ) * 4
+
+    def attn_norm(self, i: int):
+        """Get attn norm for layer i."""
+        return self.params[f'layers.{i}.attention_norm.weight']
+
+    def ffn(self, i: int):
+        """Get ffn weight for layer i."""
+        result = []
+        for key in ['w1', 'w2', 'w3']:
+            tensor = self.params[f'layers.{i}.feed_forward.{key}.weight']
+            result.append(tensor.t())
+        return (*result, )
+
+    def ffn_zero(self, i: int):
+        """Get ffn zero point for layer i."""
+        return (None, ) * 3
+
+    def ffn_scale(self, i: int):
+        """Get ffn scale for layer i."""
+        return (None, ) * 3
+
+    def ffn_norm(self, i: int):
+        """Get ffn norm for layer i."""
+        return self.params[f'layers.{i}.ffn_norm.weight']
+
+
+@INPUT_MODELS.register_module(name='llama')
+class MetaLlamaModel(BaseInputModel):
+    """Llama model in fb format."""
+
+    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
+        super().__init__(model_path, tokenizer_path, **kwargs)
+
+    @property
+    def nmgrs(self):
+        """Get number of checkpoint."""
+        return 1
+
+    def get_mgrs(self):
+        """Conctruct all BaseReader."""
+        end_layer_id = self.model_info()['num_layer']
+        try:
+            if hasattr(self, 'meta_reader'):
+                yield self.meta_reader
+            else:
+                self.meta_reader = MetaLlamaReader(self.model_path, 0,
+                                                   end_layer_id)
+                yield self.meta_reader
+        except GeneratorExit:
+            pass
+
+    def tokenizer_info(self):
+        """Read tokenizer info."""
+        assert osp.isfile(self.tokenizer_path), self.tokenizer_path
+        sp_model = SentencePieceProcessor(model_file=self.tokenizer_path)
+        # BOS / EOS token IDs
+        n_words = sp_model.vocab_size()
+        bos_id = sp_model.bos_id()
+        eos_id = sp_model.eos_id()
+        return n_words, bos_id, eos_id
+
+    def model_info(self):
+        """Read model info."""
+        params_path = osp.join(self.model_path, 'params.json')
+        with open(params_path) as f:
+            model_arg = json.load(f)
+            num_layer = model_arg['n_layers']
+            norm_eps = model_arg['norm_eps']
+            head_num = model_arg.get('n_heads', 32)
+            kv_head_num = model_arg.get('n_kv_heads', head_num)
+
+        return dict(num_layer=num_layer,
+                    norm_eps=norm_eps,
+                    head_num=head_num,
+                    kv_head_num=kv_head_num)
diff --git a/lmdeploy/turbomind/deploy/source_model/qwen.py b/lmdeploy/turbomind/deploy/source_model/qwen.py
new file mode 100644
index 0000000000..09ff93afc5
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/qwen.py
@@ -0,0 +1,113 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os.path as osp
+
+import torch
+
+from .base import INPUT_MODELS
+from .llama import LlamaModel, LlamaReader
+
+
+class QwenReader(LlamaReader):
+    """QwenReader."""
+
+    attn_layer_patten = r'transformer.h.([0-9]+).'
+    tok_embeddings_key = 'transformer.wte.weight'
+    norm_weight_key = 'transformer.ln_f.weight'
+    output_weight_key = 'lm_head.weight'
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def _attn(self, i: int, kind: str, size_dim: int, dim: int = 0):
+        """Get q, k, v, o kind for layer i."""
+        qkv = self.params[f'transformer.h.{i}.attn.c_attn.{kind}']
+        q, k, v = torch.split(qkv, qkv.size(size_dim) // 3, dim=dim)
+        o = self.params.get(f'transformer.h.{i}.attn.c_proj.{kind}', None)
+        if o is None:
+            o = torch.zeros_like(q)
+        return q, k, v, o
+
+    def attn(self, i: int):
+        """Get q, k, v, o weight for layer i."""
+        return self._attn(i, 'weight', 0, 0)
+
+    def attn_bias(self, i: int):
+        """Get q, k, v, o bias for layer i."""
+        return self._attn(i, 'bias', -1, 0)
+
+    def attn_zero(self, i: int):
+        """Get q, k, v, o zero point for layer i."""
+        return (None, ) * 4
+
+    def attn_scale(self, i: int):
+        """Get q, k, v, o scale for layer i."""
+        return (None, ) * 4
+
+    def attn_norm(self, i: int):
+        """Get attn norm for layer i."""
+        return self.params[f'transformer.h.{i}.ln_1.weight']
+
+    def _ffn(self, i: int, kind: str):
+        """Get ffn kind for layer i."""
+        result = []
+        for key in ['w2', 'c_proj', 'w1']:
+            tensor = self.params[f'transformer.h.{i}.mlp.{key}.{kind}']
+            result.append(tensor)
+        return (*result, )
+
+    def ffn(self, i: int):
+        """Get ffn weight for layer i."""
+        return self._ffn(i, 'weight')
+
+    def ffn_zero(self, i: int):
+        """Get ffn zero point for layer i."""
+        return (None, ) * 3
+
+    def ffn_scale(self, i: int):
+        """Get ffn scale for layer i."""
+        return (None, ) * 3
+
+    def ffn_norm(self, i: int):
+        """Get ffn norm for layer i."""
+        return self.params[f'transformer.h.{i}.ln_2.weight']
+
+
+@INPUT_MODELS.register_module(name='qwen')
+class QwenModel(LlamaModel):
+    """Qwen model in hf format."""
+
+    Reader = QwenReader
+
+    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
+        super().__init__(model_path, tokenizer_path, **kwargs)
+
+    def tokenizer_info(self):
+        """Read tokenizer info."""
+        n_words = 151851
+        bos_id = 0
+        eos_id = 151643
+        return n_words, bos_id, eos_id
+
+    def model_info(self):
+        """Read model info."""
+        params_path = osp.join(self.model_path, 'config.json')
+        with open(params_path) as f:
+            config = json.load(f)
+            num_layer = config['num_hidden_layers']
+            norm_eps = config['layer_norm_epsilon']
+            rope_theta = float(config.get('rotary_emb_base', 10000.0))
+            if 'num_key_value_heads' in config:
+                kv_head_num = config['num_key_value_heads']
+            else:
+                kv_head_num = config['num_attention_heads']
+            seq_length = config['seq_length']
+            use_dynamic_ntk = int(config['use_dynamic_ntk'])
+            use_logn_attn = int(config['use_logn_attn'])
+        return dict(num_layer=num_layer,
+                    norm_eps=norm_eps,
+                    kv_head_num=kv_head_num,
+                    rope_theta=rope_theta,
+                    max_position_embeddings=seq_length,
+                    use_dynamic_ntk=int(use_dynamic_ntk),
+                    use_logn_attn=use_logn_attn)
diff --git a/lmdeploy/turbomind/deploy/source_model/qwen_awq.py b/lmdeploy/turbomind/deploy/source_model/qwen_awq.py
new file mode 100644
index 0000000000..04df2ac729
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/qwen_awq.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import INPUT_MODELS
+from .llama_awq import ensure_fp16orint32
+from .qwen import QwenModel, QwenReader
+
+
+class QwenAwqReader(QwenReader):
+    """QwenAwqReader."""
+
+    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool):
+        super().__init__(new_params, unused_params, last_bin)
+
+    def attn(self, i: int):
+        """Get q, k, v, o qweight for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'qweight', -1, -1))
+
+    def attn_bias(self, i: int):
+        """Get q, k, v, o bias for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'bias', -1, 0))
+
+    def attn_zero(self, i: int):
+        """Get q, k, v, o qzeros for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'qzeros', -1, -1))
+
+    def attn_scale(self, i: int):
+        """Get q, k, v, o scales for layer i."""
+        return ensure_fp16orint32(self._attn(i, 'scales', -1, -1))
+
+    def ffn(self, i: int):
+        """Get ffn qweight for layer i."""
+        # ours: w2(silu(w1(x)) * w3(x))
+        # qwen: c_proj(w1(x) * silu(w2(x)))
+        return ensure_fp16orint32(self._ffn(i, 'qweight'))
+
+    def ffn_zero(self, i: int):
+        """Get ffn qzeros for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'qzeros'))
+
+    def ffn_scale(self, i: int):
+        """Get ffn scales for layer i."""
+        return ensure_fp16orint32(self._ffn(i, 'scales'))
+
+
+@INPUT_MODELS.register_module(name='qwen-awq')
+class QwenAwqModel(QwenModel):
+    """Qwen awq model in hf format."""
+
+    Reader = QwenAwqReader
+
+    def __init__(self,
+                 model_path: str,
+                 tokenizer_path: str,
+                 ckpt_path: str = None,
+                 **kwargs):
+        super().__init__(model_path,
+                         tokenizer_path,
+                         ckpt_path=ckpt_path,
+                         **kwargs)
diff --git a/lmdeploy/turbomind/deploy/target_model/__init__.py b/lmdeploy/turbomind/deploy/target_model/__init__.py
new file mode 100644
index 0000000000..fe03500e45
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/target_model/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .fp import TurbomindModel  # noqa: F401
+from .w4 import TurbomindW4Model  # noqa: F401
diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py
new file mode 100644
index 0000000000..92e6232301
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/target_model/base.py
@@ -0,0 +1,269 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import configparser
+import inspect
+import os.path as osp
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+import torch
+import tqdm
+from mmengine import Registry
+
+from lmdeploy.model import MODELS
+
+from ..source_model.base import BaseInputModel, BaseReader
+
+OUTPUT_MODELS = Registry(
+    'target model', locations=['lmdeploy.turbomind.deploy.target_model.base'])
+
+
+def tprint(*args, **kwargs):
+    to_file = kwargs.pop('to_file', False)
+    if not to_file:
+        return
+    from io import StringIO
+    s = StringIO()
+    print(*args, **kwargs, file=s, end='')
+    tqdm.tqdm.write(s.getvalue())
+
+
+@dataclass
+class TurbomindModelConfig:
+    """Config for turbomind model."""
+    model_name: str
+    tensor_para_size: int
+    head_num: int
+    kv_head_num: int
+    vocab_size: int
+    num_layer: int
+    inter_size: int
+    norm_eps: float
+    attn_bias: int
+    start_id: int
+    end_id: int
+    session_len: int
+    weight_type: str = 'fp16'
+    rotary_embedding: int = 128
+    rope_theta: float = 10000.0
+    size_per_head: int = 128
+    group_size: int = 0
+    max_batch_size: int = 64
+    max_context_token_num: int = 1
+    step_length: int = 1
+    cache_max_entry_count: float = 0.5
+    cache_block_seq_len: int = 128
+    cache_chunk_size: int = 1
+    use_context_fmha: int = 1
+    quant_policy: int = 0
+    max_position_embeddings: int = 0
+    rope_scaling_factor: float = 0.0
+    use_logn_attn: int = 0
+
+    @classmethod
+    def from_dict(cls, env, allow_none=False):
+        """Construct from dict."""
+        params = inspect.signature(cls).parameters
+        used = {k: v for k, v in env.items() if k in params and v is not None}
+        if not allow_none:
+            return cls(**used)
+        else:
+            default = {
+                k: None
+                for k in params.keys() if params[k].default is inspect._empty
+            }
+            default.update(used)
+            return cls(**default)
+
+    @property
+    def valid(self):
+        """Check if cfg is valid."""
+        for _, v in self.__dict__.items():
+            if v is None:
+                return False
+        return True
+
+
+class BaseOutputModel(ABC):
+    """Base output model."""
+
+    def __init__(self,
+                 input_model: BaseInputModel,
+                 cfg: TurbomindModelConfig,
+                 to_file: bool = True,
+                 out_dir: str = ''):
+        super().__init__()
+        self.input_model = input_model
+        self.cfg = cfg
+        if not cfg.valid:
+            self.cfg = self.get_config(cfg)
+        assert self.cfg.valid
+        self.to_file = to_file
+        self.out_dir = out_dir
+        self.tm_params = {}
+
+    @abstractmethod
+    def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
+        """Generate turbomind model config (config.ini)."""
+        _, bos_id, eos_id = self.input_model.tokenizer_info()
+        model = MODELS.get(cfg.model_name)()
+        final_cfg = cfg.__dict__
+        final_cfg.update(
+            dict(start_id=bos_id,
+                 end_id=eos_id,
+                 session_len=model.session_len + 8))
+        final_cfg.update(self.input_model.model_info())
+
+        # head_num, vocab_size
+        for bin in self.input_model.bins():
+            emb = bin.tok_embeddings()
+            if emb is not None:
+                _vocab_size, dim = emb.shape
+                head_num = dim // cfg.size_per_head
+                break
+        final_cfg.update(dict(head_num=head_num, vocab_size=_vocab_size))
+        return TurbomindModelConfig.from_dict(final_cfg, allow_none=True)
+
+    def export_config(self) -> None:
+        """export turbomind config."""
+        if self.to_file:
+            config = configparser.ConfigParser()
+            cfg = dict(llama=self.cfg.__dict__)
+            for section, key_values in cfg.items():
+                config[section] = key_values
+            config_path = osp.join(self.out_dir, 'config.ini')
+            with open(config_path, 'w') as f:
+                config.write(f)
+
+    def export_weight(self, param: torch.Tensor, name: str) -> None:
+        """export turbomind weight."""
+        if self.to_file:
+            if param.dtype in [torch.float, torch.bfloat16]:
+                param = param.half()
+            tprint(name, param.shape)
+            param.contiguous().cpu().numpy().tofile(
+                osp.join(self.out_dir, name))
+        elif len(self.tm_params) > 0:
+            tm_params = self.tm_params
+            weight_type = self.cfg.weight_type
+            assert weight_type in ['fp16', 'fp32', 'int4']
+
+            # currently, the tensor type should in
+            # [torch.float, torch.half, torch.int32]
+            torch_tensor = param.cuda().contiguous()
+            assert torch_tensor.dtype in [
+                torch.int32, torch.float, torch.half, torch.bfloat16
+            ]
+            if torch_tensor.dtype != torch.int32:
+                if weight_type in ['fp16', 'int4']:
+                    torch_tensor = torch_tensor.half()
+                else:
+                    torch_tensor = torch_tensor.float()
+            for tm_tensor in tm_params[name]:
+                tm_tensor.copy_from(torch_tensor)
+            tm_params.pop(name)
+        else:
+            tprint('skip export', name, param.shape)
+
+    def save_split(self,
+                   tensor: torch.Tensor,
+                   name: str,
+                   split_dim=None,
+                   copy=False) -> None:
+        """save split."""
+        tp = self.cfg.tensor_para_size
+        if split_dim is not None:
+            tprint(
+                f'*** splitting {name}, shape={tensor.shape}, '
+                f'split_dim={split_dim}, tp={tp}',
+                to_file=self.to_file)
+            assert tensor.shape[split_dim] % tp == 0
+            split_size = tensor.shape[split_dim] // tp
+            splits = torch.split(tensor, split_size, dim=split_dim)
+            for i, split in enumerate(splits):
+                prefix, ext = osp.splitext(name)
+                self.export_weight(split, f'{prefix}.{i}{ext}')
+        elif copy:
+            tprint(f'### copying {name}, shape={tensor.shape}',
+                   to_file=self.to_file)
+            copies = [tensor] * tp
+            for i, copy in enumerate(copies):
+                prefix, ext = osp.splitext(name)
+                self.export_weight(copy, f'{prefix}.{i}{ext}')
+        else:
+            self.export_weight(tensor, name)
+
+    def export(self) -> None:
+        """Export to turbomind model format."""
+        num_layer = self.cfg.num_layer
+        from tqdm import tqdm
+        pbar = tqdm(total=num_layer,
+                    desc='Convert to turbomind format',
+                    leave=self.to_file)
+        self.export_config()
+        for bin in self.input_model.bins():
+            self.export_misc(bin)
+            for i in range(bin.start_layer_id, bin.end_layer_id):
+                self.export_transformer_block(bin, i)
+                pbar.update(1)
+        pbar.close()
+        # manually clean up meta reader
+        if hasattr(self.input_model, 'meta_reader'):
+            self.input_model.meta_reader.clean_up(True)
+            del self.input_model.meta_reader
+            torch.cuda.empty_cache()
+
+    def export_misc(self, bin: BaseReader) -> None:
+        """Export embedding, norm, output weight."""
+        emb = bin.tok_embeddings()
+        norm_weight = bin.norm_weight()
+        output_weight = bin.output_weight()
+
+        def pad_weight(tensor):
+            pad_size = None
+            vocab_size = self.cfg.vocab_size
+            tp = self.cfg.tensor_para_size
+            if vocab_size % tp != 0:
+                pad_size = (vocab_size + tp - 1) // tp * tp - vocab_size
+
+            if pad_size is None:
+                return tensor
+            return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size),
+                                           'constant', 0)
+
+        if emb is not None:
+            emb = pad_weight(emb)
+            self.export_weight(emb, 'tok_embeddings.weight')
+        if norm_weight is not None:
+            self.export_weight(norm_weight, 'norm.weight')
+        if output_weight is not None:
+            output_weight = pad_weight(output_weight)
+            self.export_weight(output_weight, 'output.weight')
+
+    @abstractmethod
+    def export_transformer_block(self, bin: BaseReader, i: int) -> None:
+        """Export transformer block."""
+        pass
+
+
+def permute(x: torch.Tensor, size_per_head: int = 128):
+    if x.shape[-1] > 1:
+        dim = x.shape[-1]
+        n_heads = dim // size_per_head
+        return x.view(-1, n_heads, 2,
+                      dim // n_heads // 2).transpose(2, 3).reshape(-1, dim)
+    else:  # scales, zeros
+        dim = x.shape[0]
+        n_heads = dim // size_per_head
+        return x.view(n_heads, 2, dim // n_heads // 2,
+                      1).transpose(1, 2).reshape(dim, 1)
+
+
+def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
+              dim: int):
+
+    def reshape(x):
+        return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)
+
+    qkv = torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)
+    # (input_dim, head_num + 2 * kv_head_num)
+    return qkv.view(q.size(0), -1)
diff --git a/lmdeploy/turbomind/deploy/target_model/fp.py b/lmdeploy/turbomind/deploy/target_model/fp.py
new file mode 100644
index 0000000000..d9a7783436
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/target_model/fp.py
@@ -0,0 +1,80 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import torch
+
+from ..source_model.base import BaseInputModel, BaseReader
+from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig,
+                   merge_qkv, permute)
+
+
+def transpose_tensor(input: List[torch.Tensor]):
+    """Transpose tensor."""
+    output = [x.cuda().t() for x in input]
+    return output
+
+
+@OUTPUT_MODELS.register_module(name='fp16')
+class TurbomindModel(BaseOutputModel):
+    """Export to turbomind fp16 format."""
+
+    def __init__(self,
+                 input_model: BaseInputModel,
+                 cfg: TurbomindModelConfig,
+                 to_file: bool = True,
+                 out_dir: str = ''):
+        super().__init__(input_model, cfg, to_file, out_dir)
+
+    def get_config(self, cfg: TurbomindModelConfig):
+        """Get turbomind config."""
+        final_cfg = super().get_config(cfg).__dict__
+
+        # attn_bias, inter_size
+        visit = False
+        attn_bias = 0
+        for bin in self.input_model.bins():
+            for i in range(bin.start_layer_id, bin.end_layer_id):
+                visit = True
+                w1, _, _ = bin.ffn(i)
+                inter_size = w1.t().shape[-1]
+                qb, _, _, _ = bin.attn_bias(i)
+                if qb is not None:
+                    attn_bias = 1
+                break
+            if visit:
+                break
+        final_cfg.update(dict(attn_bias=attn_bias, inter_size=inter_size))
+        return TurbomindModelConfig.from_dict(final_cfg)
+
+    def export_transformer_block(self, bin: BaseReader, i: int):
+        """Export transformer layer i."""
+        assert bin.start_layer_id <= i < bin.end_layer_id
+        tp = self.cfg.tensor_para_size
+        size_per_head = self.cfg.size_per_head
+        # attn
+        qw, kw, vw, ow = bin.attn(i)
+        qw, kw, vw, ow = transpose_tensor([qw, kw, vw, ow])
+        qw = permute(qw, size_per_head)
+        kw = permute(kw, size_per_head)
+        qkv_w = merge_qkv(qw, kw, vw, tp, dim=2)
+        self.save_split(qkv_w, f'layers.{i}.attention.w_qkv.weight', -1)
+        self.save_split(ow, f'layers.{i}.attention.wo.weight', 0)
+        qb, kb, vb, ob = bin.attn_bias(i)
+        if qb is not None:
+            qb, kb, vb, ob = transpose_tensor([qb, kb, vb, ob])
+            qb = permute(qb, size_per_head)
+            kb = permute(kb, size_per_head)
+            qkv_b = merge_qkv(qb, kb, vb, tp, dim=1)
+            self.save_split(qkv_b, f'layers.{i}.attention.w_qkv.bias', -1)
+            self.save_split(ob, f'layers.{i}.attention.wo.bias', copy=True)
+        # ffn
+        w1, w2, w3 = bin.ffn(i)
+        w1, w2, w3 = transpose_tensor([w1, w2, w3])
+        self.save_split(w1, f'layers.{i}.feed_forward.w1.weight', -1)
+        self.save_split(w3, f'layers.{i}.feed_forward.w3.weight', -1)
+        self.save_split(w2, f'layers.{i}.feed_forward.w2.weight', 0)
+        # norm
+        attn_norm = bin.attn_norm(i)
+        ffn_norm = bin.ffn_norm(i)
+        self.save_split(attn_norm, f'layers.{i}.attention_norm.weight')
+        self.save_split(ffn_norm, f'layers.{i}.ffn_norm.weight')
diff --git a/lmdeploy/turbomind/deploy/target_model/w4.py b/lmdeploy/turbomind/deploy/target_model/w4.py
new file mode 100644
index 0000000000..282c7df607
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/target_model/w4.py
@@ -0,0 +1,162 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import sys
+
+import torch
+
+import lmdeploy
+
+from ..source_model.base import BaseInputModel, BaseReader
+from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig,
+                   merge_qkv, permute)
+
+# import _turbomind as _tm
+# TODO: find another way import _turbomind
+lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
+sys.path.append(osp.join(lmdeploy_dir, 'lib'))
+import _turbomind as _tm  # noqa: E402
+
+
+def transpose_qk_s4(src: torch.Tensor, group_size):
+    assert src.is_contiguous()
+    dst = torch.zeros_like(src)
+    _tm.transpose_qk_s4_k_m8(src, dst,
+                             src.size(-1) * 8, src.size(0), group_size)
+    return dst
+
+
+def fuse_w1_w3_s4(w1_qw: torch.Tensor, w1_qz: torch.Tensor, w1_s: torch.Tensor,
+                  w3_qw: torch.Tensor, w3_qz: torch.Tensor,
+                  w3_s: torch.Tensor):
+
+    def fuse(a: torch.Tensor, b: torch.Tensor):
+        ab = torch.cat((a, b)).contiguous()
+        _ab = torch.zeros_like(ab)
+        _tm.fuse_w1_w3_s4_k_m8(ab, _ab, a.size(-1) * 8, a.size(0))
+        return _ab.view(a.size(0), -1)
+
+    w13_qw = fuse(w1_qw, w3_qw)
+    w13_qz = fuse(w1_qz, w3_qz)
+
+    w13_s = torch.cat((w1_s, w3_s)).view(2, w1_s.size(0), -1)
+    w13_s = w13_s.permute(1, 2, 0).contiguous().view(w1_s.size(0), -1)
+
+    return w13_qw, w13_qz, w13_s
+
+
+def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
+               group_size: int):
+    assert qw.is_contiguous()
+    assert qz.is_contiguous()
+    assert s.is_contiguous()
+    _qw = torch.zeros_like(qw)
+    _sz = torch.zeros_like(s, dtype=torch.int32)  # half2
+    _ws = torch.zeros_like(s)
+    _tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
+                        qw.size(-1) * 8, qw.size(0), group_size)
+    return _qw, _sz
+
+
+def tp_m_s4(x: torch.Tensor, tp: int):
+    return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
+                                                        1).contiguous()
+
+
+def get_cuda_tensor(tensors):
+    """Get cuda tensor."""
+    result = map(lambda x: x.cuda() if x is not None else x, tensors)
+    return (*result, )
+
+
+@OUTPUT_MODELS.register_module(name='w4')
+class TurbomindW4Model(BaseOutputModel):
+    """Export to turbomind w4a16 format."""
+
+    def __init__(self,
+                 input_model: BaseInputModel,
+                 cfg: TurbomindModelConfig,
+                 to_file: bool = True,
+                 out_dir: str = ''):
+        super().__init__(input_model, cfg, to_file, out_dir)
+
+    def get_config(self, cfg: TurbomindModelConfig):
+        """Get turbomind config."""
+        final_cfg = super().get_config(cfg).__dict__
+
+        # attn_bias, inter_size
+        visit = False
+        attn_bias = 0
+        for bin in self.input_model.bins():
+            for i in range(bin.start_layer_id, bin.end_layer_id):
+                visit = True
+                w1s, _, _ = bin.ffn_scale(i)
+                inter_size = w1s.shape[-1]
+                qb, _, _, _ = bin.attn_bias(i)
+                if qb is not None:
+                    attn_bias = 1
+                break
+            if visit:
+                break
+        final_cfg.update(dict(attn_bias=attn_bias, inter_size=inter_size))
+        return TurbomindModelConfig.from_dict(final_cfg)
+
+    def export_transformer_block(self, bin: BaseReader, i: int):
+        """Export transformer layer i."""
+        group_size = self.cfg.group_size
+        tp = self.cfg.tensor_para_size
+        size_per_head = self.cfg.size_per_head
+        # attn
+        q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i))
+        q_qz, k_qz, v_qz, o_qz = get_cuda_tensor(bin.attn_zero(i))
+        q_s, k_s, v_s, o_s = get_cuda_tensor(bin.attn_scale(i))
+
+        q_qw = transpose_qk_s4(q_qw, group_size)
+        k_qw = transpose_qk_s4(k_qw, group_size)
+        q_qz = transpose_qk_s4(q_qz, group_size)
+        k_qz = transpose_qk_s4(k_qz, group_size)
+        q_s = permute(q_s, size_per_head)
+        k_s = permute(k_s, size_per_head)
+
+        qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2)
+        qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
+        qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
+
+        qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
+        qkv_qw = tp_m_s4(qkv_qw, tp)
+        self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1)
+        self.save_split(qkv_sz, f'layers.{i}.attention.w_qkv.scales_zeros', -1)
+
+        o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
+        self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0)
+        self.save_split(o_sz, f'layers.{i}.attention.wo.scales_zeros', 0)
+
+        q_b, k_b, v_b, o_b = get_cuda_tensor(bin.attn_bias(i))
+        if q_b is not None:
+            q_b = permute(q_b, size_per_head)
+            k_b = permute(k_b, size_per_head)
+            qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
+            self.save_split(qkv_b, f'layers.{i}.attention.w_qkv.bias', -1)
+            self.save_split(o_b, f'layers.{i}.attention.wo.bias', copy=True)
+
+        # ffn weights
+        w1_qw, w2_qw, w3_qw = get_cuda_tensor(bin.ffn(i))
+        w1_qz, w2_qz, w3_qz = get_cuda_tensor(bin.ffn_zero(i))
+        w1_s, w2_s, w3_s = get_cuda_tensor(bin.ffn_scale(i))
+
+        w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
+                                              w3_s)
+        w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
+        w13_qw = tp_m_s4(w13_qw, tp)
+        self.save_split(w13_qw, f'layers.{i}.feed_forward.w13.qweight', -1)
+        self.save_split(w13_sz, f'layers.{i}.feed_forward.w13.scales_zeros',
+                        -1)
+
+        w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
+        self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0)
+        self.save_split(w2_sz, f'layers.{i}.feed_forward.w2.scales_zeros', 0)
+
+        # norm
+        attn_norm = bin.attn_norm(i)
+        ffn_norm = bin.ffn_norm(i)
+        self.save_split(attn_norm, f'layers.{i}.attention_norm.weight')
+        self.save_split(ffn_norm, f'layers.{i}.ffn_norm.weight')
diff --git a/lmdeploy/turbomind/generate_gemm_config.py b/lmdeploy/turbomind/generate_gemm_config.py
index 328f182158..9a4f0e8c4d 100644
--- a/lmdeploy/turbomind/generate_gemm_config.py
+++ b/lmdeploy/turbomind/generate_gemm_config.py
@@ -2,8 +2,6 @@
 
 import subprocess
 
-import fire
-
 
 def get_llama_gemm():
     import os.path as osp
@@ -30,4 +28,6 @@ def main(head_num: int = 32,
 
 
 if __name__ == '__main__':
+    import fire
+
     fire.Fire(main)
diff --git a/lmdeploy/turbomind/hf_repo/config.json b/lmdeploy/turbomind/hf_repo/config.json
new file mode 100644
index 0000000000..9778905e33
--- /dev/null
+++ b/lmdeploy/turbomind/hf_repo/config.json
@@ -0,0 +1,11 @@
+{
+    "architectures": [
+        "LMDeployForCausalLM"
+    ],
+    "auto_map": {
+        "AutoConfig": "configuration_lmdeploy.LMDeployConfig",
+        "AutoModel": "modeling_lmdeploy.LMDeployForCausalLM",
+        "AutoModelForCausalLM": "modeling_lmdeploy.LMDeployForCausalLM"
+    },
+    "turbomind": {}
+}
diff --git a/lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py b/lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py
new file mode 100644
index 0000000000..880ad66e81
--- /dev/null
+++ b/lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from transformers import PretrainedConfig
+
+from lmdeploy.turbomind.deploy.target_model.base import TurbomindModelConfig
+from lmdeploy.version import __version__ as lm_version
+
+
+class LMDeployConfig(PretrainedConfig):
+    """Lmdeploy config."""
+
+    def __init__(self, turbomind: dict = None, **kwargs):
+        default_tm_cfg = copy.deepcopy(
+            TurbomindModelConfig.from_dict({}, allow_none=True).__dict__)
+        if turbomind is not None:
+            default_tm_cfg.update(turbomind)
+        self.turbomind = default_tm_cfg
+        self.lmdeploy_version = lm_version
+        super().__init__(**kwargs)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+        return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
+        config, kwargs = super().from_pretrained(pretrained_model_name_or_path,
+                                                 return_unused_kwargs=True,
+                                                 **kwargs)
+        for k, v in kwargs.items():
+            if k in config.turbomind.keys():
+                config.turbomind[k] = v
+        if 'tp' in kwargs:
+            config.turbomind['tensor_para_size'] = kwargs['tp']
+        if return_unused_kwargs:
+            return config, kwargs
+        else:
+            return config
diff --git a/lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py b/lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py
new file mode 100644
index 0000000000..ffb9b05613
--- /dev/null
+++ b/lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py
@@ -0,0 +1,226 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import dataclasses
+import os
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from itertools import count
+from queue import Queue
+from typing import List, Optional, Tuple, Union
+
+from huggingface_hub import snapshot_download
+from transformers import PretrainedConfig
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from lmdeploy.turbomind import TurboMind
+from lmdeploy.turbomind.utils import get_gen_param
+
+from .configuration_lmdeploy import LMDeployConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class Session:
+    _count = count()
+    _session_id: int = None
+    _message: List[Tuple[str, str]] = field(default_factory=list)
+    _step: int = 0
+    _nth_round: int = 0
+    _error: int = 0
+
+    def __init__(self):
+        self._session_id = next(Session._count)
+        self._message = []
+        self._step = 0
+        self._nth_round = 0
+
+    @property
+    def session_id(self):
+        return self._session_id
+
+    @property
+    def message(self):
+        return self._message
+
+    @property
+    def step(self):
+        return self._step
+
+    @property
+    def nth_round(self):
+        return self._nth_round
+
+    @property
+    def error(self):
+        return self._error
+
+
+class LMDeployForCausalLM(PreTrainedModel):
+    config_class = LMDeployConfig
+
+    def __init__(self,
+                 config: LMDeployConfig,
+                 *inputs,
+                 model_path: str = None,
+                 **kwargs):
+        super().__init__(config)
+        self.tm_model = TurboMind.from_pretrained(model_path, **kwargs)
+        que = Queue()
+        for _ in range(config.turbomind['max_batch_size']):
+            que.put(self.tm_model.create_instance())
+        self.que = que
+
+    @classmethod
+    def from_pretrained(cls,
+                        pretrained_model_name_or_path,
+                        *model_args,
+                        config: Optional[Union[PretrainedConfig, str,
+                                               os.PathLike]] = None,
+                        cache_dir: Optional[Union[str, os.PathLike]] = None,
+                        force_download: bool = False,
+                        local_files_only: bool = False,
+                        token: Optional[Union[str, bool]] = None,
+                        revision: str = 'main',
+                        **kwargs):
+        """Instantiate a LM model with turbomind backend."""
+
+        resume_download = kwargs.pop('resume_download', True)
+        proxies = kwargs.pop('proxies', None)
+
+        if os.path.isdir(pretrained_model_name_or_path):
+            local_folder = pretrained_model_name_or_path
+        else:
+            local_folder = snapshot_download(
+                pretrained_model_name_or_path,
+                revision=revision,
+                cache_dir=cache_dir,
+                proxies=proxies,
+                resume_download=resume_download,
+                force_download=force_download,
+                token=token,
+                local_files_only=local_files_only,
+            )
+
+        if not isinstance(config, PretrainedConfig):
+            config_path = config if config is not None else local_folder
+            kwargs.pop('return_unused_kwargs')
+            config, model_kwargs = cls.config_class.from_pretrained(
+                config_path, return_unused_kwargs=True, **kwargs)
+        else:
+            model_kwargs = kwargs
+
+        model = cls(config,
+                    *model_args,
+                    model_path=local_folder,
+                    **model_kwargs)
+
+        generation_config = model.tm_model.model.sampling_param
+        for k, v in dataclasses.asdict(generation_config).items():
+            if hasattr(model.generation_config, k):
+                base_value = getattr(model.generation_config, k)
+                setattr(generation_config, k, base_value)
+            if k in kwargs:
+                setattr(generation_config, k, v)
+        model.generation_config = generation_config
+
+        return model
+
+    @contextmanager
+    def managed_generator(self, session: Session):
+        generator = self.que.get()
+        try:
+            yield generator
+        except:  # noqa E722
+            for _ in generator.stream_infer(session.session_id, [0],
+                                            request_output_len=0,
+                                            sequence_start=False,
+                                            sequence_end=False,
+                                            stop=True):
+                pass
+            session._error = 1
+        finally:
+            self.que.put(generator)
+
+    def generate(
+        self,
+        input_ids: List[int],
+        session: Session,
+        **kwargs,
+    ):
+        """Generates sequences of token ids for models with a language modeling
+        head.
+
+        Args:
+            input_ids (List(int)): list of input token ids
+            session (Session) session information
+            kwargs (dict): hoc parametrization of generation
+        """
+        with self.managed_generator(session) as generator:
+            for outputs in generator.stream_infer(
+                    session_id=session.session_id,
+                    input_ids=[input_ids],
+                    **kwargs,
+            ):
+                res, tokens = outputs[0]
+                yield res, tokens
+
+    def chat(
+        self,
+        query: str,
+        session: Optional[Session] = None,
+        cap: str = 'chat',
+        request_output_len: int = 512,
+        stream_output: bool = False,
+        ignore_eos=False,
+        random_seed: Optional[int] = None,
+        **kwargs,
+    ) -> Tuple[str, Session]:
+        """chat."""
+
+        if session is None:
+            session = Session()
+        assert session._error == 0, 'An error occurred before, ' \
+            'please start a new session.'
+
+        session._message.append([query, ''])
+
+        prompt = self.tm_model.model.get_prompt(query, session.nth_round == 0)
+        input_ids = self.tm_model.tokenizer.encode(prompt)
+
+        if len(
+                input_ids
+        ) + session.step + request_output_len >= self.tm_model.session_len:
+            logger.error(
+                f'session_length exceeded {self.tm_model.session_len}')
+            session._error = 1
+            yield '', session
+        else:
+            gen_param = get_gen_param(cap, self.generation_config,
+                                      session.nth_round + 1, session.step,
+                                      request_output_len, **kwargs)
+            gen_kwargs = dataclasses.asdict(gen_param)
+            gen_kwargs.update(
+                random_seed=random_seed if session.nth_round == 0 else None,
+                stream_output=stream_output,
+                ignore_eos=ignore_eos,
+                **kwargs)
+
+            _step = session._step
+            _nth_round = session._nth_round
+            response_size = 0
+
+            for res, tokens in self.generate(input_ids,
+                                             session=session,
+                                             **gen_kwargs):
+                response = self.tm_model.tokenizer.decode(res.tolist(),
+                                                          offset=response_size)
+                if response.endswith('�'):
+                    continue
+                response_size = tokens
+
+                session._message[-1][-1] += response
+                session._nth_round = _nth_round + 1
+                session._step = _step + response_size
+
+                yield response, session
diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py
index dcfc499e89..3c46b270ea 100644
--- a/lmdeploy/turbomind/turbomind.py
+++ b/lmdeploy/turbomind/turbomind.py
@@ -1,36 +1,51 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import asyncio
+import copy
+import io
+import json
+import logging
 import os.path as osp
 import sys
 from configparser import ConfigParser
 from contextlib import contextmanager
 from queue import Queue
 from threading import Thread
-from typing import Iterable, List
+from typing import Iterable, List, Optional
 
 import numpy as np
 import torch
+from huggingface_hub import snapshot_download
 from torch.nn.utils.rnn import pad_sequence
 
 import lmdeploy
-from lmdeploy.model import MODELS
+from lmdeploy.model import MODELS, BaseModel
 from lmdeploy.tokenizer import Tokenizer
 from lmdeploy.utils import get_logger
 
+from .deploy.converter import get_model_format, supported_formats
+from .deploy.source_model.base import INPUT_MODELS
+from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig
+from .utils import (ModelSource, check_tm_model_input, create_hf_download_args,
+                    get_hf_config_content, get_model_source)
+
 # TODO: find another way import _turbomind
 lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
 sys.path.append(osp.join(lmdeploy_dir, 'lib'))
 import _turbomind as _tm  # noqa: E402
 
+logger = logging.getLogger(__name__)
+
 
 def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
     """return list of stop-words to numpy.ndarray."""
     if stop_words is None:
         return None
     assert isinstance(stop_words, List) and \
-           all(isinstance(elem, str) for elem in stop_words), \
-           f'stop_words must be a list but got {type(stop_words)}'
-    stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
+        all(isinstance(elem, str) for elem in stop_words), \
+        f'stop_words must be a list but got {type(stop_words)}'
+    stop_words = [
+        tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words
+    ]
     assert isinstance(stop_words, List) and all(
         isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
     # each id in stop_words represents a stop word
@@ -74,74 +89,289 @@ class TurboMind:
 
     Args:
         model_path (str): the path of turbomind's model
-        eos_id (int): eos token id
+        model_source (int): model source
+        model_name (str): needed when model_path is a hf model and not
+            managed by lmdeploy
+        model_format (str): needed when model_path is a hf model and not
+            managed by lmdeploy
+        group_size (int): needed when model_path is a hf model and not
+            managed by lmdeploy
         tp (int): tensor parallel
     """
 
-    def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1):
-        self.eos_id = eos_id
+    def __init__(self,
+                 model_path: str,
+                 model_source: ModelSource = ModelSource.WORKSPACE,
+                 model_name: Optional[str] = None,
+                 model_format: Optional[str] = None,
+                 group_size: Optional[int] = None,
+                 tp: Optional[int] = None,
+                 **kwargs):
+        if tp is not None:
+            assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
+        self.gpu_count = tp if tp is not None else 1
+
+        if model_source == ModelSource.WORKSPACE:
+            tokenizer_model_path = osp.join(model_path, 'triton_models',
+                                            'tokenizer')
+            self.tokenizer = Tokenizer(tokenizer_model_path)
+            self.model_comm = self._from_workspace(model_path)
+        else:
+            self.tokenizer = Tokenizer(model_path)
+            self.model_comm = self._from_hf(model_source=model_source,
+                                            model_path=model_path,
+                                            model_name=model_name,
+                                            model_format=model_format,
+                                            group_size=group_size,
+                                            tp=tp,
+                                            **kwargs)
+
+        self.eos_id = self.tokenizer.eos_token_id
+        self.model: BaseModel = MODELS.get(self.model_name)(**kwargs)
+        self.session_len = self.model.session_len
+        self.stop_words = _stop_words(self.model.stop_words, self.tokenizer)
+
+    def _create_weight(self, model_comm):
+        """Allocate weight buffer, load params if from_workspace."""
 
         # TODO: support mpi
-        node_id = 0
-        node_num = 1
-
-        # read meta from model path
-        assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
-        self.gpu_count = tp
-        self.session_len = 2048
-        data_type = 'fp16'
-        ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
-        with open(ini_path, 'r') as f:
-            parser = ConfigParser()
-            parser.read_file(f)
-            section_name = ''
-            if 'turbomind' in parser:
-                section_name = 'turbomind'
-            elif 'llama' in parser:
-                section_name = 'llama'
-
-            if len(section_name) > 0:
-                tp_cfg = parser.getint(section_name, 'tensor_para_size')
-                self.session_len = parser.getint(section_name, 'session_len')
-                if tp_cfg != 1 and tp_cfg != tp:
-                    get_logger('turbomind').info(
-                        f'found tp={tp_cfg} in config.ini.')
-                    self.gpu_count = tp_cfg
-            self.model_name = parser.get(section_name, 'model_name')
-            data_type = parser.get(section_name, 'weight_type')
-        model = MODELS.get(self.model_name)()
-        tokenizer_model_path = osp.join(model_path, 'triton_models',
-                                        'tokenizer')
-        tokenizer = Tokenizer(tokenizer_model_path)
-        self.stop_words = _stop_words(model.stop_words, tokenizer)
-
-        # params
-        self.node_id = node_id
-        self.node_num = node_num
-        self.world_size = self.node_num * self.gpu_count
-
-        # create model
-        weight_dir = osp.join(model_path, 'triton_models', 'weights')
-        model = _tm.AbstractTransformerModel.create_llama_model(
-            weight_dir, tensor_para_size=self.gpu_count, data_type=data_type)
-        self.model = model
-        self.nccl_params = model.create_nccl_params(self.node_id)
+        self.node_id = 0
+        self.node_num = 1
+        self.nccl_params = model_comm.create_nccl_params(self.node_id)
         torch.cuda.synchronize()
 
         # create weight
-        def _create_weight(device_id):
+        def _create_weight_func(device_id):
             with cuda_ctx(device_id):
                 rank = self.node_id * self.gpu_count + device_id
-                model.create_shared_weights(device_id, rank)
+                model_comm.create_shared_weights(device_id, rank)
 
         threads = []
         for device_id in range(self.gpu_count):
-            t = Thread(target=_create_weight, args=(device_id, ))
+            t = Thread(target=_create_weight_func, args=(device_id, ))
             t.start()
             threads.append(t)
         for t in threads:
             t.join()
 
+    def _load_kv_qparams(self, model_path, tm_params, **kwargs):
+        """Load kv qparams when loading from hf."""
+        if self.config.quant_policy:
+            logger.warning('loading kv_cache quant scale')
+            from lmdeploy.lite.apis.kv_qparams import main as kv_loader
+            kv_sym = kwargs.get('kv_sym', False)
+            kv_bits = kwargs.get('kv_bits', 8)
+            tp = self.config.tensor_para_size
+            kv_loader(model_path, model_path, kv_bits, kv_sym, tp, tm_params)
+        else:
+            for key in list(tm_params.keys()):
+                if 'past_kv_scale' in key:
+                    tm_params.pop(key)
+
+    def _get_model_params(self, model_comm, tm_params):
+        """Get turbomind model params when loading from hf."""
+
+        def _get_params(device_id, que):
+            with cuda_ctx(device_id):
+                rank = self.node_id * self.gpu_count + device_id
+                out = model_comm.get_params(device_id, rank)
+                que.put(out)
+
+        que = Queue()
+        threads = []
+        for device_id in range(self.gpu_count):
+            t = Thread(target=_get_params, args=(device_id, que))
+            t.start()
+            threads.append(t)
+        for t in threads:
+            t.join()
+
+        for _ in range(self.gpu_count):
+            tensor_map = que.get()
+            for k, v in tensor_map.items():
+                if k not in tm_params:
+                    tm_params[k] = []
+                tm_params[k].append(v)
+
+    def _from_hf(self,
+                 model_source: ModelSource,
+                 model_path: str,
+                 model_name: Optional[str] = None,
+                 model_format: Optional[str] = None,
+                 group_size: Optional[int] = None,
+                 tp: Optional[int] = None,
+                 **kwargs):
+        """Load model which is in hf format."""
+        # get model_name, group_size if is lmdeploy managed.
+        if model_source == ModelSource.HF_LMDEPLOY:
+            config = get_hf_config_content(model_path, local_files_only=True)
+            tm_config = config['turbomind']
+            tm_config.update(kwargs)
+            var_shoud_be_none = dict(model_name=model_name,
+                                     model_format=model_format,
+                                     group_size=group_size)
+            for key, value in var_shoud_be_none.items():
+                assert value is None, f'{key} should be None when model is '\
+                    f'from {model_source}'
+            model_name = tm_config['model_name']
+            group_size = tm_config['group_size']
+            if tm_config['weight_type'] == 'int4':
+                model_format = 'awq'
+        else:
+            assert model_name is not None, 'please supply model_name when ' \
+                f'model is form {model_source}'
+            if osp.exists(osp.join(model_path, 'outputs_stats.pth')):
+                model_format = 'awq' if model_format is None else model_format
+                group_size = 128 if group_size is None else group_size
+            tm_config = kwargs
+
+        assert model_name in MODELS.module_dict.keys(), \
+            f"'{model_name}' is not supported. " \
+            f'The supported models are: {MODELS.module_dict.keys()}'
+        assert model_format in supported_formats, 'the model format ' \
+            f'should be in {supported_formats}'
+
+        data_type = 'fp16'
+        output_format = 'fp16'
+        inferred_model_format = get_model_format(model_name, model_format)
+        cfg = TurbomindModelConfig.from_dict(tm_config, allow_none=True)
+
+        # overwrite with input params
+        cfg.model_name = model_name
+        cfg.tensor_para_size = 1 if tp is None else tp
+        cfg.rotary_embedding = cfg.size_per_head
+        cfg.group_size = group_size
+        if inferred_model_format.find('awq') != -1:
+            cfg.weight_type = 'int4'
+            output_format = 'w4'
+            data_type = 'int4'
+            assert group_size > 0, f'group_size: {group_size} should > 0'
+
+        self.config = cfg
+        self.model_name = model_name
+        self.data_type = data_type
+
+        input_model = INPUT_MODELS.get(inferred_model_format)(
+            model_path=model_path, tokenizer_path=model_path, ckpt_path=None)
+
+        output_model = OUTPUT_MODELS.get(output_format)(
+            input_model=input_model, cfg=cfg, to_file=False, out_dir='')
+
+        config = copy.deepcopy(output_model.cfg.__dict__)
+        logger.warning(f'model_config:\n{json.dumps(config, indent=2)}')
+        parser = ConfigParser()
+        parser['llama'] = config
+        with io.StringIO() as ss:
+            parser.write(ss)
+            ss.seek(0)
+            config = ss.read()
+
+        model_comm = _tm.AbstractTransformerModel.create_llama_model(
+            model_dir='',
+            config=config,
+            tensor_para_size=self.gpu_count,
+            data_type=data_type)
+
+        # create empty weight
+        self._create_weight(model_comm)
+
+        # copy hf model weight to turbomind weight
+        tm_params = output_model.tm_params
+        self._get_model_params(model_comm, tm_params)
+        logger.warning(f'get {len(tm_params)} model params')
+        output_model.export()
+
+        # load kv qparams
+        self._load_kv_qparams(model_path, tm_params, **kwargs)
+        assert len(tm_params) == 0, f'missing {tm_params.keys()}'
+
+        return model_comm
+
+    def _from_workspace(self, model_path: str):
+        """Load model which is converted by `lmdeploy convert`"""
+        ini_path = osp.join(model_path, 'triton_models', 'weights',
+                            'config.ini')
+        with open(ini_path, 'r') as f:
+            parser = ConfigParser()
+            parser.read_file(f)
+            section_name = 'llama'
+            tp_cfg = parser.getint(section_name, 'tensor_para_size')
+
+            if tp_cfg != 1 and tp_cfg != self.gpu_count:
+                get_logger('turbomind').info(
+                    f'found tp={tp_cfg} in config.ini.')
+                self.gpu_count = tp_cfg
+            self.model_name = parser.get(section_name, 'model_name')
+            self.data_type = parser.get(section_name, 'weight_type')
+            cfg = parser._sections[section_name]
+            cfg = TurbomindModelConfig.from_dict(cfg)
+            self.config = cfg
+
+        # create model
+        weight_dir = osp.join(model_path, 'triton_models', 'weights')
+        model_comm = _tm.AbstractTransformerModel.create_llama_model(
+            weight_dir,
+            tensor_para_size=self.gpu_count,
+            data_type=self.data_type)
+
+        # create weight and load params
+        self._create_weight(model_comm)
+        return model_comm
+
+    @classmethod
+    def from_pretrained(cls,
+                        pretrained_model_name_or_path: str,
+                        model_name: Optional[str] = None,
+                        model_format: Optional[str] = None,
+                        group_size: Optional[int] = None,
+                        tp: Optional[int] = None,
+                        **kwargs):
+        """LMDeploy's turbomind inference engine.
+
+        Args:
+            pretrained_model_name_or_path (str):
+                It could be one of the following options:
+                    - i) A local directory path of a turbomind model which is
+                      converted by `lmdeploy convert` command or download from
+                      ii) and iii)
+                    - ii) The model_id of a lmdeploy-quantized model hosted
+                      inside a model repo on huggingface.co, such as
+                      "InternLM/internlm-chat-20b-4bit",
+                      "lmdeploy/llama2-chat-70b-4bit", etc.
+                    - iii) The model_id of a model hosted inside a model repo
+                      on huggingface.co, such as "InternLM/internlm-chat-7b",
+                      "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
+                      and so on.
+            model_name (str): needed when pretrained_model_name_or_path is c)
+            model_format (str): model format
+            group_size (int): group size
+            tp (int): tensor parallel size
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update configuration when initialize the engine.
+        """
+        model_source = get_model_source(pretrained_model_name_or_path)
+        if model_source == ModelSource.WORKSPACE:
+            local_path = pretrained_model_name_or_path
+        else:
+            check_tm_model_input(pretrained_model_name_or_path,
+                                 model_name=model_name,
+                                 **kwargs)
+            if not osp.exists(pretrained_model_name_or_path):
+                download_kwargs = create_hf_download_args(**kwargs)
+                local_path = snapshot_download(pretrained_model_name_or_path,
+                                               **download_kwargs)
+            else:
+                local_path = pretrained_model_name_or_path
+
+        logger.warning(f'model_source: {model_source}')
+        return cls(model_source=model_source,
+                   model_path=local_path,
+                   model_name=model_name,
+                   model_format=model_format,
+                   group_size=group_size,
+                   tp=tp,
+                   **kwargs)
+
     def create_instance(self, cuda_stream_id=0):
         """Create a turbomind instance.
 
@@ -161,7 +391,7 @@ class TurboMindInstance:
         cuda_stream_id(int): identity of a cuda stream
     """
 
-    def __init__(self, tm_model, cuda_stream_id=0):
+    def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0):
         self.tm_model = tm_model
         self.cuda_stream_id = cuda_stream_id
 
@@ -175,8 +405,6 @@ def __init__(self, tm_model, cuda_stream_id=0):
         self.session_len = tm_model.session_len
 
         self.nccl_params = tm_model.nccl_params
-        self.instance_comm = tm_model.model.create_instance_comm(
-            self.gpu_count)
 
         # create model instances
         model_insts = [None] * self.gpu_count
@@ -196,7 +424,7 @@ def __init__(self, tm_model, cuda_stream_id=0):
     def _create_model_instance(self, device_id, model_insts):
         with cuda_ctx(device_id):
             rank = self.node_id * self.gpu_count + device_id
-            model_inst = self.tm_model.model.create_model_instance(
+            model_inst = self.tm_model.model_comm.create_model_instance(
                 device_id, rank, self.cuda_stream_id, self.nccl_params)
             model_insts[device_id] = model_inst
 
@@ -204,16 +432,20 @@ def _forward_callback(self, result, ctx):
         self.que.put((False, result))
 
     def _forward_thread(self, inputs):
+        instance_comm = self.tm_model.model_comm.create_instance_comm(
+            self.gpu_count)
 
         def _func(device_id, enque_output):
             with cuda_ctx(device_id):
                 output = self.model_insts[device_id].forward(
-                    inputs, self.instance_comm)
+                    inputs, instance_comm)
                 if enque_output:
                     self.que.put((True, output))
 
         for device_id in range(self.gpu_count):
-            t = Thread(target=_func, args=(device_id, device_id == 0))
+            t = Thread(target=_func,
+                       args=(device_id, device_id == 0),
+                       daemon=True)
             t.start()
             self.threads[device_id] = t
 
@@ -262,11 +494,11 @@ def stream_infer(self,
             random_seed (int): seed used by sampling
             stream_output (bool): indicator for stream output
         """
-        if stream_output:
+        if stream_output and not stop:
             self.model_insts[0].register_callback(self._forward_callback)
 
         if len(input_ids) == 0:
-            input_ids = []
+            input_ids = [[]]
         if isinstance(input_ids[0], int):
             input_ids = [input_ids]
 
@@ -330,6 +562,7 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
         tm_inputs = _np_dict_to_tm_dict(inputs)
 
         # start forward thread
+        self.que = Queue()
         self._forward_thread(tm_inputs)
 
         seq_start = input_lengths + input_lengths.new_tensor(step)
@@ -344,7 +577,7 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
             outputs = _tm_dict_to_torch_dict(tm_outputs)
 
             output_ids = outputs['output_ids'][:, 0, :]
-            sequence_length = outputs['sequence_length'].long()[:, 0].cpu()
+            sequence_length = outputs['sequence_length'].long()[:, 0]
             output_ids = [
                 output_id[s:l] for output_id, s, l in zip(
                     output_ids, seq_start, sequence_length)
@@ -354,13 +587,13 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
             outputs = []
             for output, len_ in zip(output_ids, sequence_length):
                 output, len_ = output, len_.item()
-                if len(output) > 0 and output[-1].item() == self.eos_id:
+                if len(output) > 0 and output[-1].item(
+                ) == self.eos_id and not ignore_eos:
                     outputs.append((output[:-1], len_ - 1))
                 elif len(output) > 0 and output[-1].item() in self.stop_tokens:
                     outputs.append((output[:-1], len_))
                 else:
                     outputs.append((output, len_))
-
             yield outputs
 
             if finish:
@@ -370,7 +603,7 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
                     self.que.get()
                 break
 
-        if stream_output:
+        if stream_output and not stop:
             self.model_insts[0].unregister_callback()
 
     def decode(self, input_ids):
@@ -381,7 +614,7 @@ def decode(self, input_ids):
         """
 
         if len(input_ids) == 0:
-            input_ids = []
+            input_ids = [[]]
         if isinstance(input_ids[0], int):
             input_ids = [input_ids]
 
diff --git a/lmdeploy/turbomind/utils.py b/lmdeploy/turbomind/utils.py
new file mode 100644
index 0000000000..20540c1df3
--- /dev/null
+++ b/lmdeploy/turbomind/utils.py
@@ -0,0 +1,120 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import dataclasses
+import json
+import logging
+import os
+
+from huggingface_hub import hf_hub_download
+from transformers.utils import ExplicitEnum
+
+logger = logging.getLogger(__name__)
+
+
+class ModelSource(ExplicitEnum):
+    """Turbomind model source."""
+    WORKSPACE = 'workspace'
+    HF_MODEL = 'hf_model'
+    HF_LMDEPLOY = 'hf_lmdeploy'
+
+
+def create_hf_download_args(**kwargs) -> dict:
+    download_kwargs = {
+        'revision': None,
+        'cache_dir': None,
+        'proxies': None,
+        'resume_download': True,
+        'force_download': False,
+        'token': None,
+        'local_files_only': False
+    }
+    for k in download_kwargs.keys():
+        if k in kwargs:
+            download_kwargs[k] = kwargs[k]
+    return download_kwargs
+
+
+def get_hf_config_path(pretrained_model_name_or_path, **kwargs) -> str:
+    """Get local hf config local file path."""
+    if os.path.exists(pretrained_model_name_or_path):
+        config_path = os.path.join(pretrained_model_name_or_path,
+                                   'config.json')
+    else:
+        download_kwargs = create_hf_download_args(**kwargs)
+        config_path = hf_hub_download(pretrained_model_name_or_path,
+                                      'config.json', **download_kwargs)
+    return config_path
+
+
+def get_hf_config_content(pretrained_model_name_or_path, **kwargs) -> dict:
+    """Get config content of a hf model."""
+    config_path = get_hf_config_path(pretrained_model_name_or_path, **kwargs)
+    with open(config_path, 'r') as f:
+        config = json.load(f)
+    return config
+
+
+def get_model_source(pretrained_model_name_or_path: str,
+                     **kwargs) -> ModelSource:
+    """Get model source."""
+    triton_model_path = os.path.join(pretrained_model_name_or_path,
+                                     'triton_models')
+    if os.path.exists(triton_model_path):
+        return ModelSource.WORKSPACE
+    config = get_hf_config_content(pretrained_model_name_or_path, **kwargs)
+    model_source = ModelSource.HF_LMDEPLOY if 'turbomind' in config \
+        else ModelSource.HF_MODEL
+    return model_source
+
+
+def check_tm_model_input(pretrained_model_name_or_path, **kwargs):
+    """Check if single input pretrained_model_name_or_path is enough to use."""
+    if kwargs.get('model_name', None):
+        return
+
+    model_source = get_model_source(pretrained_model_name_or_path, **kwargs)
+    if model_source == ModelSource.WORKSPACE:
+        return
+
+    config = get_hf_config_content(pretrained_model_name_or_path, **kwargs)
+    if 'turbomind' in config and config['turbomind']['model_name'] != '':
+        return
+
+    assert (0), '\nCan not get model name from input model, '\
+        'please supply model name with arg --model-name,' \
+        'you can list supported models by `lmdeploy list`'
+
+
+@dataclasses.dataclass
+class GenParam:
+    top_p: float
+    top_k: float
+    temperature: float
+    repetition_penalty: float
+    sequence_start: bool = False
+    sequence_end: bool = False
+    step: int = 0
+    request_output_len: int = 512
+
+
+def get_gen_param(cap,
+                  sampling_param,
+                  nth_round,
+                  step,
+                  request_output_len=512,
+                  **kwargs):
+    """return parameters used by token generation."""
+    gen_param = GenParam(**dataclasses.asdict(sampling_param),
+                         request_output_len=request_output_len)
+    # Fix me later. turbomind.py doesn't support None top_k
+    if gen_param.top_k is None:
+        gen_param.top_k = 40
+
+    if cap == 'chat':
+        gen_param.sequence_start = (nth_round == 1)
+        gen_param.sequence_end = False
+        gen_param.step = step
+    else:
+        gen_param.sequence_start = True
+        gen_param.sequence_end = True
+        gen_param.step = 0
+    return gen_param
diff --git a/lmdeploy/version.py b/lmdeploy/version.py
index 417dc76768..1109a4bcdc 100644
--- a/lmdeploy/version.py
+++ b/lmdeploy/version.py
@@ -1,7 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from typing import Tuple
 
-__version__ = '0.0.11'
+__version__ = '0.1.0a2'
 short_version = __version__
 
 
diff --git a/requirements.txt b/requirements.txt
index 9eacb498fb..91d38808f1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,19 +1,4 @@
-accelerate
-datasets
-fastapi
-fire
-gradio
-mmengine
-numpy
-pybind11
-
-pycuda
-safetensors
-sentencepiece
-setuptools
-shortuuid
-tiktoken
-torch
-transformers>=4.33.0
-tritonclient[all]
-uvicorn
+-r requirements/build.txt
+-r requirements/runtime.txt
+-r requirements/lite.txt
+-r requirements/serve.txt
diff --git a/requirements/build.txt b/requirements/build.txt
new file mode 100644
index 0000000000..b4430ae374
--- /dev/null
+++ b/requirements/build.txt
@@ -0,0 +1,2 @@
+pybind11
+setuptools
diff --git a/requirements/lite.txt b/requirements/lite.txt
new file mode 100644
index 0000000000..bd10933103
--- /dev/null
+++ b/requirements/lite.txt
@@ -0,0 +1,3 @@
+accelerate
+datasets
+flash-attn
diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt
index 2f4885a20b..1fdf851e2b 100644
--- a/requirements/readthedocs.txt
+++ b/requirements/readthedocs.txt
@@ -1,4 +1,4 @@
-mmengine
+mmengine-lite
 torch
 transformers
 urllib3<2.0.0
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
new file mode 100644
index 0000000000..9acb819813
--- /dev/null
+++ b/requirements/runtime.txt
@@ -0,0 +1,9 @@
+fire
+mmengine-lite
+numpy
+pycuda
+safetensors
+sentencepiece
+tiktoken
+torch
+transformers>=4.33.0
diff --git a/requirements/serve.txt b/requirements/serve.txt
new file mode 100644
index 0000000000..11350c2900
--- /dev/null
+++ b/requirements/serve.txt
@@ -0,0 +1,5 @@
+fastapi
+gradio<4.0.0
+pydantic>2.0.0
+shortuuid
+uvicorn
diff --git a/requirements/test.txt b/requirements/test.txt
new file mode 100644
index 0000000000..2125b2daaa
--- /dev/null
+++ b/requirements/test.txt
@@ -0,0 +1,5 @@
+allure-pytest
+coverage
+pynvml
+pytest
+pyyaml
diff --git a/setup.py b/setup.py
index 09ae1e31c2..ff009b735f 100644
--- a/setup.py
+++ b/setup.py
@@ -121,26 +121,36 @@ def gen_packages_items():
 
 if __name__ == '__main__':
     lmdeploy_package_data = ['lmdeploy/bin/llama_gemm']
-    setup(name='lmdeploy',
-          version=get_version(),
-          description='A toolset for compressing, deploying and serving LLM',
-          long_description=readme(),
-          long_description_content_type='text/markdown',
-          author='OpenMMLab',
-          author_email='openmmlab@gmail.com',
-          packages=find_packages(exclude=()),
-          package_data={
-              'lmdeploy': lmdeploy_package_data,
-          },
-          include_package_data=True,
-          install_requires=parse_requirements('requirements.txt'),
-          has_ext_modules=check_ext_modules,
-          classifiers=[
-              'Programming Language :: Python :: 3.8',
-              'Programming Language :: Python :: 3.9',
-              'Programming Language :: Python :: 3.10',
-              'Programming Language :: Python :: 3.11',
-              'Intended Audience :: Developers',
-              'Intended Audience :: Education',
-              'Intended Audience :: Science/Research',
-          ])
+    setup(
+        name='lmdeploy',
+        version=get_version(),
+        description='A toolset for compressing, deploying and serving LLM',
+        long_description=readme(),
+        long_description_content_type='text/markdown',
+        author='OpenMMLab',
+        author_email='openmmlab@gmail.com',
+        packages=find_packages(exclude=()),
+        package_data={
+            'lmdeploy': lmdeploy_package_data,
+        },
+        include_package_data=True,
+        setup_requires=parse_requirements('requirements/build.txt'),
+        tests_require=parse_requirements('requirements/test.txt'),
+        install_requires=parse_requirements('requirements/runtime.txt'),
+        extras_require={
+            'all': parse_requirements('requirements.txt'),
+            'lite': parse_requirements('requirements/lite.txt'),
+            'serve': parse_requirements('requirements/serve.txt')
+        },
+        has_ext_modules=check_ext_modules,
+        classifiers=[
+            'Programming Language :: Python :: 3.8',
+            'Programming Language :: Python :: 3.9',
+            'Programming Language :: Python :: 3.10',
+            'Programming Language :: Python :: 3.11',
+            'Intended Audience :: Developers',
+            'Intended Audience :: Education',
+            'Intended Audience :: Science/Research',
+        ],
+        entry_points={'console_scripts': ['lmdeploy = lmdeploy.cli:run']},
+    )
diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt
index 7c014845dd..a7593e3de9 100644
--- a/src/turbomind/kernels/CMakeLists.txt
+++ b/src/turbomind/kernels/CMakeLists.txt
@@ -71,3 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
 set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
 
 add_subdirectory(gemm_s_f16)
+add_subdirectory(decoder_multihead_attention)
diff --git a/src/turbomind/kernels/bert_preprocess_kernels.cu b/src/turbomind/kernels/bert_preprocess_kernels.cu
index 0b97c44fba..495ec6e4a9 100644
--- a/src/turbomind/kernels/bert_preprocess_kernels.cu
+++ b/src/turbomind/kernels/bert_preprocess_kernels.cu
@@ -48,7 +48,9 @@ __global__ void getPaddingOffsetAndCuSeqLensKernel(size_t*    h_valid_word_num,
     if (calculate_cu_seqlens) {
         cu_seqlens[batch_size] = total_seq_len;
     }
-    h_valid_word_num[0] = (size_t)total_seq_len;
+    if (h_valid_word_num) {
+        h_valid_word_num[0] = (size_t)total_seq_len;
+    }
 }
 
 void invokeGetPaddingOffsetAndCuSeqLens(size_t*      h_pinned_token_num,
@@ -60,15 +62,19 @@ void invokeGetPaddingOffsetAndCuSeqLens(size_t*      h_pinned_token_num,
                                         const int    max_seq_len,
                                         cudaStream_t stream)
 {
-    h_pinned_token_num[0] = 0;
+    if (h_pinned_token_num) {
+        h_pinned_token_num[0] = 0;
+    }
     getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>(
         h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len);
+    if (h_pinned_token_num) {
 #ifdef _MSC_VER
-    cudaStreamSynchronize(stream);
+        cudaStreamSynchronize(stream);
 #else
-    while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
+        while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
 #endif
-    h_token_num[0] = h_pinned_token_num[0];
+        h_token_num[0] = h_pinned_token_num[0];
+    }
     sync_check_cuda_error();
 }
 
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
index 370594a274..e6ea907dd1 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
@@ -43,6 +43,12 @@
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+// cudaFuncAttributes attr{};                                                                                         \
+// cudaFuncGetAttributes(&attr, func);                                                                                \
+// std::cout << "static_smem_sz: " << attr.sharedSizeBytes << std::endl;                                              \
+// std::cout << "max_dynamic_smem: " << attr.maxDynamicSharedSizeBytes << std::endl;                                  \
+// std::cout << "dynamic_smem_sz: " << smem_sz << std::endl;                                                          \
+
 template
 void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
 {
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
index a0f7490e00..85ece1fa99 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
@@ -1472,6 +1472,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params
         }
         // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
 
+        printf("QK_last[%d] = %f\n", hi, qk);
+
         qk_max                        = qk;
         qk_smem[tlength - first_step] = qk;
         // qk_smem[params.timestep] = qk;
@@ -1596,6 +1598,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params
 
                 qk += mul(params.linear_bias_slopes[hi], dist);
             }
+            // printf("QK_%d = %f\n", (int)ti, qk);
             qk_max                   = is_mask ? qk_max : fmaxf(qk_max, qk);
             qk_smem[ti - first_step] = qk;
         }
@@ -1632,6 +1635,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params
     // Broadcast to all the threads in the warp.
     qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
 
+    if (threadIdx.x == 0) {
+        printf("QK_MAX[%d] = %f\n", hi, (float)qk_max);
+    }
+
     // Compute the logits and start the sum.
     float sum = 0.f;
     // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
@@ -1657,6 +1664,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params
     // Compute the sum.
     sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum);
 
+    if (threadIdx.x == 0) {
+        printf("SUM[%d] = %f\n", hi, (float)sum);
+    }
+
     // Normalize the logits.
     float inv_sum = __fdividef(1.f, sum + 1.e-6f);
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
new file mode 100644
index 0000000000..61e5245ffc
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu)
+# target_compile_options(decoder_multihead_attention PRIVATE
+#   --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
+set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
+set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
+target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
+
+add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
+# target_compile_options(test_decoder_multihead_attention PRIVATE
+#   --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
+target_link_libraries(test_decoder_multihead_attention PRIVATE
+    decoder_multihead_attention
+    decoder_masked_multihead_attention
+    cublas)
diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
new file mode 100644
index 0000000000..5a1300ff2d
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -0,0 +1,490 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "src/turbomind/kernels/gemm_s_f16/common.h"
+#include 
+#include 
+
+namespace turbomind {
+
+namespace ops {
+
+template
+struct plus {
+    __device__ T operator()(T a, T b)
+    {
+        return a + b;
+    }
+};
+
+template
+struct minus {
+    __device__ T operator()(T a, T b)
+    {
+        return a - b;
+    }
+};
+
+template
+struct multiplies {
+    __device__ T operator()(T a, T b)
+    {
+        return a * b;
+    }
+};
+
+template
+inline __device__ Array binary_op_vv(const Array& a, const Array& b, Op op)
+{
+    Array c;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        c[i] = op(a[i], b[i]);
+    }
+    return c;
+}
+
+template
+inline __device__ Array binary_op_sv(const T& a, const Array& b, Op op)
+{
+    Array c;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        c[i] = op(a, b[i]);
+    }
+    return c;
+}
+
+template
+inline __device__ Array binary_op_vs(const Array& a, const T& b, Op op)
+{
+    Array c;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        c[i] = op(a[i], b);
+    }
+    return c;
+}
+
+template
+inline __device__ Array operator+(const Array& a, const Array& b)
+{
+    return binary_op_vv(a, b, plus{});
+}
+
+template
+inline __device__ Array operator*(const Array& a, const Array& b)
+{
+    return binary_op_vv(a, b, multiplies{});
+}
+
+template
+inline __device__ Array operator*(const Array& a, const T& b)
+{
+    return binary_op_vs(a, b, multiplies{});
+}
+
+}  // namespace ops
+
+template
+inline __device__ Array cast(const Array& src)
+{
+    Array dst;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        dst[i] = (To)src[i];
+    }
+    return dst;
+}
+
+template
+struct RotaryEmbedding {
+
+    static_assert(N % 2 == 0);
+
+    Array cs_;
+
+    __device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 2) {
+            const float2 tmp = get_coefficient(offset.x + i, dims, base, timestep);
+            cs_[i]           = tmp.x;
+            cs_[i + 1]       = tmp.y;
+        }
+    }
+
+    static __device__ inline float2 get_coefficient(int idx, int dims, float base, int timestep)
+    {
+        const float inv_freq = timestep / powf(base, idx / (float)dims);
+        float2      cs;
+        sincosf(inv_freq, &cs.y, &cs.x);
+        return cs;
+    }
+
+    template
+    __device__ void apply(Array& x)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 2) {
+            float tmp0 = cs_[i] * (float)x[i] - cs_[i + 1] * (float)x[i + 1];
+            float tmp1 = cs_[i] * (float)x[i + 1] + cs_[i + 1] * (float)x[i];
+            x[i]       = (T)tmp0;
+            x[i + 1]   = (T)tmp1;
+        }
+    }
+};
+
+struct LogNScaling {
+
+    float scale_;
+
+    __device__ static float get_scale(int seq_len, int max_position_embeddings)
+    {
+        if (seq_len <= max_position_embeddings) {
+            return 1.f;
+        }
+        else {
+            return log2f(seq_len) / log2f(max_position_embeddings);
+        }
+    }
+
+    __device__ LogNScaling(int seq_len, int max_position_embeddings)
+    {
+        scale_ = get_scale(seq_len, max_position_embeddings);
+    }
+
+    template
+    __device__ void apply(Array& x) const
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            x[i] = (T)((float)x[i] * scale_);
+        }
+    }
+};
+
+template
+inline __device__ void Store(T* dst, const Array& src)
+{
+    static_assert(sizeof(Array) <= sizeof(uint4));
+
+    if constexpr (sizeof(Array) == sizeof(uint4)) {
+        *(uint4*)dst = (const uint4&)src;
+    }
+    else if constexpr (sizeof(Array) == sizeof(uint2)) {
+        *(uint2*)dst = (const uint2&)src;
+    }
+    else if constexpr (sizeof(Array) == sizeof(uint1)) {
+        *(uint1*)dst = (const uint1&)src;
+    }
+    else {
+        static_assert(!std::is_same_v);
+    }
+}
+
+template
+inline __device__ void Ldg(Array& dst, const T* src)
+{
+    static_assert(sizeof(Array) <= sizeof(uint4));
+
+    if constexpr (sizeof(Array) == sizeof(uint4)) {
+        (uint4&)dst = __ldg((const uint4*)src);
+    }
+    else if constexpr (sizeof(Array) == sizeof(uint2)) {
+        (uint2&)dst = __ldg((const uint2*)src);
+    }
+    else if constexpr (sizeof(Array) == sizeof(uint)) {
+        (uint&)dst = __ldg((const uint*)src);
+    }
+    else {
+        static_assert(!std::is_same_v);
+    }
+}
+
+template
+inline __device__ void Lds(Array& dst, const T* src)
+{
+    static_assert(sizeof(Array) <= sizeof(uint4));
+
+    if constexpr (sizeof(Array) == sizeof(uint4)) {
+        (uint4&)dst = *(const uint4*)src;
+    }
+    else if constexpr (sizeof(Array) == sizeof(uint2)) {
+        (uint2&)dst = *(const uint2*)src;
+    }
+    else if constexpr (sizeof(Array) == sizeof(uint)) {
+        (uint1&)dst = *(const uint1*)src;
+    }
+    else {
+        static_assert(!std::is_same_v);
+    }
+}
+
+template
+inline __device__ Accum qk_dot(const Array (&q)[V], const Array (&k)[V])
+{
+    Accum accum{};
+
+    PRAGMA_UNROLL
+    for (int vi = 0; vi < V; ++vi) {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            accum += Accum(Compute(q[vi][i]) * Compute(k[vi][i]));
+        }
+    }
+
+    PRAGMA_UNROLL
+    for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
+        accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
+    }
+
+    return accum;
+}
+
+template
+inline __device__ Accum qk_dot(const Array& q, const Array& k)
+{
+    Accum accum{};
+
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        accum += Accum(Compute(q[i]) * Compute(k[i]));
+    }
+
+    PRAGMA_UNROLL
+    for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
+        accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
+    }
+
+    return accum;
+}
+
+template
+inline __device__ void fma_pv(Tp pr, const Array (&v)[M], Array (&o)[M])
+{
+    PRAGMA_UNROLL
+    for (int m = 0; m < M; ++m) {
+        PRAGMA_UNROLL
+        for (int n = 0; n < N; ++n) {
+            o[m][n] += To(ComputeType(v[m][n]) * ComputeType(pr));
+        }
+    }
+}
+
+template
+inline __device__ Array qk_max(Array val, T* smem_red, int warp_id, int lane_id)
+{
+    constexpr int kWarpCount = ThreadMap::kWarpCount;
+
+    // warp maximum
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        PRAGMA_UNROLL
+        for (int mask = WARP_SIZE / 2; mask >= ThreadMap::kWarpThreadC; mask /= 2) {
+            val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
+        }
+        if (lane_id == 0) {
+            smem_red[i * kWarpCount + warp_id] = val[i];
+        }
+    }
+
+    __syncthreads();
+
+    // block maximum
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : -FLT_MAX;
+        PRAGMA_UNROLL
+        for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
+            val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
+        }
+        // braodcast to all threads
+        val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
+    }
+
+    return val;
+}
+
+template
+inline __device__ Array blockSum(Array val, T* smem_red, int warp_id, int lane_id)
+{
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        PRAGMA_UNROLL
+        for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
+            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
+        }
+        if (lane_id == 0) {
+            smem_red[i * kWarpCount + warp_id] = val[i];
+        }
+    }
+
+    __syncthreads();
+
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : T{};
+        PRAGMA_UNROLL
+        for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
+            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
+        }
+        val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
+    }
+
+    return val;
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+
+// generic case for floating point -> floating point / integer -> integer conversion
+template
+struct ConvertKvCache {
+    __device__ __host__ ConvertKvCache(float, float) {}
+    template
+    inline __device__ auto operator()(const Array& vi) const -> Array
+    {
+        Array vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            vo[i] = (To)vi[i];
+        }
+        return vo;
+    }
+};
+
+// generic case for converting to same type, bypass
+template
+struct ConvertKvCache {
+    __device__ __host__ ConvertKvCache(float, float) {}
+    template
+    inline __device__ auto operator()(const Array& v) const -> Array
+    {
+        return v;
+    }
+};
+
+template
+struct ConvertKvCache {
+
+    float scale_;
+    float zero_;
+
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
+
+    inline __device__ uint8_t round(float x) const
+    {
+        uint32_t y;
+        asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
+        return y;
+    }
+
+    template
+    inline __device__ auto operator()(const Array& vi) const -> Array
+    {
+        Array vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            // convert to unsigned int by offsetting +128
+            (uint8_t&)vo[i] = round(((float)vi[i] - zero_) / scale_ + 128.f);
+        }
+        return vo;
+    }
+};
+
+inline __device__ Array fast_i2f_f32_s8(const Array& x)
+{
+    union {
+        Array    f32x4;
+        Array u32x4;
+    };
+
+    auto& i8s = (const uint32_t&)x;
+
+    // 00000000111111112222222233333333
+    // 01234567012345670123456701234567
+    // SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
+    // 0????????_______XXXXXXXX________
+    // (1 + x / 2^15) * 2^(e - 127) -> e - 127 == 15 -> e = 142
+    //                                       7 6 5 4
+    static constexpr uint32_t f32_magic = 0x47000000;  // 2^15 = 32768
+    static constexpr uint32_t m0        = 0x7604;
+    static constexpr uint32_t m1        = 0x7614;
+    static constexpr uint32_t m2        = 0x7624;
+    static constexpr uint32_t m3        = 0x7634;
+
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
+
+    if (0) {  // fused with dequantization
+        PRAGMA_UNROLL
+        for (int i = 0; i < 4; ++i) {
+            f32x4[i] -= 32896.f;  // 32768 + 128
+        }
+    }
+
+    return f32x4;
+}
+
+template<>
+struct ConvertKvCache {
+
+    float scale_;
+    float zero_;
+
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
+    {
+        zero_ = zero_ - 32896.f * scale_;
+    }
+
+    template
+    inline __device__ auto operator()(const Array& vi) const -> Array
+    {
+        Array vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 4) {
+            auto& vec = (Array&)vo[i];
+            vec       = fast_i2f_f32_s8((const Array&)vi[i]);
+            PRAGMA_UNROLL
+            for (int j = 0; j < 4; ++j) {
+                vec[j] = vec[j] * scale_ + zero_;
+                // vec[j] = vec[j] * scale_ + (zero_ - 32896.f * scale_);
+            }
+        }
+        return vo;
+    }
+};
+
+template<>
+struct ConvertKvCache {
+
+    float scale_;
+    float zero_;
+
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
+    {
+        zero_ = zero_ - 32896.f * scale_;
+    }
+
+    template
+    inline __device__ auto operator()(const Array& vi) const -> Array
+    {
+        Array vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 4) {
+            auto& vec = (Array&)vo[i];
+            auto  tmp = fast_i2f_f32_s8((const Array&)vi[i]);
+            PRAGMA_UNROLL
+            for (int j = 0; j < 4; ++j) {
+                vec[j] = half(tmp[j] * scale_ + zero_);
+                // vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
+            }
+        }
+        return vo;
+    }
+};
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
new file mode 100644
index 0000000000..02cc827694
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -0,0 +1,115 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "decoder_multihead_attention_template.h"
+#include "src/turbomind/models/llama/llama_utils.h"
+#include "src/turbomind/utils/cuda_utils.h"
+
+#include 
+
+namespace turbomind {
+
+namespace {
+
+template
+bool Print(size_t dynamic_smem_size)
+{
+    using MapKv = typename MHAType::MapKv;
+
+    std::cout << "     warps: " << MapKv::kWarpCount << "\n";
+    std::cout << "     shape: (" << MapKv::kC << ", " << MapKv::kS << ")\n";
+    std::cout << "    access: (" << MapKv::kAccessC << ", " << 1 << ")\n";
+    std::cout << "warpThread: (" << MapKv::kWarpThreadC << ", " << MapKv::kWarpThreadS << ")\n";
+    std::cout << "warpAccess: (" << MapKv::kWarpAccessC << ", " << MapKv::kWarpAccessS << ")\n";
+    std::cout << "  warpIter: (" << MapKv::kWarpIterC << ", " << MapKv::kWarpIterS << ")\n";
+    std::cout << "      warp: (" << MapKv::kWarpC << ", " << MapKv::kWarpS << ")\n";
+    std::cout << "      iter: (" << MapKv::kIterC << ", " << MapKv::kIterS << ")\n";
+    std::cout << " footprint: (" << MapKv::kFootprintC << ", " << MapKv::kFootprintS << ")\n";
+    std::cout << "     delta: (" << MapKv::kDeltaC << ", " << MapKv::kDeltaS << ")\n";
+    std::cout << "dynamic smem size: " << dynamic_smem_size << "\n";
+
+    return true;
+}
+
+}  // namespace
+
+template
+void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params)
+{
+    auto invoke = [&](auto* type) {
+        using Attn = std::remove_reference_t;
+
+        static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();
+
+        // [[maybe_unused]] static const bool _ = Print(kDynSmemSize);
+
+        const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
+        const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
+
+        dim3 block(Attn::kWarpCount * WARP_SIZE);
+        dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
+
+        // if (params.layer_offset == 0) {
+        //     std::cout << "max_split_k' = " << max_split_k << ", arch = " << params.arch << "\n";
+        // }
+
+        cudaFuncSetAttribute(
+            decoder_multihead_attention, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
+
+        decoder_multihead_attention<<>>(params);
+
+        if (max_split_k > 1) {
+            dim3 grid(params.num_heads, params.batch_size);
+            decoder_multihead_attention_reduce<<>>(params);
+        }
+    };
+
+    if (params.arch >= 80) {
+        // DecoderMultiHeadAttentionKernel;  // 64k
+
+        using Type = DecoderMultiHeadAttentionKernel;
+        invoke((Type*)0);
+    }
+    else {
+        // DecoderMultiHeadAttentionKernel; // 34k
+        // DecoderMultiHeadAttentionKernel;  // 34k
+
+        using Type = DecoderMultiHeadAttentionKernel;
+        invoke((Type*)0);
+    }
+}
+
+template
+void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params)
+{
+    static constexpr int HeadDim = 128;
+
+    FT_CHECK(params.size_per_head == HeadDim);
+
+    if constexpr (std::is_same_v) {
+        if (params.quant_policy & QuantPolicy::kCacheKVInt8) {
+            invokeDecoderMultiheadAttention(params);
+            return;
+        }
+
+        int group_size = params.num_heads / params.num_kv_heads;
+
+        if (0) {}
+        // else if (group_size % 8 == 0) {
+        //     invokeDecoderMultiheadAttention(params);
+        // }
+        else if (group_size % 4 == 0) {
+            invokeDecoderMultiheadAttention(params);
+        }
+        else if (group_size % 2 == 0) {
+            invokeDecoderMultiheadAttention(params);
+        }
+        else {
+            invokeDecoderMultiheadAttention(params);
+        }
+    }
+}
+
+template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params);
+template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
new file mode 100644
index 0000000000..5f7024c49c
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
@@ -0,0 +1,12 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "decoder_multihead_attention_params.h"
+
+namespace turbomind {
+
+template
+void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params);
+
+}
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
new file mode 100644
index 0000000000..ebe46d9773
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -0,0 +1,69 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+#include 
+
+namespace turbomind {
+
+template
+struct DecoderMultiHeadAttentionParams {
+    // token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]
+    T* __restrict__ out;
+    T* __restrict__ q;
+    T* __restrict__ k;
+    T* __restrict__ v;
+    int stride;
+
+    // bias, [qH, D] or [kvH, D]
+    T* __restrict__ q_bias;
+    T* __restrict__ k_bias;
+    T* __restrict__ v_bias;
+
+    // sequence-level buffers
+    const int* __restrict__ context_length;
+    const bool* __restrict__ finished;
+    const float* __restrict__ rope_theta;
+
+    // kv cache
+    size_t layer_offset;
+
+    /// cache layout M,[N,H,x,D]
+    /// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
+    /// 1. [L,sum(S),H,x,D]
+    void** __restrict__ k_cache_block_ptrs;  // X,[H,x,D]
+    void** __restrict__ v_cache_block_ptrs;  // X,[H,x,D]
+    int* __restrict__ cu_block_cnts;         // [B+1]
+    int kv_cache_block_size;
+
+    // batch-level params
+    int batch_size;
+    int max_seq_len;
+
+    // instance-level params
+    int   num_heads;
+    int   num_kv_heads;
+    int   size_per_head;
+    float inv_sqrt_dh;
+
+    // rotary embedding
+    int   rotary_embedding_dim;
+    float rotary_embedding_base;
+    int   max_position_embeddings;
+    // bool  use_dynamic_ntk;
+
+    // log(n) attention
+    bool use_logn_attn;
+
+    int   quant_policy;
+    float kv_quant_params[4];
+
+    int    max_split_k;
+    float* partial_O;
+    float* partial_M;
+    float* partial_L;
+
+    int          arch;
+    cudaStream_t stream;
+};
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
new file mode 100644
index 0000000000..98ff678870
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -0,0 +1,932 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "array_ops.h"
+#include "iterator.h"
+#include "src/turbomind/kernels/gemm_s_f16/common.h"
+#include "thread_map.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "decoder_multihead_attention_params.h"
+
+namespace turbomind {
+
+template
+struct DecoderMultiHeadAttentionKernel {
+    using ParamType = DecoderMultiHeadAttentionParams;
+
+    static constexpr int  kWarpCount  = 4;
+    static constexpr int  kHeadPerCta = HeadPerCta;
+    static constexpr int  kMaxHeadDim = MaxHeadDim;
+    static constexpr int  kKeyPerIter = KeyPerIter;
+    static constexpr int  kHeadDim    = HeadDim;
+    static constexpr int  kStages     = Stages;
+    static constexpr bool kSplitK     = SplitK;
+
+    static constexpr int kSliceLen     = SliceLen;
+    static constexpr int kIterPerSlice = kSliceLen / kKeyPerIter;
+
+    static constexpr int kVecKvSize    = sizeof(uint4) / sizeof(Tkv);
+    static constexpr int kThreadPerKey = 8;
+
+    using VecKv      = Array;
+    using VecKvFloat = Array;
+
+    static constexpr bool kUseBlockIter = true;
+
+    using MapKv  = ThreadMapKv;
+    using IterKv = turbomind::Iterator;
+
+    static constexpr size_t GetDynamicSmemSize()
+    {
+        size_t smem_kv_cache = IterKv::kSmemByteSize;
+        // size_t smem_kv_align = 128;
+        size_t smem_kv_align = 0;
+        size_t smem_qk       = sizeof(float) * kHeadPerCta * kSliceLen;
+        size_t smem_pr       = sizeof(float) * kHeadPerCta * kSliceLen;
+        return smem_kv_align + smem_kv_cache + std::max(smem_qk, smem_pr);
+    }
+
+    using QkAccumType   = float;
+    using QkComputeType = float;
+
+    using PvAccumType   = float;
+    using PvComputeType = float;
+
+    struct SharedStorage {
+        __align__(16) T Q[kHeadPerCta * kMaxHeadDim];
+        __align__(16) float O[kHeadPerCta * kMaxHeadDim];
+        float M[kHeadPerCta];  // max{dot(Q,  K^T  )}
+        float L[kHeadPerCta];  // sum{exp(s - S_max)}
+        float red_max[kHeadPerCta * kWarpCount];
+        float red_sum[kHeadPerCta * kWarpCount];
+    };
+
+    const ParamType& params_;
+
+    int head_idx_;
+    int batch_idx_;
+    int warp_id_;
+    int lane_id_;
+
+    int  kv_head_idx_;
+    bool is_gqa_leader_;
+
+    int step_begin_;
+    int step_end_;
+
+    int timestep_;
+    Tkv* __restrict__ k_cache_;  // [S, D]
+    Tkv* __restrict__ v_cache_;  // [S, D]
+
+    const void** __restrict__ k_cache_ptrs_;
+    const void** __restrict__ v_cache_ptrs_;
+
+    Tkv* __restrict__ smem_Kv_;
+    float* __restrict__ smem_S_;
+    float* __restrict__ smem_P_;
+    T* __restrict__ smem_Q_;
+    float* __restrict__ smem_M_;
+    float* __restrict__ smem_L_;
+    float* __restrict__ smem_O_;
+    float* __restrict__ smem_red_max_;
+    float* __restrict__ smem_red_sum_;
+
+    // avoid redundant type cast for KV8
+    using KLoadType = std::conditional_t, float, T>;
+    using VLoadType = std::conditional_t, float, T>;
+
+    ConvertKvCache         conv_k_store_;
+    ConvertKvCache         conv_v_store_;
+    ConvertKvCache conv_k_;
+    ConvertKvCache conv_v_;
+
+    __device__ bool thread0()
+    {
+        return blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0;
+    }
+
+    __device__ DecoderMultiHeadAttentionKernel(const ParamType& params, SharedStorage& smem, uint8_t* dsmem):
+        params_(params),
+        conv_k_store_{params_.kv_quant_params[0], params_.kv_quant_params[1]},
+        conv_v_store_{params_.kv_quant_params[2], params_.kv_quant_params[3]},
+        conv_k_{params_.kv_quant_params[0], params_.kv_quant_params[1]},
+        conv_v_{params_.kv_quant_params[2], params_.kv_quant_params[3]}
+    {
+        smem_Kv_      = (Tkv*)dsmem;
+        smem_S_       = (float*)(smem_Kv_ + IterKv::kSizePerTile * kStages);  // [HeadPerCta * kSliceLen]
+        smem_P_       = smem_S_;  // ! reusing only works when S and P has same dtype
+        smem_Q_       = smem.Q;
+        smem_M_       = smem.M;
+        smem_L_       = smem.L;
+        smem_O_       = smem.O;
+        smem_red_max_ = smem.red_max;
+        smem_red_sum_ = smem.red_sum;
+
+        head_idx_  = blockIdx.x * kHeadPerCta;
+        batch_idx_ = blockIdx.y;
+        warp_id_   = threadIdx.x / WARP_SIZE;
+        lane_id_   = threadIdx.x % WARP_SIZE;
+
+        const int gqa_group_size = params.num_heads / params.num_kv_heads;
+        kv_head_idx_             = head_idx_ / gqa_group_size;
+        is_gqa_leader_           = head_idx_ % gqa_group_size == 0;
+
+        timestep_ = params_.context_length[batch_idx_] - 1;
+
+        if (kSplitK && params.max_split_k > 1) {
+            const int slice_count     = (timestep_ + kSliceLen - 1) / kSliceLen;
+            const int slice_per_split = (slice_count + params_.max_split_k - 1) / params_.max_split_k;
+
+            step_begin_ = slice_per_split * get_split_k_idx() * kSliceLen;
+            step_end_   = min(timestep_, step_begin_ + slice_per_split * kSliceLen);
+        }
+        else {
+            step_begin_ = 0;
+            step_end_   = timestep_;
+        }
+
+        if constexpr (kUseBlockIter) {
+            k_cache_ptrs_ = params_.k_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
+            v_cache_ptrs_ = params_.v_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
+        }
+        else {
+            k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.layer_offset
+                       + kv_head_idx_ * params_.max_seq_len * params_.size_per_head;
+            v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.layer_offset
+                       + kv_head_idx_ * params_.max_seq_len * params_.size_per_head;
+        }
+    }
+
+    __device__ void Prolugue()
+    {
+        // - Each warp is handling a row of Q
+        // - K/V are loaded redundantly only for the current step
+        static_assert(kMaxHeadDim % WARP_SIZE == 0);
+        static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
+
+        using VecQ      = Array;
+        using VecQFloat = Array;
+
+        using MapQ = ThreadMapQ;
+
+        static constexpr int kQVecPerThread  = MapQ::kIterC;
+        static constexpr int kQHeadPerThread = MapQ::kIterS;  // > 1 when #warp < kCtaPerHead
+
+        static_assert(kQVecPerThread == 1);
+
+        int2 offset   = MapQ::get_offset(warp_id_, lane_id_);
+        bool is_valid = offset.x < kMaxHeadDim && offset.y < kHeadPerCta;
+
+        if (!is_valid) {
+            return;
+        }
+
+        VecQ frag_Q[kQHeadPerThread];
+        VecQ frag_K;
+        VecQ frag_V;
+
+        // load qkv
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQHeadPerThread; ++s) {
+            int di = offset.x;
+            int qi = offset.y + s;
+            Ldg(frag_Q[s], ¶ms_.q[batch_idx_ * params_.stride + (head_idx_ + qi) * kHeadDim + di]);
+        }
+        Ldg(frag_K, ¶ms_.k[batch_idx_ * params_.stride + kv_head_idx_ * kHeadDim + offset.x]);
+        Ldg(frag_V, ¶ms_.v[batch_idx_ * params_.stride + kv_head_idx_ * kHeadDim + offset.x]);
+
+        if (params_.q_bias) {
+            // load biases
+            VecQ bias_Q[kQHeadPerThread];
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                int di = offset.x;
+                int qi = offset.y + s;
+                Ldg(bias_Q[s], ¶ms_.q_bias[(head_idx_ + qi) * kHeadDim + di]);
+            }
+            VecQ bias_K;
+            VecQ bias_V;
+            Ldg(bias_K, ¶ms_.k_bias[kv_head_idx_ * kHeadDim + offset.x]);
+            Ldg(bias_V, ¶ms_.v_bias[kv_head_idx_ * kHeadDim + offset.x]);
+
+            using namespace ops;
+            // apply biases
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                frag_Q[s] = frag_Q[s] + bias_Q[s];
+            }
+            frag_K = frag_K + bias_K;
+            frag_V = frag_V + bias_V;
+        }
+
+        // for (int i = 0; i < kVecQSize; ++i) {
+        //     printf("q[%2d][%3d] = %f\n", (int)head_idx_, (int)(offset.x + i), (float)frag_Q[0][i]);
+        // }
+
+        float rotary_embedding_base =
+            params_.rope_theta ? params_.rope_theta[batch_idx_] : params_.rotary_embedding_base;
+
+        // Apply rotary embedding
+        RotaryEmbedding rotary_emb(rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);
+
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQHeadPerThread; ++s) {
+            rotary_emb.apply(frag_Q[s]);
+        }
+        rotary_emb.apply(frag_K);
+
+        if (params_.use_logn_attn) {
+            LogNScaling logn_scaling(timestep_ + 1, params_.max_position_embeddings);
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                logn_scaling.apply(frag_Q[s]);
+            }
+        }
+
+        if (kSplitK && step_begin_) {  // Split idx > 0
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                int qi = offset.y + s;
+                if (lane_id_ == 0) {
+                    smem_M_[qi] = -std::numeric_limits::infinity();
+                    smem_L_[qi] = 0.f;
+                }
+                Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
+                Store(&smem_O_[qi * kMaxHeadDim + offset.x], VecQFloat{});
+            }
+            return;
+        }
+
+        ////////////////////////////////////////////////////////
+        // Split 0 computes last step and stores to k/v cache
+
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQHeadPerThread; ++s) {
+            int         qi = offset.y + s;
+            QkAccumType qk = qk_dot(frag_Q[s], frag_K);
+            if (lane_id_ == 0) {
+                qk *= params_.inv_sqrt_dh;
+                smem_M_[qi] = qk;
+                smem_L_[qi] = 1.f;
+                // printf("qk[%2d] = %f\n", head_idx_, qk);
+            }
+            // write Q and O
+            Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
+            Store(&smem_O_[qi * kMaxHeadDim + offset.x], cast(frag_V));
+        }
+
+        auto frag_K_store = conv_k_store_(frag_K);
+        auto frag_V_store = conv_v_store_(frag_V);
+
+        // store
+        if (warp_id_ == 0 && is_gqa_leader_) {
+            if constexpr (kUseBlockIter) {
+                int block_index  = timestep_ / params_.kv_cache_block_size;
+                int block_offset = timestep_ % params_.kv_cache_block_size;
+                // if (thread0()) {
+                //     printf("%d %d %p %p\n", block_index, block_offset, k_cache_ptrs_, v_cache_ptrs_);
+                // }
+                k_cache_ = (Tkv*)k_cache_ptrs_[block_index] + params_.layer_offset
+                           + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                v_cache_ = (Tkv*)v_cache_ptrs_[block_index] + params_.layer_offset
+                           + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K_store);
+                Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V_store);
+            }
+            else {
+                Store(&k_cache_[timestep_ * kHeadDim + offset.x], frag_K_store);
+                Store(&v_cache_[timestep_ * kHeadDim + offset.x], frag_V_store);
+            }
+        }
+    }
+
+    __device__ void PrefetchKvCache(IterKv& iter)
+    {
+        PRAGMA_UNROLL
+        for (int stage = 0; stage < kStages - 1; ++stage) {
+            iter.PrefetchStage();
+            CpAsyncCommit();
+        }
+    }
+
+    __device__ void CpAsyncWait()
+    {
+        __pipeline_wait_prior(kStages - 2);
+    }
+
+    __device__ void CpAsyncCommit()
+    {
+        __pipeline_commit();
+    }
+
+    __device__ void CpAsyncFlush()
+    {
+        __pipeline_commit();
+        __pipeline_wait_prior(0);
+    }
+
+    static constexpr int kKvVecPerThread = MapKv::kIterC;
+    static constexpr int kKvKeyPerThread = MapKv::kIterS;
+
+    struct FragmentQ {
+        VecKv data[kHeadPerCta][kKvVecPerThread];
+    };
+
+    struct State {
+        // Double buffering to hide smem/dequant latency
+        Array frag_K_buf[2][kKvVecPerThread];
+        Array frag_V_buf[2][kKvVecPerThread];
+
+        Array frag_Kv_tmp_buf[2][kKvVecPerThread];
+    };
+
+    static constexpr int kPrefetchCount = (IterKv::kIterCount + MapKv::kIterS - 1) / MapKv::kIterS;
+
+    __device__ void ComputeSlice(FragmentQ& frag_Q, State& state, const int2& offset, int step, int iter_length)
+    {
+
+        Array frag_M;
+        PRAGMA_UNROLL
+        for (int i = 0; i < kHeadPerCta; ++i) {
+            frag_M[i] = smem_M_[i];
+        }
+
+        IterKv iter_K;
+
+        if constexpr (kUseBlockIter) {
+            iter_K = {k_cache_ptrs_,
+                      params_.kv_cache_block_size,
+                      params_.layer_offset,
+                      kv_head_idx_,
+                      smem_Kv_,
+                      step,
+                      step + iter_length,
+                      warp_id_,
+                      lane_id_};
+        }
+        else {
+            iter_K = {k_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_};
+        }
+
+        PrefetchKvCache(iter_K);
+        CpAsyncWait();
+
+        iter_K.Load(state.frag_Kv_tmp_buf[0]);
+        PRAGMA_UNROLL
+        for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+            state.frag_K_buf[0][vi] = conv_k_(state.frag_Kv_tmp_buf[0][vi]);
+        }
+
+        iter_K.PrefetchBatch(0, kPrefetchCount);
+        if (kKvKeyPerThread == 1) {
+            CpAsyncCommit();
+            CpAsyncWait();
+            iter_K.AdvancePrefetchStage();
+            iter_K.AdvanceComputeStage();
+        }
+
+        ///////////////////////////////////////////////////////////////////////////////////////////
+        /// Compute QK(Q, S) = Q(Q, D) * K^T(D, S)
+
+        PRAGMA_NO_UNROLL
+        for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
+            PRAGMA_UNROLL
+            for (int si = 0; si < kKvKeyPerThread; ++si) {
+                const int next = (si + 1) % 2;
+                // smem -> rmem for next iter
+                iter_K.Load(state.frag_Kv_tmp_buf[next]);
+                PRAGMA_UNROLL
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_K_buf[next][vi] = conv_k_(state.frag_Kv_tmp_buf[next][vi]);
+                }
+
+                // current iter's K fragment
+                auto& frag_K = state.frag_K_buf[si % 2];
+
+                const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
+
+                PRAGMA_UNROLL
+                for (int qi = 0; qi < kHeadPerCta; ++qi) {
+
+                    auto qk = qk_dot(frag_Q.data[qi], frag_K);
+
+                    // if (ti == 16) {
+                    //     for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    //         for (int i = 0; i < kVecKvSize; ++i) {
+                    //             printf("frag_Q = %f, frag_K[%d] = %f\n",
+                    //                    (float)frag_Q.data[qi][vi][i],
+                    //                    offset.x + vi * kVecKvSize + i,
+                    //                    (float)frag_K[vi][i]);
+                    //         }
+                    //     }
+                    // }
+
+                    qk *= params_.inv_sqrt_dh;
+
+                    if (step + local_offset < timestep_) {
+
+                        // group leader writes to smem
+                        if (threadIdx.x % kThreadPerKey == 0) {
+                            // printf("qk_%d = %f\n", step + local_offset, (float)qk);
+
+                            smem_S_[kSliceLen * qi + local_offset] = qk;
+
+                            // local max
+                            frag_M[qi] = fmaxf(frag_M[qi], qk);
+                        }
+                    }
+                }
+
+                iter_K.PrefetchBatch((si + 1) % kKvKeyPerThread, kPrefetchCount);
+
+                if (kKvKeyPerThread == 1 || si == kKvKeyPerThread - 2) {
+                    CpAsyncCommit();
+                    CpAsyncWait();
+                    iter_K.AdvancePrefetchStage();
+                    iter_K.AdvanceComputeStage();
+                }
+            }
+
+            // handle special case
+            if (kKvKeyPerThread == 1) {
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_K_buf[0][vi] = state.frag_K_buf[1][vi];
+                }
+            }
+        }
+
+        CpAsyncFlush();
+
+        __syncthreads();
+
+        Array exp_M_diff;
+        PRAGMA_UNROLL
+        for (int i = 0; i < kHeadPerCta; ++i) {
+            exp_M_diff[i] = smem_M_[i];
+        }
+
+        /// block synchronization
+        frag_M = qk_max(frag_M, smem_red_max_, warp_id_, lane_id_);
+
+        // wait while smem_red_ is being used.
+        // __syncthreads();
+
+        PRAGMA_UNROLL
+        for (int i = 0; i < kHeadPerCta; ++i) {
+            // if (thread0()) {
+            //     printf("%f %f %f\n", (float)exp_M_diff[i], (float)frag_M[i], (float)__expf(exp_M_diff[i] -
+            //     frag_M[i]));
+            // }
+            // exp(m1 - m2)
+            exp_M_diff[i] = __expf(exp_M_diff[i] - frag_M[i]);
+
+            if (threadIdx.x == 0) {
+                smem_M_[i] = frag_M[i];
+            }
+        }
+
+        // if (threadIdx.x == 0 && step + iter_length == timestep_) {
+        //     printf("frag_M[%2d] = %f\n", head_idx_, (float)frag_M[0]);
+        // }
+
+        // __syncthreads();  // DEBUG
+
+        /////////////////////////////////////////////////////////////////////////////////////////
+        // / Compute softmax P(Q, S)
+        Array frag_L{};
+
+        for (int ti = threadIdx.x; ti < iter_length; ti += kWarpCount * WARP_SIZE) {
+            PRAGMA_UNROLL
+            for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                int   idx = qi * kSliceLen + ti;
+                float qk  = smem_S_[idx];
+                float pr  = expf(qk - frag_M[qi]);
+                // printf("smem_P[%d] = %f\n", ti, pr);
+                smem_P_[idx] = pr;
+                frag_L[qi] += pr;
+            }
+        }
+
+        /// block synchronization
+        frag_L = blockSum(frag_L, smem_red_sum_, warp_id_, lane_id_);
+
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            // exp(m1 - m2) * l1
+            frag_L[qi] += exp_M_diff[qi] * smem_L_[qi];
+        }
+
+        __syncthreads();
+
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            if (threadIdx.x == 0) {
+                smem_L_[qi] = frag_L[qi];
+            }
+        }
+
+        if (threadIdx.x == 0 && step == timestep_ - kSliceLen) {
+            // printf("frag_L'[%d] = %f\n", head_idx_, (float)frag_L[0]);
+        }
+
+        /////////////////////////////////////////////////////////////////////////////////////////
+        // / Compute O[H,D] = P[H,S] * V[S,D]
+        VecKvFloat frag_O[kHeadPerCta][kKvVecPerThread]{};  // value initialize
+                                                            // float      frag_Pr_buf[2][kHeadPerCta];
+
+        // ti = step + offset.y;
+
+        // int ti = step + offset.y;
+
+        // PRAGMA_UNROLL
+        // for (int qi = 0; qi < kHeadPerCta; ++qi) {
+        //     // prefetch Pr for first warp iter
+        //     frag_Pr_buf[0][qi] = smem_P_[qi * kSliceLen + ti];
+        // }
+
+        IterKv iter_V;
+
+        if constexpr (kUseBlockIter) {
+            iter_V = {v_cache_ptrs_,
+                      params_.kv_cache_block_size,
+                      params_.layer_offset,
+                      kv_head_idx_,
+                      smem_Kv_,
+                      step,
+                      step + iter_length,
+                      warp_id_,
+                      lane_id_};
+        }
+        else {
+            iter_V = {v_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_};
+        }
+
+        PrefetchKvCache(iter_V);
+        CpAsyncWait();
+
+        iter_V.Load(state.frag_Kv_tmp_buf[0]);
+        PRAGMA_UNROLL
+        for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+            state.frag_V_buf[0][vi] = conv_v_(state.frag_Kv_tmp_buf[0][vi]);
+        }
+
+        iter_V.PrefetchBatch(0, kPrefetchCount);
+        if (kKvKeyPerThread == 1) {
+            CpAsyncCommit();
+            CpAsyncWait();
+            iter_V.AdvancePrefetchStage();
+            iter_V.AdvanceComputeStage();
+        }
+
+        PRAGMA_NO_UNROLL
+        for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
+            PRAGMA_UNROLL
+            for (int si = 0; si < kKvKeyPerThread; ++si) {
+                const int next = (si + 1) % 2;
+                // Load value cache for next warp iter
+                iter_V.Load(state.frag_Kv_tmp_buf[next]);
+                PRAGMA_UNROLL
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_V_buf[next][vi] = conv_v_(state.frag_Kv_tmp_buf[next][vi]);
+                }
+
+                // Load Pr for next warp iter
+                // PRAGMA_UNROLL
+                // for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                //     frag_Pr_buf[(si + 1) % 2][qi] = smem_P_[qi * kSliceLen + (ti + MapKv::kWarpAccessS)];
+                // }
+
+                auto& frag_V = state.frag_V_buf[si % 2];
+                // auto& frag_P = frag_Pr_buf[si % 2];
+
+                const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
+
+                float frag_P[kHeadPerCta];
+                PRAGMA_UNROLL
+                for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                    frag_P[qi] = smem_P_[qi * kSliceLen + local_offset];
+                }
+
+                if (step + local_offset < timestep_) {
+                    PRAGMA_UNROLL
+                    for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                        fma_pv(frag_P[qi], frag_V, frag_O[qi]);
+                    }
+                    // for (int i = 0; i < kKvVecPerThread; ++i) {
+                    //     for (int j = 0; j < kVecKvSize; ++j) {
+                    //         printf("frag_V %f\n", (float)frag_V[i][j]);
+                    //     }
+                    // }
+                    // if (threadIdx.x % MapKv::kWarpThreadC == 0) {
+                    //     printf("frag_P[%d] %f\n", ti, frag_P[0]);
+                    // }
+                }
+
+                iter_V.PrefetchBatch((si + 1) % kKvKeyPerThread, kPrefetchCount);
+
+                if (kKvKeyPerThread == 1 || si == kKvKeyPerThread - 2) {
+                    CpAsyncCommit();
+                    CpAsyncWait();
+                    iter_V.AdvancePrefetchStage();
+                    iter_V.AdvanceComputeStage();
+                }
+            }
+
+            // handle special case
+            if (kKvKeyPerThread == 1) {
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_V_buf[0][vi] = state.frag_V_buf[1][vi];
+                }
+                // PRAGMA_UNROLL
+                // for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                //     frag_Pr_buf[0][qi] = frag_Pr_buf[1][qi];
+                // }
+            }
+        }
+
+        /// warp reduce over S dim
+        PRAGMA_UNROLL
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            PRAGMA_UNROLL
+            for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                PRAGMA_UNROLL
+                for (int i = 0; i < kVecKvSize; ++i) {
+                    // reduce over warp thread S
+                    PRAGMA_UNROLL
+                    for (int mask = WARP_SIZE / 2; mask >= MapKv::kWarpThreadC; mask /= 2) {
+                        frag_O[qi][vi][i] += __shfl_xor_sync(uint32_t(-1), frag_O[qi][vi][i], mask);
+                    }
+                }
+            }
+        }
+
+        // __syncthreads();
+
+        PRAGMA_UNROLL
+        for (int gi = 0; gi < MapKv::kS; gi += MapKv::kFootprintS) {
+            PRAGMA_UNROLL
+            for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                PRAGMA_UNROLL
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    if (offset.y == gi) {
+                        // bank conflict
+                        auto& smem_O = (VecKvFloat&)smem_O_[qi * kMaxHeadDim + offset.x + vi * MapKv::kDeltaC];
+                        using namespace ops;
+                        auto tmp_O = smem_O;
+                        if (offset.y == 0) {
+                            tmp_O = tmp_O * exp_M_diff[qi];
+                        }
+                        // bank conflict
+                        smem_O = tmp_O + frag_O[qi][vi];
+                    }
+                }
+            }
+            __syncthreads();
+        }
+
+        CpAsyncFlush();
+    }
+
+    __device__ void LoopKv()
+    {
+        const int2 offset = MapKv::get_offset(warp_id_, lane_id_);
+
+        ///////////////////////////////////////////////////////////////////////////////////////////
+        /// Load Q from shared memory.
+        /// NOTE: There will be bank-conflict when sizeof(VecKv) > 16 (e.g. KV is quantized)
+        FragmentQ frag_Q;
+
+        PRAGMA_UNROLL
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            PRAGMA_UNROLL
+            for (int c = 0; c < kKvVecPerThread; ++c) {
+                const int di       = offset.x + MapKv::kDeltaC * c;
+                frag_Q.data[qi][c] = (VecKv&)smem_Q_[qi * kMaxHeadDim + di];
+            }
+        }
+
+        State state;
+
+        PRAGMA_NO_UNROLL
+        for (int step = step_begin_; step < step_end_; step += kSliceLen) {
+            int iter_count = min(step_end_ - step, kSliceLen);
+            ComputeSlice(frag_Q, state, offset, step, iter_count);
+        }
+    }
+
+    __device__ void Run()
+    {
+        if constexpr (0) {
+            for (int i = threadIdx.x; i < kStages * IterKv::kSizePerTile; i += blockDim.x) {
+                smem_Kv_[i] = T(0);
+            }
+            __syncthreads();
+        }
+
+        // early exit if split if out of bound
+        if (kSplitK && step_begin_ && step_begin_ >= step_end_) {
+            return;
+        }
+
+        // early exit if finished flag is set
+        if (params_.finished[batch_idx_]) {
+            return;
+        }
+
+        // Compute attention for current step
+        Prolugue();
+
+        __syncthreads();
+
+        // Iterate over K/V
+        LoopKv();
+
+        __syncthreads();
+
+        // Normalize outputs & write to device memory
+        Epilogue();
+    }
+
+    __device__ void Epilogue()
+    {
+        static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
+
+        using VecQFloat = Array;
+
+        using MapQ = ThreadMapQ;
+
+        static constexpr int kQkvHeadPerThread = MapQ::kIterS;
+
+        int2 offset = MapQ::get_offset(warp_id_, lane_id_);
+
+        if (offset.x >= kMaxHeadDim || offset.y >= kHeadPerCta) {
+            return;
+        }
+
+        using namespace ops;
+
+        if (!kSplitK || (step_begin_ == 0 && step_end_ == timestep_)) {  // non-split-k
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {
+                const int di = offset.x;
+                const int qi = offset.y + s;
+
+                const float     scale  = __fdividef(1.f, smem_L_[qi] + 1e-8f);
+                const VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
+
+                Store(¶ms_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
+                      cast(frag_O));
+            }
+        }
+        else {
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {  // split-k
+                const int di = offset.x;
+                const int qi = offset.y + s;
+
+                const VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di];
+
+                // [B, H, k, D]
+                const int index = batch_idx_ * params_.num_heads * params_.max_split_k
+                                  + (head_idx_ + qi) * params_.max_split_k + get_split_k_idx();
+                Store(¶ms_.partial_O[index * kHeadDim + di], cast(frag_O));
+
+                if (di == 0) {
+                    params_.partial_M[index] = smem_M_[qi];
+                    params_.partial_L[index] = smem_L_[qi];
+                }
+            }
+        }
+    }
+
+    static __device__ void Reduce(const ParamType& params)
+    {
+        const int batch_idx       = get_batch_idx();
+        const int head_idx        = get_head_idx();
+        const int timestep        = params.context_length[batch_idx] - 1;
+        const int max_split_k     = params.max_split_k;
+        const int slice_count     = get_slice_count(timestep);
+        const int slice_per_split = (slice_count + max_split_k - 1) / max_split_k;
+        const int split_k         = (slice_count + slice_per_split - 1) / slice_per_split;
+
+        if (split_k == 1) {
+            return;
+        }
+
+        // [B, H, k, D]
+        const int index = batch_idx * params.num_heads * max_split_k + head_idx * max_split_k + threadIdx.x;
+
+        __shared__ float smem_global_M;
+        __shared__ float smem_global_L;
+        __shared__ __align__(16) float smem_expdiff_M[WARP_SIZE];
+        __shared__ __align__(16) float smem_scale_O[WARP_SIZE];
+
+        {
+            float global_M = threadIdx.x < split_k ? params.partial_M[index] : -std::numeric_limits::infinity();
+            PRAGMA_UNROLL
+            for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+                global_M = fmaxf(global_M, __shfl_xor_sync((uint32_t)-1, global_M, mask));
+            }
+
+            if (threadIdx.x == 0) {
+                smem_global_M = global_M;
+            }
+        }
+
+        __syncthreads();
+
+        {
+            float global_L = threadIdx.x < split_k ? params.partial_L[index] : 0.f;
+
+            if (threadIdx.x < split_k) {
+                auto expdiff_M = expf(params.partial_M[index] - smem_global_M);
+                global_L *= expdiff_M;
+                smem_expdiff_M[threadIdx.x] = expdiff_M;
+            }
+
+            PRAGMA_UNROLL
+            for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+                global_L += __shfl_xor_sync((uint32_t)-1, global_L, mask);
+            }
+
+            if (threadIdx.x == 0) {
+                smem_global_L = global_L;
+            }
+        }
+
+        __syncthreads();
+
+        if (threadIdx.x < split_k) {
+            smem_scale_O[threadIdx.x] = smem_expdiff_M[threadIdx.x] / (smem_global_L + 1e-8f);
+        }
+
+        __syncthreads();
+
+        int   idx = (batch_idx * params.num_heads * max_split_k + head_idx * max_split_k) * kHeadDim + threadIdx.x;
+        float accum_O{};
+
+        const bool is_valid = threadIdx.x < kHeadDim;
+
+        for (int k = 0; k < split_k; ++k) {
+            if (is_valid) {
+                accum_O += smem_scale_O[k] * params.partial_O[idx];
+            }
+            idx += kHeadDim;
+        }
+        if (is_valid) {
+            params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + threadIdx.x] = (T)accum_O;
+        }
+    }
+
+    static __device__ int get_slice_count(int timestep)
+    {
+        return (timestep + kSliceLen - 1) / kSliceLen;
+    }
+
+    static __device__ int get_head_idx()
+    {
+        return blockIdx.x;
+    }
+
+    static __device__ int get_batch_idx()
+    {
+        return blockIdx.y;
+    }
+
+    static __device__ int get_split_k_idx()
+    {
+        return blockIdx.z;
+    }
+};
+
+extern __shared__ uint8_t dynamic_smem[];
+
+template
+__global__ void decoder_multihead_attention(ParamType params)
+{
+    __shared__ typename MHAType::SharedStorage shared_storage;
+
+    uint8_t* smem_ptr = dynamic_smem;
+
+    MHAType{params, shared_storage, smem_ptr}.Run();
+}
+
+template
+__global__ void decoder_multihead_attention_reduce(ParamType params)
+{
+    MHAType::Reduce(params);
+}
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
new file mode 100644
index 0000000000..683d95b589
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -0,0 +1,333 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "../gemm_s_f16/common.h"
+#include "array_ops.h"
+
+namespace turbomind {
+
+#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
+#define L2_CACHEHINT(size) ".L2::" #size "B"
+#else
+#define L2_CACHEHINT(size)
+#endif
+
+struct BlockIterator {
+    const void** ptrs_;
+    const void*  prefetch_;
+
+    BlockIterator() = default;
+
+    __device__ BlockIterator(const void** block_ptrs): ptrs_{block_ptrs}
+    {
+        // prefetch first ptr
+        prefetch_ = *ptrs_++;
+    }
+
+    __device__ const void* Next()
+    {
+        // return prefetched ptr
+        const void* ret = prefetch_;
+        // prefetch next ptr
+        prefetch_ = *ptrs_++;
+
+        return ret;
+    }
+};
+
+template
+struct Iterator {
+
+    using ElementType = T;
+    using AccessType  = Array;
+
+    static constexpr int kElementSize = sizeof(ElementType);
+    static constexpr int kAccessSize  = sizeof(AccessType);
+
+    static constexpr int kSizePerTile  = ThreadMap::kS * ThreadMap::kC;
+    static constexpr int kSmemByteSize = kElementSize * Stages * kSizePerTile;
+
+    BlockIterator block_iterator_;
+
+    static constexpr int kIterCount = ThreadMap::kIterS * ThreadMap::kIterC;
+
+    static constexpr int kStepC = ThreadMap::kDeltaC;
+    static constexpr int kStepS = ThreadMap::kDeltaS * ThreadMap::kC - ThreadMap::kIterC * kStepC;
+    static constexpr int kStepK =
+        ThreadMap::kS * ThreadMap::kC - ThreadMap::kIterS * ThreadMap::kDeltaS * ThreadMap::kC;
+
+    // (C, S, K) = (64, 384, 1536)
+
+    // initial offset, used to reset src_offset when switching to a new block
+    int init_offset_;
+
+    int src_offset_;
+    int dst_offset_;
+
+    int iter_c_;
+    int iter_b_;
+
+    int  seq_len_;
+    int  offset_s_;
+    bool is_valid_s_;
+
+    int block_size_;
+    int block_k_;
+    int layer_offset_;
+
+    int head_idx_;
+
+    const T* __restrict__ src_;
+    T* __restrict__ smem_;
+
+    int smem_read_offset_;
+
+    struct __align__(sizeof(AccessType)) SharedStorage
+    {
+        T smem_[Stages][kSizePerTile];
+    };
+
+    Iterator() = default;
+
+    __device__ Iterator(T* src, T* smem, int step, int seq_len, int warp_id, int lane_id)
+    {
+        src_  = src;
+        smem_ = smem;
+
+        int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
+
+        init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
+
+        src_offset_       = init_offset_ + step * ThreadMap::kC;
+        dst_offset_       = init_offset_;
+        smem_read_offset_ = init_offset_;
+
+        iter_c_ = 0;
+        iter_b_ = 0;
+
+        seq_len_    = seq_len;
+        offset_s_   = init_offset_cs.y + step;
+        is_valid_s_ = offset_s_ < seq_len;
+    }
+
+    __device__ Iterator(const void** block_ptrs,
+                        int          block_size,
+                        int          layer_offset,
+                        int          head_idx,
+                        T*           smem,
+                        int          step,
+                        int          seqlen,
+                        int          warp_id,
+                        int          lane_id)
+    {
+        // src_  = src;
+        int block_index = step / block_size;
+        block_size_     = block_size;
+        block_k_        = (block_index + 1) * block_size - step;  // offset to next block
+        layer_offset_   = layer_offset;
+        head_idx_       = head_idx;
+
+        block_iterator_ = BlockIterator(block_ptrs + block_index);
+
+        src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
+
+        smem_ = smem;
+
+        int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
+
+        init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
+
+        src_offset_       = init_offset_ + (step - block_index * block_size) * ThreadMap::kC;
+        dst_offset_       = init_offset_;
+        smem_read_offset_ = init_offset_;
+
+        iter_c_ = 0;
+        iter_b_ = 0;
+
+        seq_len_    = seqlen;
+        offset_s_   = init_offset_cs.y + step;
+        is_valid_s_ = offset_s_ < seqlen;
+    }
+
+    __device__ void PrefetchStage()
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < kIterCount; ++i) {
+            Prefetch(is_valid_s_);
+            ++(*this);
+        }
+        AdvancePrefetchStage();
+    }
+
+    __device__ void PrefetchBatch(int batch_idx, int batch_size)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < batch_size; ++i) {
+            if (batch_idx * batch_size + i < kIterCount) {
+                Prefetch(is_valid_s_);
+                ++(*this);
+            }
+        }
+    }
+
+    __device__ Iterator& operator++()
+    {
+        src_offset_ += kStepC;
+        dst_offset_ += kStepC;
+        ++iter_c_;
+        if (iter_c_ < ThreadMap::kIterC) {
+            return *this;
+        }
+
+        iter_c_ = 0;
+        src_offset_ += kStepS;
+        dst_offset_ += kStepS;
+
+        offset_s_ += ThreadMap::kDeltaS;
+        is_valid_s_ = offset_s_ < seq_len_;
+
+        return *this;
+    }
+
+    __device__ void AdvancePrefetchStage()
+    {
+        src_offset_ += kStepK;
+        dst_offset_ += kStepK;
+
+        offset_s_ += ThreadMap::kS - ThreadMap::kIterS * ThreadMap::kDeltaS;
+
+        is_valid_s_ = offset_s_ < seq_len_;
+
+        if constexpr (kUseBlockIter) {
+            if (is_valid_s_) {
+                block_k_ -= ThreadMap::kS;
+                if (block_k_ == 0) {
+                    src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
+                    block_k_    = block_size_;
+                    src_offset_ = init_offset_;
+                }
+            }
+            // if (blockIdx.x == 0 && threadIdx.x == 0) {
+            //     printf("%d %d %d\n", offset_s_, src_offset_ / ThreadMap::kC, block_k_);
+            // }
+        }
+
+        // if (init_offset_ / ThreadMap::kC == 0) {
+        //     int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
+        //     int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
+        //     int c = dst_offset_ % ThreadMap::kC;
+        //     printf("tid=%d, k=%d, s=%d, c=%d, offset_s=%d, valid_s=%d, init_s=%d\n",
+        //            threadIdx.x,
+        //            k,
+        //            s,
+        //            c,
+        //            offset_s_,
+        //            (int)is_valid_s_,
+        //            init_offset_ / ThreadMap::kC);
+        // }
+
+        // if (threadIdx.x == 0 && blockIdx.x == 0) {
+        //     printf("next stage %d\n", offset_s_);
+        // }
+
+        if (dst_offset_ >= Stages * kSizePerTile) {
+            dst_offset_ -= Stages * kSizePerTile;
+        }
+
+        // if constexpr (Chained) {
+        //     bool is_last_stage = *signal_iterator_;
+
+        //     ++signal_iterator_;
+
+        //     if (is_last_stage) {
+        //         AdvancePrefetchSlice();
+        //     }
+        // }
+    }
+
+#if 0
+    __device__ void AdvancePrefetchSlice()
+    {
+        src_        = (const T*)block_iterator_.Next();
+        src_offset_ = init_offset_;
+
+        ++iter_b_;
+        offset_s_   = iter_b_ / 2 * BlockLen + init_offset_ / ThreadMap::kC;
+        is_valid_s_ = offset_s_ < seq_len_;
+    }
+#endif
+
+    static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask)
+    {
+        const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);
+        constexpr int cp_size      = sizeof(AccessType);
+#if TURBOMIND_ARCH_SM80
+        // clang-format off
+        asm volatile("{\n"
+                     "  .reg .pred p;\n"
+                     "  setp.ne.b32 p, %0, 0;\n"
+                     "  @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
+                     "}\n" ::"r"((int)mask),
+                     "r"(smem_int_ptr),
+                     "l"(src),
+                     "n"(cp_size));
+        // clang-format on
+#else
+        assert(TURBOMIND_ARCH_SM80);
+#endif
+    }
+
+    static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
+    {
+        if (mask) {
+            Ldg(*(AccessType*)dst, src);
+        }
+    }
+
+    __device__ void Prefetch(bool mask)
+    {
+        if constexpr (TURBOMIND_ARCH_SM80) {
+            CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
+        }
+        else {
+            Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
+        }
+    }
+
+    __device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
+    {
+
+        // if (init_offset_ / ThreadMap::kC == 0) {
+        //     int k = smem_read_offset_ / (ThreadMap::kS * ThreadMap::kC);
+        //     int s = smem_read_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
+        //     int c = smem_read_offset_ % ThreadMap::kC;
+        //     printf("tid=%d, k=%d, s=%d, c=%d, init_s=%d\n", threadIdx.x, k, s, c, init_offset_ / ThreadMap::kC);
+        // }
+
+        for (int vi = 0; vi < ThreadMap::kIterC; ++vi) {
+
+            // int offset = smem_read_offset_ + vi * ThreadMap::kDeltaC;
+            // if (offset >= Stages * kSizePerTile || offset % sizeof(AccessType)) {
+            //     int c = offset % ThreadMap::kC;
+            //     int s = offset / ThreadMap::kC;
+            //     printf("%d %d %d\n", c, s, offset);
+            // }
+
+            Lds(frag[vi], smem_ + smem_read_offset_ + vi * ThreadMap::kDeltaC);
+        }
+
+        smem_read_offset_ += ThreadMap::kDeltaS * ThreadMap::kC;
+    }
+
+    __device__ void AdvanceComputeStage()
+    {
+        smem_read_offset_ += kStepK;
+
+        if (smem_read_offset_ >= Stages * kSizePerTile) {
+            smem_read_offset_ -= Stages * kSizePerTile;
+        }
+    }
+};
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
new file mode 100644
index 0000000000..d9a46c40a7
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
@@ -0,0 +1,481 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "../gemm_s_f16/common.h"
+#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
+#include "src/turbomind/models/llama/llama_utils.h"
+#include "src/turbomind/utils/debug_utils.h"
+#include 
+#include 
+
+namespace turbomind {
+
+// [S/x, H, x, D] <-> [S/y, H, y, D]
+
+template>
+__inline__ __device__ void ConvertBlockSize(const Tin** __restrict__ src_block_ptrs,
+                                            Tout** __restrict__ dst_block_ptrs,
+                                            const int* __restrict__ src_cu_block_cnts,
+                                            const int* __restrict__ dst_cu_block_cnts,
+                                            const int* __restrict__ seq_lens,
+                                            int         src_offset,
+                                            int         dst_offset,
+                                            SrcBlockLen src_block_len,
+                                            DstBlockLen dst_block_len,
+                                            HeadDim     head_dim,
+                                            Transform   transform = {1.f, 0.f})
+{
+    constexpr int kVecSize = sizeof(uint4) / std::max(sizeof(Tin), sizeof(Tout));
+
+    const int hi = blockIdx.y;
+    const int bi = blockIdx.z;
+
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    /// TODO: use cutlass fast div/mod
+    const int di = idx * kVecSize % head_dim;
+    const int si = idx * kVecSize / head_dim;
+
+    if (si >= seq_lens[bi]) {
+        return;
+    }
+
+    // compute indices into src
+    int src_block_index  = si / src_block_len + src_cu_block_cnts[bi];
+    int src_block_offset = src_offset + hi * src_block_len * head_dim + si % src_block_len * head_dim + di;
+
+    // compute indices into dst
+    int dst_block_index  = si / dst_block_len + dst_cu_block_cnts[bi];
+    int dst_block_offset = dst_offset + hi * dst_block_len * head_dim + si % dst_block_len * head_dim + di;
+
+    // printf("%d %d\n", src_block_index, dst_block_index);
+
+    const Tin* __restrict__ src_block = src_block_ptrs[src_block_index];
+    Tout* __restrict__ dst_block      = dst_block_ptrs[dst_block_index];
+
+    // uint4 data = __ldg(reinterpret_cast(src_block + src_block_offset));
+
+    Array src_vec;
+    Ldg(src_vec, src_block + src_block_offset);
+
+    Array dst_vec = transform(src_vec);
+    Store(dst_block + dst_block_offset, dst_vec);
+
+    // *reinterpret_cast(dst_block + dst_block_offset) = data;
+}
+
+template
+__global__ void LinearToBlocksKernel(const T*   src,
+                                     T**        dst_block_ptrs,
+                                     const int* dst_cu_block_cnts,
+                                     const int* seq_lens,
+                                     int        dst_offset,
+                                     int        src_block_len,
+                                     int        dst_block_len,
+                                     int        head_num,
+                                     int        head_dim,
+                                     int        batch_size)
+{
+    extern __shared__ void* smem[];
+
+    const T** src_block_ptrs    = (const T**)smem;
+    int*      src_cu_block_cnts = (int*)(src_block_ptrs + batch_size);
+
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        src_cu_block_cnts[i] = i;
+        src_block_ptrs[i]    = src + blockIdx.z * head_num * src_block_len * head_dim;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_block_ptrs,
+                     dst_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     0,
+                     dst_offset,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+}
+
+template
+void ConvertLinearToBlocks(const T*     src,
+                           T**          dst_block_ptrs,
+                           const int*   dst_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          dst_offset,
+                           int          src_max_len,
+                           int          dst_block_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
+                           cudaStream_t st)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    constexpr int threads = 128;
+    const dim3    blocks((src_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
+
+    const auto smem_sz = (sizeof(void*) + sizeof(int)) * batch_size;
+
+    auto fn = [&](auto head_dim) {
+        LinearToBlocksKernel<<>>(src,
+                                                               dst_block_ptrs,
+                                                               dst_cu_block_cnts,
+                                                               seq_lens,
+                                                               dst_offset,
+                                                               src_max_len,
+                                                               dst_block_len,
+                                                               head_num,
+                                                               head_dim,
+                                                               batch_size);
+    };
+
+    switch (head_dim) {
+        case 128:
+            fn(std::integral_constant{});
+            break;
+        default:
+            fn(head_dim);
+    }
+}
+
+template void ConvertLinearToBlocks(const half*  src,
+                                    half**       dst_block_ptrs,
+                                    const int*   dst_cu_block_cnts,
+                                    const int*   seq_lens,
+                                    int          dst_offset,
+                                    int          src_seq_len,
+                                    int          dst_block_len,
+                                    int          head_num,
+                                    int          head_dim,
+                                    int          batch_size,
+                                    cudaStream_t st);
+
+template
+__global__ void BlocksToLinearKernel(const T**  src_block_ptrs,
+                                     T*         dst,
+                                     const int* src_cu_block_cnts,
+                                     const int* seq_lens,
+                                     int        src_offset,
+                                     int        src_block_len,
+                                     int        dst_block_len,
+                                     int        head_num,
+                                     HeadDim    head_dim,
+                                     int        batch_size)
+{
+    extern __shared__ void* smem[];
+
+    T**  dst_block_ptrs    = (T**)smem;
+    int* dst_cu_block_cnts = (int*)(dst_block_ptrs + batch_size);
+
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+        dst_block_ptrs[i]    = dst + blockIdx.z * head_num * dst_block_len * head_dim;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_block_ptrs,
+                     dst_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+}
+
+template
+void ConvertBlocksToLinear(const T**    src_block_ptrs,
+                           T*           dst,
+                           const int*   src_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          src_offset,
+                           int          src_block_len,
+                           int          dst_max_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
+                           cudaStream_t st)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    constexpr int threads = 256;
+    const dim3    blocks((dst_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
+
+    const auto smem_sz = (sizeof(void*) + sizeof(int)) * batch_size;
+
+    auto fn = [&](auto head_dim) {
+        BlocksToLinearKernel<<>>(src_block_ptrs,
+                                                               dst,
+                                                               src_cu_block_cnts,
+                                                               seq_lens,
+                                                               src_offset,
+                                                               src_block_len,
+                                                               dst_max_len,
+                                                               head_num,
+                                                               head_dim,
+                                                               batch_size);
+    };
+
+    switch (head_dim) {
+        case 128:
+            fn(std::integral_constant{});
+            break;
+        default:
+            fn(head_dim);
+    }
+}
+
+template void ConvertBlocksToLinear(const half** src_block_ptrs,
+                                    half*        dst,
+                                    const int*   src_cu_block_cnts,
+                                    const int*   seq_lens,
+                                    int          src_offset,
+                                    int          src_block_len,
+                                    int          dst_max_seq_len,
+                                    int          head_num,
+                                    int          head_dim,
+                                    int          batch_size,
+                                    cudaStream_t st);
+
+template
+__global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
+                                            const T**   src_v_block_ptrs,
+                                            T**         dst_k_ptrs,
+                                            T**         dst_v_ptrs,
+                                            const int*  src_cu_block_cnts,
+                                            const int*  seq_lens,
+                                            int         src_offset,
+                                            SrcBlockLen src_block_len,
+                                            DstBlockLen dst_block_len,
+                                            int         head_num,
+                                            HeadDim     head_dim,
+                                            int         batch_size)
+{
+    extern __shared__ int dst_cu_block_cnts[];
+
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_k_block_ptrs,
+                     dst_k_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+
+    ConvertBlockSize(src_v_block_ptrs,
+                     dst_v_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+}
+
+void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
+                                  const void** src_v_block_ptrs,
+                                  void**       dst_k_ptrs,
+                                  void**       dst_v_ptrs,
+                                  const int*   src_cu_block_cnts,
+                                  const int*   seq_lens,
+                                  int          src_offset,
+                                  int          src_block_len,
+                                  int          dst_block_len,
+                                  int          head_num,
+                                  int          head_dim,
+                                  int          batch_size,
+                                  int          elem_bits,
+                                  cudaStream_t st)
+{
+    auto fn = [&](auto value) {
+        using T = decltype(value);
+
+        constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+        constexpr int kThreads = 256;
+
+        const dim3 blocks((dst_block_len * head_dim / kVecSize + kThreads - 1) / kThreads, head_num, batch_size);
+        const auto smem_sz = sizeof(int) * batch_size;
+
+        KvCacheBlocksToLinearKernel<<>>((const T**)src_k_block_ptrs,
+                                                                       (const T**)src_v_block_ptrs,
+                                                                       (T**)dst_k_ptrs,
+                                                                       (T**)dst_v_ptrs,
+                                                                       src_cu_block_cnts,
+                                                                       seq_lens,
+                                                                       src_offset,
+                                                                       src_block_len,
+                                                                       dst_block_len,
+                                                                       head_num,
+                                                                       head_dim,
+                                                                       batch_size);
+    };
+
+    switch (elem_bits) {
+        case 8:
+            fn(uint8_t{});
+            break;
+        case 16:
+            fn(uint16_t{});
+            break;
+        case 32:
+            fn(uint32_t{});
+            break;
+        default:
+            fprintf(stderr, "unsupported elem bits: %d\n", elem_bits);
+    }
+}
+
+template
+__global__ void KvCacheBlocksToLinearKernel2(const Tin** src_k_block_ptrs,
+                                             const Tin** src_v_block_ptrs,
+                                             Tout**      dst_k_ptrs,
+                                             Tout**      dst_v_ptrs,
+                                             const int*  src_cu_block_cnts,
+                                             const int*  seq_lens,
+                                             int         src_offset,
+                                             SrcBlockLen src_block_len,
+                                             DstBlockLen dst_block_len,
+                                             int         head_num,
+                                             HeadDim     head_dim,
+                                             int         batch_size,
+                                             TransformK  transform_k,
+                                             TransformV  transform_v)
+{
+    extern __shared__ int dst_cu_block_cnts[];
+
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_k_block_ptrs,
+                     dst_k_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim,
+                     transform_k);
+
+    ConvertBlockSize(src_v_block_ptrs,
+                     dst_v_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim,
+                     transform_v);
+}
+
+template
+void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                   const void** src_v_block_ptrs,
+                                   T**          dst_k_ptrs,
+                                   T**          dst_v_ptrs,
+                                   const int*   src_cu_block_cnts,
+                                   const int*   seq_lens,
+                                   int          src_offset,
+                                   int          src_block_len,
+                                   int          dst_block_len,
+                                   int          head_num,
+                                   int          head_dim,
+                                   int          batch_size,
+                                   int          quant_policy,
+                                   const float* kv_params,
+                                   cudaStream_t st)
+{
+    auto fn = [&](auto tin) {
+        using Tin = decltype(tin);
+
+        constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+        constexpr int kThreads = 256;
+
+        const dim3 blocks((dst_block_len * head_dim / kVecSize + kThreads - 1) / kThreads, head_num, batch_size);
+        const auto smem_sz = sizeof(int) * batch_size;
+
+        KvCacheBlocksToLinearKernel2<<>>(
+            (const Tin**)src_k_block_ptrs,
+            (const Tin**)src_v_block_ptrs,
+            (T**)dst_k_ptrs,
+            (T**)dst_v_ptrs,
+            src_cu_block_cnts,
+            seq_lens,
+            src_offset,
+            src_block_len,
+            dst_block_len,
+            head_num,
+            head_dim,
+            batch_size,
+            ConvertKvCache{kv_params[0], kv_params[1]},
+            ConvertKvCache{kv_params[2], kv_params[3]});
+    };
+
+    (quant_policy & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
+}
+
+template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                            const void** src_v_block_ptrs,
+                                            float**      dst_k_ptrs,
+                                            float**      dst_v_ptrs,
+                                            const int*   src_cu_block_cnts,
+                                            const int*   seq_lens,
+                                            int          src_offset,
+                                            int          src_block_len,
+                                            int          dst_block_len,
+                                            int          head_num,
+                                            int          head_dim,
+                                            int          batch_size,
+                                            int          quant_policy,
+                                            const float* kv_params,
+                                            cudaStream_t st);
+
+template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                            const void** src_v_block_ptrs,
+                                            half**       dst_k_ptrs,
+                                            half**       dst_v_ptrs,
+                                            const int*   src_cu_block_cnts,
+                                            const int*   seq_lens,
+                                            int          src_offset,
+                                            int          src_block_len,
+                                            int          dst_block_len,
+                                            int          head_num,
+                                            int          head_dim,
+                                            int          batch_size,
+                                            int          quant_policy,
+                                            const float* kv_params,
+                                            cudaStream_t st);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
new file mode 100644
index 0000000000..f72d58c135
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
@@ -0,0 +1,67 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include 
+
+namespace turbomind {
+
+template
+void ConvertLinearToBlocks(const T*     src,
+                           T**          dst_block_ptrs,
+                           const int*   dst_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          dst_offset,
+                           int          src_seq_len,
+                           int          dst_block_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
+                           cudaStream_t st);
+
+template
+void ConvertBlocksToLinear(const T**    src_block_ptrs,
+                           T*           dst,
+                           const int*   src_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          src_offset,
+                           int          src_block_len,
+                           int          dst_max_seq_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
+                           cudaStream_t st);
+
+void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
+                                  const void** src_v_block_ptrs,
+                                  void**       dst_k_ptrs,
+                                  void**       dst_v_ptrs,
+                                  const int*   src_cu_block_cnts,
+                                  const int*   seq_lens,
+                                  int          src_offset,
+                                  int          src_block_len,
+                                  int          dst_block_len,  // max{seq_lens}
+                                  int          head_num,
+                                  int          head_dim,
+                                  int          batch_size,
+                                  int          elem_bits,
+                                  cudaStream_t st);
+
+template
+void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                   const void** src_v_block_ptrs,
+                                   T**          dst_k_ptrs,
+                                   T**          dst_v_ptrs,
+                                   const int*   src_cu_block_cnts,
+                                   const int*   seq_lens,
+                                   int          src_offset,
+                                   int          src_block_len,
+                                   int          dst_block_len,
+                                   int          head_num,
+                                   int          head_dim,
+                                   int          batch_size,
+                                   int          quant_policy,
+                                   const float* kv_params,
+                                   cudaStream_t st);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
new file mode 100644
index 0000000000..a3cc2568b8
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -0,0 +1,328 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "decoder_multihead_attention.h"
+#include "kv_cache.h"
+#include "test_utils.h"
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+using namespace turbomind;
+
+template
+T* align(T* ptr, size_t alignment)
+{
+    size_t misalign = (uintptr_t)ptr % alignment;
+    std::cout << "misalignment: " << misalign << "\n";
+    if (misalign) {
+        return (T*)((uint8_t*)ptr + alignment - misalign);
+    }
+    return ptr;
+}
+
+// [S/S, H, S, D] <-> [S/b, H, b, D]
+
+void TestBlocks(thrust::universal_vector&  linear,          // linear data
+                thrust::universal_vector&  _blocks,         // block data
+                thrust::universal_vector& _ptrs,           // block ptrs
+                thrust::universal_vector&   _cu_block_cnts,  // cumulative block counts
+                int                              head_num,
+                int                              head_dim,
+                int                              block_size,
+                int                              batch_size)
+{
+    int seq_len  = linear.size() / (head_dim * head_num * batch_size);
+    int n_blocks = (seq_len + block_size - 1) / block_size;
+
+    std::cout << "batch_size = " << batch_size << ", seq_len = " << seq_len << ", block_num = " << n_blocks
+              << ", block_size = " << block_size << "\n";
+
+    thrust::universal_vector  blocks(batch_size * n_blocks * head_num * block_size * head_dim);
+    thrust::universal_vector ptrs(batch_size * n_blocks + 1);  // +1 padding
+
+    std::vector idxs(batch_size * n_blocks);
+    std::iota(idxs.begin(), idxs.end(), 0);
+
+    std::random_device rd;
+    std::mt19937       g(rd());
+    std::shuffle(idxs.begin(), idxs.end(), g);
+
+    for (size_t i = 0; i < idxs.size(); ++i) {
+        ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
+    }
+
+    thrust::universal_vector seq_lens(batch_size);
+    thrust::fill(seq_lens.begin(), seq_lens.end(), seq_len);
+
+    std::vector              n_blocks_vec(batch_size + 1, n_blocks);
+    thrust::universal_vector cu_block_cnts(batch_size + 1);
+    std::exclusive_scan(n_blocks_vec.begin(), n_blocks_vec.end(), cu_block_cnts.begin(), 0);
+
+    for (int i = 0; i < 10; ++i) {
+        ConvertLinearToBlocks((const half*)linear.data().get(),
+                              ptrs.data().get(),
+                              cu_block_cnts.data().get(),
+                              seq_lens.data().get(),
+                              0,
+                              seq_len,
+                              block_size,
+                              head_num,
+                              head_dim,
+                              batch_size,
+                              0);
+    }
+    thrust::universal_vector _linear(linear.size());
+
+    for (int i = 0; i < 10; ++i) {
+        ConvertBlocksToLinear((const half**)ptrs.data().get(),
+                              _linear.data().get(),
+                              cu_block_cnts.data().get(),
+                              seq_lens.data().get(),
+                              0,
+                              block_size,
+                              seq_len,
+                              head_num,
+                              head_dim,
+                              batch_size,
+                              0);
+    }
+    cudaDeviceSynchronize();
+
+    if (0) {
+        std::cout << ">>> Compare\n";
+        Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
+        std::cout << "<<< Compare\n";
+    }
+
+    _blocks.swap(blocks);
+    _ptrs.swap(ptrs);
+    _cu_block_cnts.swap(cu_block_cnts);
+}
+
+int main(int argc, char* argv[])
+{
+
+    DecoderMultiHeadAttentionParams params{};
+
+    constexpr int kHeadNum   = 32;
+    constexpr int kHeadDim   = 128;
+    constexpr int KvHeadNum  = 32;
+    constexpr int kBatchSize = 1;
+    // constexpr int kContextLen = 7306;
+    constexpr int kSequenceLen = 1024;
+    constexpr int kContextLen  = kSequenceLen + 1;
+    constexpr int kBlockSz     = 128;
+    constexpr int kTestIter    = 10;
+    constexpr int kMaxSplitK   = 1;
+
+    RNG rng{};
+
+    thrust::universal_vector  output(kBatchSize * kHeadNum * kHeadDim);
+    thrust::universal_vector  qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim);
+    thrust::universal_vector  finished(kBatchSize);
+    thrust::universal_vector  k_cache(kBatchSize * kContextLen * KvHeadNum * kHeadDim);
+    thrust::universal_vector  v_cache(kBatchSize * kContextLen * KvHeadNum * kHeadDim);
+    thrust::universal_vector   context_length(kBatchSize);
+    thrust::universal_vector   sequence_length(kBatchSize);
+    thrust::universal_vector k_cache_ptrs(kBatchSize);
+    thrust::universal_vector v_cache_ptrs(kBatchSize);
+
+    thrust::universal_vector partial_M(kBatchSize * kHeadNum * kMaxSplitK);
+    thrust::universal_vector partial_L(kBatchSize * kHeadNum * kMaxSplitK);
+    thrust::universal_vector partial_O(kBatchSize * kHeadNum * kMaxSplitK * kHeadDim);
+
+    rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
+
+    if (kSequenceLen) {
+        rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);
+        rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);
+
+        cudaMemset2DAsync(k_cache.data().get() + kSequenceLen * kHeadDim,
+                          sizeof(half) * kContextLen * kHeadDim,
+                          0,
+                          sizeof(half) * kHeadDim,
+                          kBatchSize * KvHeadNum);
+        if constexpr (0) {
+            for (int b = 0; b < kBatchSize; ++b) {
+                for (int h = 0; h < KvHeadNum; ++h) {
+                    for (int s = 0; s < kContextLen; ++s) {
+                        for (int d = 0; d < kHeadDim; ++d) {
+                            std::cout << std::setw(7) << std::setprecision(4) << std::fixed
+                                      << (float)k_cache[b * KvHeadNum * kContextLen * kHeadDim
+                                                        + h * kContextLen * kHeadDim + s * kHeadDim + d]
+                                      << " ";
+                        }
+                        std::cout << "\n";
+                    }
+                    std::cout << "\n";
+                }
+                std::cout << "\n";
+            }
+            std::exit(0);
+        }
+
+        cudaMemset2DAsync(v_cache.data().get() + kSequenceLen * kHeadDim,
+                          sizeof(half) * kContextLen * kHeadDim,
+                          0,
+                          sizeof(half) * kHeadDim,
+                          kBatchSize * KvHeadNum);
+    }
+
+    thrust::universal_vector  k_blocks;
+    thrust::universal_vector k_ptrs;
+    thrust::universal_vector   cu_block_cnts;
+
+    TestBlocks(k_cache, k_blocks, k_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
+
+    thrust::universal_vector  v_blocks;
+    thrust::universal_vector v_ptrs;
+
+    TestBlocks(v_cache, v_blocks, v_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
+
+    thrust::universal_vector  k_cache_ref = k_cache;
+    thrust::universal_vector  v_cache_ref = v_cache;
+    thrust::universal_vector  output_ref  = output;
+    thrust::universal_vector k_cache_ref_ptrs(kBatchSize);
+    thrust::universal_vector v_cache_ref_ptrs(kBatchSize);
+
+    cudaDeviceSynchronize();
+
+    for (int i = 0; i < kBatchSize; ++i) {
+        sequence_length[i]  = kSequenceLen;
+        context_length[i]   = kContextLen;
+        k_cache_ptrs[i]     = k_cache.data().get() + i * k_cache.size() / kBatchSize;
+        v_cache_ptrs[i]     = v_cache.data().get() + i * v_cache.size() / kBatchSize;
+        k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
+        v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
+
+        // align(k_cache_ptrs[i], 256);
+        // align(v_cache_ptrs[i], 256);
+    }
+
+    // getchar();
+
+    params.out    = output_ref.data().get();
+    params.q      = qkv.data().get();
+    params.k      = params.q + kHeadNum * kHeadDim;
+    params.v      = params.k + KvHeadNum * kHeadDim;
+    params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;
+
+    params.batch_size    = kBatchSize;
+    params.max_seq_len   = kSequenceLen;
+    params.cu_block_cnts = cu_block_cnts.data().get();
+
+    printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
+    params.k_cache_block_ptrs  = (void**)k_ptrs.data().get();
+    params.v_cache_block_ptrs  = (void**)v_ptrs.data().get();
+    params.kv_cache_block_size = kBlockSz;
+
+    params.finished       = finished.data().get();
+    params.context_length = context_length.data().get();
+    params.layer_offset   = 0;
+
+    params.num_heads     = kHeadNum;
+    params.num_kv_heads  = KvHeadNum;
+    params.size_per_head = kHeadDim;
+    params.inv_sqrt_dh   = 1.f / std::sqrt((float)params.size_per_head);
+
+    params.rotary_embedding_dim  = kHeadDim;
+    params.rotary_embedding_base = 10000.f;
+
+    params.partial_L = partial_L.data().get();
+    params.partial_M = partial_M.data().get();
+    params.partial_O = partial_O.data().get();
+
+    params.max_split_k = kMaxSplitK;
+    params.arch        = 80;
+
+    for (int i = 0; i < kTestIter; ++i) {
+        mmha_ft_reference(params,
+                          (half**)k_cache_ref_ptrs.data().get(),
+                          (half**)v_cache_ref_ptrs.data().get(),
+                          sequence_length.data().get(),
+                          kContextLen,
+                          cudaStream_t{});
+    }
+
+    cudaDeviceSynchronize();
+    if (auto err = cudaGetLastError(); err != cudaSuccess) {
+        std::cout << cudaGetErrorString(err) << "\n";
+        return -1;
+    }
+    std::cout << "---------------------------------------------------\n";
+
+    params.out = output.data().get();
+
+    std::vector> outputs;
+
+    for (int i = 0; i < std::max(kTestIter, 1); ++i) {
+        DispatchDecoderMultiheadAttention(params);
+        if (auto err = cudaGetLastError(); err != cudaSuccess) {
+            std::cout << cudaGetErrorString(err) << "\n";
+            return -1;
+        }
+        if (1) {
+            outputs.push_back(output);
+        }
+    }
+
+    if (1) {
+        ConvertBlocksToLinear((const half**)k_ptrs.data().get(),
+                              k_cache.data().get(),
+                              cu_block_cnts.data().get(),
+                              context_length.data().get(),
+                              0,
+                              kBlockSz,
+                              kContextLen,
+                              KvHeadNum,
+                              kHeadDim,
+                              kBatchSize,
+                              0);
+        ConvertBlocksToLinear((const half**)v_ptrs.data().get(),
+                              v_cache.data().get(),
+                              cu_block_cnts.data().get(),
+                              context_length.data().get(),
+                              0,
+                              kBlockSz,
+                              kContextLen,
+                              KvHeadNum,
+                              kHeadDim,
+                              kBatchSize,
+                              0);
+    }
+
+    cudaDeviceSynchronize();
+
+    if (outputs.size() > 1) {
+        std::cout << "Evaluating consistency..." << std::endl;
+        for (size_t i = 1; i < outputs.size(); ++i) {
+            Compare(outputs[i].data().get(), outputs[0].data().get(), kHeadDim, kHeadDim, kHeadNum);
+        }
+    }
+
+    std::cout << "---------------------------------------------------\n";
+
+    Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadDim, kHeadNum, false);
+
+    // [H, S, D]
+
+    Compare(k_cache.data().get() + kSequenceLen * kHeadDim,
+            k_cache_ref.data().get() + kSequenceLen * kHeadDim,
+            kContextLen * kHeadDim,
+            kHeadDim,
+            KvHeadNum);
+
+    Compare(v_cache.data().get() + kSequenceLen * kHeadDim,
+            v_cache_ref.data().get() + kSequenceLen * kHeadDim,
+            kContextLen * kHeadDim,
+            kHeadDim,
+            KvHeadNum);
+
+    return 0;
+}
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
new file mode 100644
index 0000000000..098df7f1d4
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
@@ -0,0 +1,252 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "test_utils.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define _CG_ABI_EXPERIMENTAL
+#include 
+#include 
+#include 
+
+#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
+
+namespace turbomind {
+
+cublasHandle_t cublas_handle{};
+cudaStream_t   cublas_stream{};
+
+template
+void Compare(const T* src, const T* ref, size_t stride, int m, int n, bool show, float rtol, float atol)
+{
+    float asums{};
+    float rsums{};
+    int   outliers{};
+    for (int nn = 0; nn < n; ++nn) {
+        float abs_diff_sum{};
+        float rel_diff_sum{};
+        for (int mm = 0; mm < m; ++mm) {
+            auto x = float(src[nn * stride + mm]);
+            auto y = float(ref[nn * stride + mm]);
+            // if (show) {
+            //     std::cout << x << "\t" << y << std::endl;
+            // }
+            auto abs_diff = std::abs(x - y);
+            auto rel_diff = abs_diff / std::abs(y + 1e-6f);
+            if (abs_diff > atol + rtol * std::abs(y)) {
+                ++outliers;
+                if (show) {
+                    std::cout << nn << "," << mm << "\t" << x << "\t" << y << std::endl;
+                }
+            }
+            abs_diff_sum += abs_diff;
+            rel_diff_sum += rel_diff;
+        }
+        asums += abs_diff_sum / m;
+        rsums += rel_diff_sum / m;
+    }
+    std::cout << "abs_diff = " << asums / n << " rel_diff = " << rsums / n << " outliers = " << outliers / (float)n
+              << std::endl;
+}
+
+template void Compare(const half* src, const half* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
+template void
+Compare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
+
+void LoadBinary(const std::string& path, size_t size, void* dst)
+{
+    std::ifstream ifs(path, std::ios::binary | std::ios::in);
+    if (!ifs.is_open()) {
+        std::cerr << "failed to open " << path << "\n";
+        std::abort();
+    }
+    ifs.seekg(0, ifs.end);
+    auto actual_size_in_bytes = ifs.tellg();
+    ifs.seekg(0, ifs.beg);
+    if (size != actual_size_in_bytes) {
+        std::cerr << "[warning] file " << path << " has " << actual_size_in_bytes << " bytes, while " << size
+                  << " bytes is requested\n";
+    }
+    ifs.read((char*)dst, size);
+    std::cerr << "[info] " << path << " " << size << "\n";
+}
+
+namespace cg = cooperative_groups;
+
+__global__ void curand_init(curandState* state)
+{
+    auto tid = cg::this_grid().thread_rank();
+    curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);
+}
+
+template
+__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)
+{
+    auto grid = cg::this_grid();
+    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
+        float tmp = curand_uniform(state + grid.thread_rank());
+        result[i] = T(scale * tmp + shift);
+    }
+}
+
+template
+__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)
+{
+    auto grid = cg::this_grid();
+    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
+        float tmp = curand_normal(state + grid.thread_rank());
+        result[i] = T(scale * tmp + shift);
+    }
+}
+
+__global__ void curand_bytes(curandState* state, size_t count, uint* result)
+{
+    auto grid = cg::this_grid();
+    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
+        result[i] = curand(state + grid.thread_rank());
+    }
+}
+
+struct RNG::Impl {
+
+    curandState* states{};
+
+    Impl()
+    {
+        cudaMalloc(&states, sizeof(curandState) * 64 * 64);
+        curand_init<<<64, 64>>>(states);
+    }
+
+    ~Impl()
+    {
+        cudaFree(states);
+    }
+
+    void GenerateUInt(uint* out, size_t count)
+    {
+        curand_bytes<<<64, 64>>>(states, count, out);
+    }
+
+    template
+    void GenerateUniform(T* out, size_t count, float scale, float shift)
+    {
+        curand_uniform<<<64, 64>>>(states, count, out, scale, shift);
+    }
+
+    template
+    void GenerateNormal(T* out, size_t count, float scale, float shift)
+    {
+        curand_normal<<<64, 64>>>(states, count, out, scale, shift);
+    }
+};
+
+RNG::RNG(): impl_(std::make_unique()) {}
+
+RNG::~RNG() = default;
+
+void RNG::GenerateUInt(uint* out, size_t count)
+{
+    impl_->GenerateUInt(out, count);
+}
+
+template
+void RNG::GenerateUniform(T* out, size_t count, float scale, float shift)
+{
+    std::cout << count << std::endl;
+    impl_->GenerateUniform(out, count, scale, shift);
+}
+
+template
+void RNG::GenerateNormal(T* out, size_t count, float scale, float shift)
+{
+    impl_->GenerateNormal(out, count, scale, shift);
+}
+
+template void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);
+template void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);
+
+template void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);
+template void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);
+
+template
+struct SATypeConverter {
+    using Type = T;
+};
+
+template<>
+struct SATypeConverter {
+    using Type = uint16_t;
+};
+
+template
+void mmha_ft_reference(const DecoderMultiHeadAttentionParams& p,
+                       T**                                       per_sample_k_cache,
+                       T**                                       per_sample_v_cache,
+                       const int*                                sequence_length,
+                       int                                       max_memory_len,
+                       cudaStream_t                              st)
+{
+    using DataType = typename SATypeConverter::Type;
+
+    // Prepare the parameters.
+    Masked_multihead_attention_params params{};
+    params.q_bias = reinterpret_cast(p.q_bias);
+    params.k_bias = reinterpret_cast(p.k_bias);
+    params.v_bias = reinterpret_cast(p.v_bias);
+
+    // Set the output buffer.
+    params.out = reinterpret_cast(p.out);
+
+    // Set the input buffers.
+    // [B, nH + kvH, D]
+    params.q = reinterpret_cast(p.q);
+    params.k = reinterpret_cast(p.k);
+    params.v = reinterpret_cast(p.v);
+
+    params.stride   = p.stride;
+    params.finished = (bool*)p.finished;
+
+    params.k_cache_per_sample         = reinterpret_cast(per_sample_k_cache);
+    params.v_cache_per_sample         = reinterpret_cast(per_sample_v_cache);
+    params.kv_cache_per_sample_offset = p.layer_offset;
+    params.batch_size                 = p.batch_size;
+    params.beam_width                 = 1;
+    params.memory_max_len             = max_memory_len;
+    params.prefix_prompt_lengths      = 0;
+    params.max_prefix_prompt_length   = 0;
+    params.length_per_sample          = sequence_length;  // max_input_length + current output length
+
+    for (int i = 0; i < p.batch_size; ++i) {
+        params.timestep = std::max(sequence_length[i], params.timestep);
+    }
+
+    std::cout << "timestep = " << params.timestep << "\n";
+
+    params.num_heads    = p.num_heads;
+    params.num_kv_heads = p.num_kv_heads;
+
+    params.hidden_size_per_head    = p.size_per_head;
+    params.rotary_embedding_dim    = p.rotary_embedding_dim;
+    params.max_position_embeddings = p.max_position_embeddings;
+    params.use_dynamic_ntk         = false;
+    params.use_logn_attn           = p.use_logn_attn;
+
+    // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
+    params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * 1.f);
+
+    params.int8_mode = 0;
+
+    masked_multihead_attention(params, st);
+}
+
+template void mmha_ft_reference(const DecoderMultiHeadAttentionParams& params,
+                                half**                                       per_sample_k_cache,
+                                half**                                       per_sample_v_cache,
+                                const int*                                   sequence_length,
+                                int                                          max_memory_len,
+                                cudaStream_t                                 st);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_utils.h b/src/turbomind/kernels/decoder_multihead_attention/test_utils.h
new file mode 100644
index 0000000000..caf81784d2
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_utils.h
@@ -0,0 +1,43 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "decoder_multihead_attention.h"
+#include "src/turbomind/macro.h"
+#include 
+#include 
+
+namespace turbomind {
+
+template
+void Compare(
+    const T* src, const T* ref, size_t stride, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);
+
+void LoadBinary(const std::string& path, size_t size, void* dst);
+
+class RNG {
+public:
+    RNG();
+    ~RNG();
+    void GenerateUInt(uint* out, size_t count);
+
+    template
+    void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);
+
+    template
+    void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);
+
+private:
+    struct Impl;
+    std::unique_ptr impl_;
+};
+
+template
+void mmha_ft_reference(const DecoderMultiHeadAttentionParams& params,
+                       T**                                       per_sample_k_cache,
+                       T**                                       per_sample_v_cache,
+                       const int*                                sequence_length,
+                       int                                       max_memory_len,
+                       cudaStream_t                              st);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
new file mode 100644
index 0000000000..f4c2be1da2
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
@@ -0,0 +1,98 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "../gemm_s_f16/common.h"
+
+namespace turbomind {
+
+template
+struct ThreadMapQ {
+    static constexpr int kWarpCount = WarpCount;
+    static constexpr int kAccessC   = AccessC;
+
+    static constexpr int kWarpThreadC = C / kAccessC;
+    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
+
+    static_assert(kWarpThreadC <= WARP_SIZE);
+
+    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;  // C
+    static constexpr int kWarpAccessS = kWarpThreadS;
+
+    static constexpr int kWarpIterC = C / kWarpAccessC;  // 1
+    static constexpr int kWarpIterS = S / kWarpAccessS;
+
+    static constexpr int kWarpC = 1;
+    static constexpr int kWarpS = kWarpCount;
+
+    static constexpr int kIterC = kWarpIterC / kWarpC;  // 1
+    static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
+
+    static constexpr int kFootprintC = kWarpAccessC * kIterC;  // C
+    static constexpr int kFootprintS = kWarpAccessS * kIterS;
+
+    static constexpr int kDeltaC = kWarpAccessC;
+    static constexpr int kDeltaS = kWarpAccessS;
+
+    __device__ static int2 get_offset(int warp_id, int lane_id)
+    {
+        int warp_offset_c = warp_id % kWarpC;
+        int warp_offset_s = warp_id / kWarpC;
+
+        int warp_thread_offset_c = lane_id % kWarpThreadC;
+        int warp_thread_offset_s = lane_id / kWarpThreadC;
+
+        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
+        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
+
+        return {cta_thread_offset_c, cta_thread_offset_s};
+    }
+};
+
+template
+struct ThreadMapKv {
+    static constexpr int kC = C;
+    static constexpr int kS = S;
+
+    static constexpr int kWarpCount = WarpCount;
+    static constexpr int kAccessC   = AccessC;
+
+    static constexpr int kWarpThreadC = WarpThreadC;
+    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
+
+    static_assert(kWarpThreadC <= WARP_SIZE);
+
+    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
+    static constexpr int kWarpAccessS = kWarpThreadS;
+
+    static constexpr int kWarpIterC = C / kWarpAccessC;
+    static constexpr int kWarpIterS = S / kWarpAccessS;
+
+    static constexpr int kWarpC = 1;
+    static constexpr int kWarpS = kWarpCount;
+
+    static constexpr int kIterC = kWarpIterC / kWarpC;
+    static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
+
+    static constexpr int kFootprintC = kWarpAccessC * kIterC;
+    static constexpr int kFootprintS = kWarpAccessS * kIterS;
+
+    static constexpr int kDeltaC = kWarpAccessC;
+    static constexpr int kDeltaS = kWarpAccessS;
+
+    __device__ static int2 get_offset(int warp_id, int lane_id)
+    {
+        int warp_offset_c = warp_id % kWarpC;
+        int warp_offset_s = warp_id / kWarpC;
+
+        int warp_thread_offset_c = lane_id % kWarpThreadC;
+        int warp_thread_offset_s = lane_id / kWarpThreadC;
+
+        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
+        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
+
+        return {cta_thread_offset_c, cta_thread_offset_s};
+    }
+};
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/gemm_s_f16/common.h b/src/turbomind/kernels/gemm_s_f16/common.h
index 1556982cf1..8dd4f3bb46 100644
--- a/src/turbomind/kernels/gemm_s_f16/common.h
+++ b/src/turbomind/kernels/gemm_s_f16/common.h
@@ -236,15 +236,88 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
 
 template
 struct Array {
-    T a[N];
 
-    __device__ __host__ constexpr T& operator[](int i) noexcept
+    using value_type      = T;
+    using size_type       = int;
+    using difference_type = int;
+    using reference       = value_type&;
+    using const_reference = const value_type&;
+    using pointer         = value_type*;
+    using const_pointer   = const value_type*;
+    using iterator        = pointer;
+    using const_iterator  = const_pointer;
+
+    static_assert(N > 0);
+
+    T __a[N];
+
+    __device__ __host__ constexpr reference operator[](size_type i) noexcept
+    {
+        return __a[i];
+    }
+    __device__ __host__ constexpr const_reference operator[](size_type i) const noexcept
+    {
+        return __a[i];
+    }
+
+    __device__ __host__ constexpr reference front() noexcept
+    {
+        return *begin();
+    }
+
+    __device__ __host__ constexpr const_reference front() const noexcept
+    {
+        return *begin();
+    }
+
+    __device__ __host__ constexpr reference back() noexcept
+    {
+        return *(end() - 1);
+    }
+
+    __device__ __host__ constexpr const_reference back() const noexcept
+    {
+        return *(end() - 1);
+    }
+
+    __device__ __host__ constexpr pointer data() noexcept
     {
-        return a[i];
+        return &__a[0];
     }
-    __device__ __host__ constexpr const T& operator[](int i) const noexcept
+
+    __device__ __host__ constexpr const_pointer data() const noexcept
+    {
+        return &__a[0];
+    }
+
+    __device__ __host__ constexpr iterator begin() noexcept
+    {
+        return data();
+    }
+
+    __device__ __host__ constexpr const_iterator begin() const noexcept
+    {
+        return data();
+    }
+
+    __device__ __host__ constexpr iterator end() noexcept
+    {
+        return data() + N;
+    }
+
+    __device__ __host__ constexpr const_iterator end() const noexcept
+    {
+        return data() + N;
+    }
+
+    __device__ __host__ constexpr std::integral_constant size() const noexcept
+    {
+        return {};
+    }
+
+    __device__ __host__ constexpr std::false_type empty() const noexcept
     {
-        return a[i];
+        return {};
     }
 };
 
diff --git a/src/turbomind/kernels/gemm_s_f16/cta_iterator.h b/src/turbomind/kernels/gemm_s_f16/cta_iterator.h
index 48cf9ace2c..0c13ae3116 100644
--- a/src/turbomind/kernels/gemm_s_f16/cta_iterator.h
+++ b/src/turbomind/kernels/gemm_s_f16/cta_iterator.h
@@ -3,6 +3,7 @@
 #pragma once
 
 #include "common.h"
+#include 
 #include 
 
 namespace turbomind {
@@ -236,7 +237,13 @@ struct IteratorA {
 
     __device__ void prefetch(bool mask)
     {
+#if TURBOMIND_ARCH_SM80
         cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
+#else
+        if (mask) {
+            *(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
+        }
+#endif
     }
 };
 
@@ -417,7 +424,13 @@ struct IteratorQ {
 
     __device__ void prefetch(bool mask)
     {
+#if TURBOMIND_ARCH_SM80
         cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
+#else
+        if (mask) {
+            *(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
+        }
+#endif
     }
 };
 
@@ -613,8 +626,14 @@ struct IteratorB {
 
     __device__ void prefetch(bool mask)
     {
+#if TURBOMIND_ARCH_SM80
         cp_async_cg_B(
             smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
+#else
+        if (is_valid_n_ && mask) {
+            *(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
+        }
+#endif
     }
 };
 
diff --git a/src/turbomind/kernels/gemm_s_f16/gemm_template.h b/src/turbomind/kernels/gemm_s_f16/gemm_template.h
index a429ba8536..0e3e9bca9d 100644
--- a/src/turbomind/kernels/gemm_s_f16/gemm_template.h
+++ b/src/turbomind/kernels/gemm_s_f16/gemm_template.h
@@ -9,6 +9,23 @@
 
 namespace turbomind {
 
+__inline__ __device__ void
+mma_m16n8k8_row_col(Array& d, const Array& a, const Array& b, Array& c)
+{
+#if TURBOMIND_ARCH_SM75
+    uint32_t const* A = reinterpret_cast(&a);
+    uint32_t const* B = reinterpret_cast(&b);
+    float const*    C = reinterpret_cast(&c);
+    float*          D = reinterpret_cast(&d);
+    asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, "
+        "{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
+        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
+        : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
+#else
+    assert(TURBOMIND_ARCH_SM75);
+#endif
+}
+
 __inline__ __device__ void
 mma_m16n8k16_row_col(Array& d, const Array& a, const Array& b, Array& c)
 {
@@ -22,7 +39,10 @@ mma_m16n8k16_row_col(Array& d, const Array& a, const Array* _a = (const Array*)&a;
+    const Array* _b = (const Array*)&b;
+    mma_m16n8k8_row_col(d, _a[0], _b[0], c);
+    mma_m16n8k8_row_col(d, _a[1], _b[1], d);
 #endif
 }
 
diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu
index 4877bdb1a0..f7ebfeff03 100644
--- a/src/turbomind/kernels/sampling_penalty_kernels.cu
+++ b/src/turbomind/kernels/sampling_penalty_kernels.cu
@@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T*                    logits,
     dim3   grid(local_batch_size);
     size_t smem_size = step * (sizeof(float) + sizeof(int));
     if (penalty_type == RepetitionPenaltyType::Additive) {
+        check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty,
+                                              cudaFuncAttributeMaxDynamicSharedMemorySize,
+                                              smem_size));
         batchApplyRepetitionPenalty<<>>(
             logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
     }
     else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
+        check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty,
+                                              cudaFuncAttributeMaxDynamicSharedMemorySize,
+                                              smem_size));
         batchApplyRepetitionPenalty<<>>(
             logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
     }
diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu
index b2450c8675..040f7204bf 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.cu
+++ b/src/turbomind/kernels/unfused_attention_kernels.cu
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
+#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
 #include "src/turbomind/kernels/reduce_kernel_utils.cuh"
 #include "src/turbomind/kernels/unfused_attention_kernels.h"
 #include "src/turbomind/utils/cuda_type_utils.cuh"
@@ -854,19 +854,20 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                    T* v_buf,
                                                    T* QKV,
                                                    const T* __restrict qkv_bias,
-                                                   const int* padding_offset,
-                                                   const int* history_length,
-                                                   const int* input_length,
-                                                   int        batch_size,
-                                                   int        seq_len,
-                                                   int        head_num,
-                                                   int        kv_head_num,
-                                                   int        size_per_head,
-                                                   int        rotary_embedding_dim,
-                                                   float      rotary_embedding_base,
-                                                   int        max_position_embeddings,
-                                                   bool       use_dynamic_ntk,
-                                                   bool       use_logn_attn)
+                                                   const int*   padding_offset,
+                                                   const int*   context_length,
+                                                   const int*   input_length,
+                                                   const float* rope_theta,
+                                                   int          batch_size,
+                                                   int          seq_len,
+                                                   int          head_num,
+                                                   int          kv_head_num,
+                                                   int          size_per_head,
+                                                   int          rotary_embedding_dim,
+                                                   float        rotary_embedding_base,
+                                                   int          max_position_embeddings,
+                                                   bool         use_dynamic_ntk,
+                                                   bool         use_logn_attn)
 {
     // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
     // QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
@@ -907,12 +908,18 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
     Vec_t q, k, v;
     Vec_t q_bias, k_bias, v_bias;
 
+    using Vec = Array;
+
+    static_assert(sizeof(Vec_t) == sizeof(Vec));
+
+    using namespace ops;
+
     // load Q and apply bias
     if (!is_masked) {
         q = *reinterpret_cast(&QKV[src_q_idx]);
         if (qkv_bias) {
-            q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]);
-            q      = mmha::add(q, q_bias);
+            q_bias  = *reinterpret_cast(&qkv_bias[hidden_idx]);
+            (Vec&)q = (Vec&)q + (Vec&)q_bias;
         }
     }
 
@@ -921,35 +928,32 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
         k = *reinterpret_cast(&QKV[src_k_idx]);
         v = *reinterpret_cast(&QKV[src_v_idx]);
         if (qkv_bias) {
-            k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + k_offset]);
-            v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + v_offset]);
-            k      = mmha::add(k, k_bias);
-            v      = mmha::add(v, v_bias);
+            k_bias  = *reinterpret_cast(&qkv_bias[hidden_idx + k_offset]);
+            v_bias  = *reinterpret_cast(&qkv_bias[hidden_idx + v_offset]);
+            (Vec&)k = (Vec&)k + (Vec&)k_bias;
+            (Vec&)v = (Vec&)v + (Vec&)v_bias;
         }
     }
 
-    const int history_len = history_length[batch_idx];
-    const int context_len = history_len + input_length[batch_idx];
+    const int context_len = context_length[batch_idx];
+    const int history_len = context_len - input_length[batch_idx];
     const int timestep    = history_len + seq_idx;
 
-    if (use_dynamic_ntk) {
-        rotary_embedding_base = mmha::rotary_embedding_get_base(
-            context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base);
+    if (rope_theta) {
+        rotary_embedding_base = rope_theta[batch_idx];
     }
 
-    // TODO: unused computation on k if GQA is used
-    mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep);
+    RotaryEmbedding rotary_emb(rotary_embedding_base, rotary_embedding_dim, timestep, {tidx * vec_size, 0});
+    rotary_emb.apply((Array&)q);
+
+    if (head_idx < kv_head_num) {
+        rotary_emb.apply((Array&)k);
+    }
 
     if (use_logn_attn) {
         // +1 to convert to context length at the timestep
-        float logn_scaling = mmha::logn_attn_get_scaling(timestep + 1, max_position_embeddings);
-        if constexpr (std::is_same_v) {
-            q = mmha::mul(logn_scaling, q);
-        }
-        else if constexpr (std::is_same_v) {
-            half tmp = __float2half(logn_scaling);
-            q        = mmha::mul((uint16_t&)tmp, q);
-        }
+        LogNScaling logn_scaling(timestep + 1, max_position_embeddings);
+        logn_scaling.apply((Array&)q);
     }
 
     if (!is_masked && !q_buf) {  // also skip modifying QKV if q/k/v_buf are present
@@ -982,8 +986,9 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                                                              QKV,                      \
                                                                                              qkv_bias,                 \
                                                                                              padding_offset,           \
-                                                                                             history_length,           \
+                                                                                             context_length,           \
                                                                                              input_length,             \
+                                                                                             rope_theta,               \
                                                                                              batch_size,               \
                                                                                              seq_len,                  \
                                                                                              head_num,                 \
@@ -1002,8 +1007,9 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     T*           QKV,
                                     const T*     qkv_bias,
                                     const int*   padding_offset,
-                                    const int*   history_length,
+                                    const int*   context_length,
                                     const int*   input_length,
+                                    const float* rope_theta,
                                     const int    batch_size,
                                     const int    seq_len,
                                     const int    token_num,
@@ -1034,6 +1040,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                                  const int*   padding_offset,                                          \
                                                  const int*   history_length,                                          \
                                                  const int*   input_length,                                            \
+                                                 const float* rope_theta,                                              \
                                                  const int    batch_size,                                              \
                                                  const int    seq_len,                                                 \
                                                  const int    token_num,                                               \
diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h
index b5c37b5d48..758fe7fba0 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.h
+++ b/src/turbomind/kernels/unfused_attention_kernels.h
@@ -70,8 +70,9 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     T*           QKV,
                                     const T*     qkv_bias,
                                     const int*   padding_offset,
-                                    const int*   history_length,
+                                    const int*   context_length,
                                     const int*   input_length,
+                                    const float* rope_theta,
                                     const int    batch_size,
                                     const int    seq_len,
                                     const int    token_num,
diff --git a/src/turbomind/layers/DynamicDecodeLayer.cc b/src/turbomind/layers/DynamicDecodeLayer.cc
index b4932eae2b..baefc2c2e2 100644
--- a/src/turbomind/layers/DynamicDecodeLayer.cc
+++ b/src/turbomind/layers/DynamicDecodeLayer.cc
@@ -188,6 +188,7 @@ void DynamicDecodeLayer::forward(TensorMap* output_tensors, TensorMap* input_
      *
      * output_tensors:
      *   \param  output_ids [max_seq_len, batch_size]
+     *   \param  curand_state [local_batch_size]
      *   \param  finished [batch_size * beam_width], optional
      *   \param  should_stop [1] on cpu
      *   \param  cum_log_probs [batch_size * beam_width], necessary in beam search
@@ -276,7 +277,8 @@ void DynamicDecodeLayer::forward(TensorMap* output_tensors, TensorMap* input_
                 {"input_lengths", input_lengths.slice({local_batch_size, beam_width}, local_batch_offset)});
         }
 
-        TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}});
+        TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")},  //
+                                         {"curand_state", output_tensors->at("curand_state")}});
         if (output_tensors->isExist("sequence_length")) {
             Tensor sequence_length = output_tensors->at("sequence_length");
             decode_output_tensors.insert(
diff --git a/src/turbomind/layers/DynamicDecodeLayer.h b/src/turbomind/layers/DynamicDecodeLayer.h
index cae2118c19..a6fdbc3aa5 100644
--- a/src/turbomind/layers/DynamicDecodeLayer.h
+++ b/src/turbomind/layers/DynamicDecodeLayer.h
@@ -53,15 +53,6 @@ class DynamicDecodeLayer: public BaseLayer {
     int* h_pinned_finished_sum_ = nullptr;
 
 public:
-    curandState_t* topk_curandstate_buf()
-    {
-        return static_cast*>(topk_decode_)->curandstate_buf();
-    }
-    curandState_t* topp_curandstate_buf()
-    {
-        return static_cast*>(topp_decode_)->curandstate_buf();
-    }
-
     DynamicDecodeLayer(size_t           vocab_size,
                        size_t           vocab_size_padded,
                        int              end_id,
diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc
index c8c36f65da..1c9ae099d9 100644
--- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc
+++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc
@@ -30,10 +30,6 @@ template
 void BaseSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    curandstate_buf_ = reinterpret_cast(
-        allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false));
-    random_seeds_buf_ = reinterpret_cast(
-        allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false));
     temperature_buf_ =
         reinterpret_cast(allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false));
     repetition_penalty_buf_ =
@@ -58,8 +54,6 @@ void BaseSamplingLayer::freeBuffer()
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     if (is_allocate_buffer_) {
-        allocator_->free((void**)(&curandstate_buf_));
-        allocator_->free((void**)(&random_seeds_buf_));
         allocator_->free((void**)(&temperature_buf_));
         allocator_->free((void**)(&repetition_penalty_buf_));
         allocator_->free((void**)(&min_lengths_buf_));
@@ -128,32 +122,6 @@ void BaseSamplingLayer::setup(const size_t batch_size, const size_t beam_widt
     Tensor runtime_top_p = runtime_args->isExist("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor();
     allocateBuffer(batch_size, runtime_top_k, runtime_top_p);
 
-    // If runtime argument has single random seed, using this random seed to initialize the random table of all
-    // sentences. If the argument has [batch_size] random seeds, initializing the random table by different random seeds
-    // respectively. If no random seed, initialize the random table of all sentences by 0 directly.
-    if (runtime_args->isExist("random_seed")) {
-        Tensor random_seeds = runtime_args->at("random_seed");
-        FT_CHECK_WITH_INFO(random_seeds.shape.size() == 1
-                               && (random_seeds.size() == 1 || random_seeds.size() == batch_size),
-                           fmtstr("random_seeds must be of shape [1] or [batch_size(%ld)], got random_seeds.shape=%s",
-                                  batch_size,
-                                  vec2str(random_seeds.shape).c_str()));
-        if (random_seeds.size() == 1) {
-            invokeCurandInitialize(curandstate_buf_, batch_size, random_seeds.getVal(), stream_);
-            sync_check_cuda_error();
-        }
-        else {
-            unsigned long long* random_seed_ptr = random_seeds.getPtr();
-            cudaAutoCpy(random_seeds_buf_, random_seed_ptr, batch_size, stream_);
-            invokeCurandBatchInitialize(curandstate_buf_, batch_size, random_seeds_buf_, stream_);
-            sync_check_cuda_error();
-        }
-    }
-    else {
-        // Initialize curand states using the default seed 0.
-        invokeCurandInitialize(curandstate_buf_, batch_size, 0, stream_);
-    }
-
     // Setup penalties.
     const float default_temperature = 1.0f;
     Tensor      temperature         = runtime_args->isExist("temperature") ?
diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h
index 71dff01834..29462e16a2 100644
--- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h
+++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h
@@ -33,10 +33,8 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
     size_t vocab_size_;
     size_t vocab_size_padded_;
 
-    size_t              sampling_workspace_size_;
-    void*               sampling_workspace_ = nullptr;
-    curandState_t*      curandstate_buf_    = nullptr;
-    unsigned long long* random_seeds_buf_   = nullptr;
+    size_t sampling_workspace_size_;
+    void*  sampling_workspace_ = nullptr;
 
     float* temperature_buf_        = nullptr;
     float* repetition_penalty_buf_ = nullptr;
@@ -59,11 +57,6 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
     virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p);
 
 public:
-    curandState_t* curandstate_buf()
-    {
-        return curandstate_buf_;
-    }
-
     BaseSamplingLayer(size_t             max_batch_size,
                       size_t             vocab_size,
                       size_t             vocab_size_padded,
diff --git a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
index 614b1a68ce..0bcd7e12a2 100644
--- a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
+++ b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
@@ -16,6 +16,7 @@
  */
 
 #include 
+#include 
 
 #include "src/turbomind/kernels/sampling_topk_kernels.h"
 #include "src/turbomind/kernels/sampling_topp_kernels.h"
@@ -199,6 +200,7 @@ void TopKSamplingLayer::runSampling(TensorMap* output_tensors, TensorMap* inp
 
     // output_tensors:
     //      output_ids [max_seq_len, batch_size]
+    //      curand_state [local_batch_size]
     //      finished [local_batch_size], optional
     //      sequence_length [local_batch_size], optional
     //      cum_log_probs [batch_size], must be float*, optional
@@ -255,7 +257,7 @@ void TopKSamplingLayer::runSampling(TensorMap* output_tensors, TensorMap* inp
         output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(),
         cum_log_probs,
         output_log_probs,
-        curandstate_buf_ + ite * local_batch_size,
+        output_tensors->at("curand_state").getPtr() + ite * local_batch_size,
         (int)runtime_max_top_k_,  // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy.
         (int*)(runtime_top_k_buf_ + ite * local_batch_size),
         1.0f,  // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy.
diff --git a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.h b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.h
index 013a6dfccb..b306e33cd5 100644
--- a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.h
+++ b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.h
@@ -40,8 +40,6 @@ class TopKSamplingLayer: public BaseSamplingLayer {
 
     using BaseSamplingLayer::sampling_workspace_size_;
     using BaseSamplingLayer::sampling_workspace_;
-    using BaseSamplingLayer::curandstate_buf_;
-    using BaseSamplingLayer::random_seeds_buf_;
     using BaseSamplingLayer::skip_decode_buf_;
     using BaseSamplingLayer::skip_decode_;
     using BaseSamplingLayer::skip_any_;
diff --git a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
index 8e7e97314f..b3fa0767cd 100644
--- a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
+++ b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
@@ -132,7 +132,7 @@ void TopPSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
                           topp_id_vals_buf_,
                           topp_offset_buf_,
                           begin_topp_offset_buf_,
-                          curandstate_buf_,
+                          nullptr,  // not used when workspace is null
                           batch_size,
                           vocab_size_padded_,
                           nullptr,
@@ -267,6 +267,7 @@ void TopPSamplingLayer::runSampling(TensorMap* output_tensors, TensorMap* inp
 
     * output_tensors:
     *   \param  output_ids [max_seq_len, batch_size]
+    *   \param  curand_state [local_batch_size]
     *   \param  finished [local_batch_size], optional
     *   \param  sequence_length [local_batch_size], optional
     *   \param  cum_log_probs [batch_size], must be float*, optional
@@ -319,7 +320,7 @@ void TopPSamplingLayer::runSampling(TensorMap* output_tensors, TensorMap* inp
         topp_id_vals_buf_,
         topp_offset_buf_,
         begin_topp_offset_buf_,
-        curandstate_buf_ + ite * local_batch_size,
+        output_tensors->at("curand_state").getPtr() + ite * local_batch_size,
         local_batch_size,
         vocab_size_padded_,
         input_tensors->at("end_id").getPtr(),
diff --git a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.h b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.h
index 8005e83215..5e91e18236 100644
--- a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.h
+++ b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.h
@@ -48,8 +48,6 @@ class TopPSamplingLayer: public BaseSamplingLayer {
 
     using BaseSamplingLayer::sampling_workspace_size_;
     using BaseSamplingLayer::sampling_workspace_;
-    using BaseSamplingLayer::curandstate_buf_;
-    using BaseSamplingLayer::random_seeds_buf_;
     using BaseSamplingLayer::skip_decode_buf_;
     using BaseSamplingLayer::skip_decode_;
     using BaseSamplingLayer::skip_any_;
diff --git a/src/turbomind/models/llama/Barrier.h b/src/turbomind/models/llama/Barrier.h
index 6eb0df9585..509290a5fb 100644
--- a/src/turbomind/models/llama/Barrier.h
+++ b/src/turbomind/models/llama/Barrier.h
@@ -2,6 +2,7 @@
 
 #pragma once
 
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
 #ifndef _MSC_VER
 #include 
@@ -33,10 +34,11 @@ class Barrier {
 
 class Barrier {
 public:
-    Barrier(unsigned count)
+    Barrier(unsigned count): count_(count)
     {
-        TM_LOG_INFO("Barrier(%d)", (int)count);
-        pthread_barrier_init(&barrier_, nullptr, count);
+        if (count_ > 1) {
+            pthread_barrier_init(&barrier_, nullptr, count);
+        }
     }
 
     Barrier(const Barrier&) = delete;
@@ -46,15 +48,20 @@ class Barrier {
 
     void wait()
     {
-        pthread_barrier_wait(&barrier_);
+        if (count_ > 1) {
+            pthread_barrier_wait(&barrier_);
+        }
     }
 
     ~Barrier()
     {
-        pthread_barrier_destroy(&barrier_);
+        if (count_ > 1) {
+            pthread_barrier_destroy(&barrier_);
+        }
     }
 
 private:
+    const int         count_;
     pthread_barrier_t barrier_{};
 };
 
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
new file mode 100644
index 0000000000..e2e9c39dd5
--- /dev/null
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -0,0 +1,286 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "src/turbomind/models/llama/BlockManager.h"
+#include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/debug_utils.h"
+#include "src/turbomind/utils/logger.h"
+#include "src/turbomind/utils/string_utils.h"
+#include 
+#include 
+#include 
+
+namespace turbomind {
+
+BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator):
+    block_size_(block_size), allocator_(allocator)
+{
+    if (block_count < 1.) {
+        max_block_count_ = GetBlockCount(block_size, block_count);
+    }
+    else {
+        max_block_count_ = block_count;
+    }
+
+    if (chunk_size == 0) {
+        chunk_size_ = static_cast(std::sqrt(max_block_count_));
+    }
+    else if (chunk_size < 0) {
+        chunk_size_ = max_block_count_;
+    }
+    else {
+        chunk_size_ = chunk_size;
+    }
+
+    TM_LOG_INFO("[BlockManager] block_size = %lu MB", (unsigned long)block_size_ >> 20);
+    TM_LOG_INFO("[BlockManager] max_block_count = %d", max_block_count_);
+    TM_LOG_INFO("[BlockManager] chunk_size = %d", chunk_size_);
+
+    blocks_.reserve(max_block_count_);
+
+    active_ids_.reserve(max_block_count_);
+    cached_ids_.reserve(max_block_count_);
+    free_ids_.reserve(max_block_count_);
+
+    // pre-allocate first chunk
+    Malloc();
+    dbg(free_ids_);
+}
+
+BlockManager::~BlockManager()
+{
+    for (auto& chunk : chunks_) {
+        allocator_->free(&chunk);
+    }
+}
+
+bool BlockManager::Malloc()
+{
+    auto chunk_size = std::min(chunk_size_, max_block_count_ - blocks_.size());
+
+    if (!chunk_size) {
+        return false;
+    }
+
+    auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
+    if (!ptr) {
+        return false;
+    }
+
+    chunks_.push_back(ptr);
+
+    for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
+        auto& block     = blocks_.emplace_back();
+        block.use_count = 0;
+        block.id        = (int)blocks_.size() - 1;
+        block.timestamp = 0;
+        block.data      = ptr;
+
+        free_ids_.push_back(block.id);
+    }
+
+    return true;
+}
+
+size_t BlockManager::GetBlockCount(size_t block_size, double ratio)
+{
+    size_t free{};
+    size_t total{};
+    check_cuda_error(cudaMemGetInfo(&free, &total));
+    return static_cast(total * ratio) / block_size;
+}
+
+void BlockManager::Move(std::vector& src, const std::vector& delta, std::vector& dst)
+{
+    FT_CHECK(src.size() >= delta.size());
+    std::vector src1(src.size() - delta.size());
+    {
+        auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
+        FT_CHECK(end == src1.end());
+    }
+    src.swap(src1);
+
+    std::vector dst1(dst.size() + delta.size());
+    {
+        auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
+        FT_CHECK(end == dst1.end());
+    }
+    dst.swap(dst1);
+}
+
+auto BlockManager::Allocate(int count) -> std::pair
+{
+    while (free_ids_.size() < count) {
+        if (!Malloc()) {
+            throw std::runtime_error("out of memory");
+        }
+    }
+
+    BlockIds  block_ids(count);
+    UniqueIds unique_ids(count);
+
+    for (int i = 0; i < count; ++i) {
+        int   idx = free_ids_[i];
+        auto& b   = blocks_[idx];
+        FT_CHECK(is_free(b));  // pre-condition: uc == 0 && ts == 0
+        b.use_count = 1;
+        b.unique_id = unique_id_++;
+        FT_CHECK(is_active(b));  // post-condition
+        block_ids[i]  = idx;
+        unique_ids[i] = b.unique_id;
+    }
+
+    Move(free_ids_, block_ids, active_ids_);
+
+    dbg(free_ids_, active_ids_);
+
+    return {block_ids, unique_ids};
+}
+
+void BlockManager::Evict(int count)
+{
+    FT_CHECK(count <= cached_ids_.size());
+    std::vector idxs(cached_ids_);
+    // get first `count` cached ids according to timestamp
+    std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
+        return blocks_[i].timestamp < blocks_[j].timestamp;
+    });
+    idxs.resize(count);
+
+    // sort the retrieved ids
+    std::sort(idxs.begin(), idxs.end());
+
+    // set as free
+    for (const auto& idx : idxs) {
+        auto& b = blocks_[idx];
+        FT_CHECK(is_cached(b));
+        b.unique_id = 0;
+        b.timestamp = 0;
+        FT_CHECK(is_free(b));
+    }
+
+    Move(cached_ids_, idxs, free_ids_);
+
+    dbg(cached_ids_, free_ids_);
+}
+
+void BlockManager::Free(BlockIds ids)
+{
+    std::sort(ids.begin(), ids.end());
+
+    for (const auto& i : ids) {
+        auto& b = blocks_[i];
+        FT_CHECK(is_cached(b));  // uc == 0 && ts != 0
+        b.unique_id = 0;
+        b.timestamp = 0;
+        FT_CHECK(is_free(b));
+    }
+
+    Move(cached_ids_, ids, free_ids_);
+}
+
+int BlockManager::Unlock(const BlockIds& ids)
+{
+    BlockIds unlock;
+    unlock.reserve(ids.size());
+
+    for (const auto& i : ids) {
+        auto& b = blocks_[i];
+        FT_CHECK(is_active(b));  // pre-condition: uc > 0
+        if (--b.use_count == 0) {
+            unlock.push_back(b.id);
+            FT_CHECK(is_cached(b));  // post-condition
+        }
+    }
+
+    std::sort(unlock.begin(), unlock.end());
+
+    Move(active_ids_, unlock, cached_ids_);
+
+    dbg(active_ids_, cached_ids_);
+    return unlock.size();
+}
+
+int BlockManager::Lock(const BlockIds& ids)
+{
+    BlockIds lock;
+    lock.reserve(ids.size());
+
+    for (const auto& i : ids) {
+        auto& b = blocks_[i];
+        FT_CHECK(is_cached(b));
+        if (++b.use_count == 1) {
+            lock.push_back(i);
+            FT_CHECK(is_active(b));
+        }
+    }
+
+    std::sort(lock.begin(), lock.end());
+
+    Move(cached_ids_, lock, active_ids_);
+
+    // dbg(cached_ids_, active_ids_);
+
+    return lock.size();
+}
+
+void BlockManager::Touch(const BlockIds& ids)
+{
+    std::for_each(ids.crbegin(), ids.crend(), [this](int i) {
+        FT_CHECK(is_active(blocks_[i]));
+        blocks_[i].timestamp = timestamp_++;
+    });
+}
+
+int BlockManager::Verify(const std::vector& block_ids, const std::vector& unique_ids)
+{
+    FT_CHECK(block_ids.size() == unique_ids.size());
+    int valid = block_ids.size();
+    for (int i = 0; i < block_ids.size(); ++i) {
+        if (unique_id(block_ids[i]) != unique_ids[i]) {
+            valid = i;
+            break;
+        }
+    }
+    int miss = 0;
+    for (int i = valid; i < block_ids.size(); ++i) {
+        miss += (unique_id(block_ids[i]) != unique_ids[i]);
+    }
+    // All later blocks should have been invalidated
+    FT_CHECK_WITH_INFO(miss == (int)block_ids.size() - valid,
+                       fmtstr("count = %d, valid = %d, miss = %d", (int)block_ids.size(), valid, miss));
+    return valid;
+}
+
+Snapshot BlockManager::TakeSnapshot()
+{
+    std::vector use_count(blocks_.size());
+    for (const auto& idx : active_ids_) {
+        use_count[idx] = blocks_[idx].use_count;
+    }
+    return {active_count(), cached_count(), free_count(), std::move(use_count)};
+}
+
+std::ostream& operator<<(std::ostream& os, const BlockManager& manager)
+{
+    os << "block_size: " << manager.block_size_ << ", ";
+    os << "max_block_count: " << manager.max_block_count_ << ", ";
+    os << "chunk_size: " << manager.chunk_size_ << ", ";
+    os << "chunks: " << manager.chunks_.size() << ", ";
+    os << "active_ids: " << manager.active_ids_.size() << ", ";
+    os << "cached_ids: " << manager.cached_ids_.size() << ", ";
+    os << "free_ids: " << manager.free_ids_.size() << ", ";
+    os << "blocks: " << manager.blocks_.size() << ", ";
+    os << "unique_id: " << manager.unique_id_ << ", ";
+    os << "timestamp: " << manager.timestamp_ << ", ";
+    os << "allocator: " << manager.allocator_;
+    return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const Block& block)
+{
+    os << "id=" << block.id << ", use_count=" << block.use_count << ", unique_id=" << block.unique_id
+       << ", timestamp=" << block.timestamp << ", data=" << block.data;
+    return os;
+}
+
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h
new file mode 100644
index 0000000000..c9ec2d06dc
--- /dev/null
+++ b/src/turbomind/models/llama/BlockManager.h
@@ -0,0 +1,153 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "src/turbomind/utils/allocator.h"
+#include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/logger.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace turbomind {
+
+// [L, H, S, D]
+
+// [L, S/x, H, x, D]
+
+struct Block {
+    int      id;         // fixed linear id in the pool
+    int      use_count;  // active sequences using the block
+    uint64_t unique_id;  // unique for every block allocation
+    uint64_t timestamp;
+    void*    data;
+
+    friend std::ostream& operator<<(std::ostream& os, const Block& block);
+    friend std::string   to_string(const Block& b)
+    {
+        std::stringstream ss;
+        ss << b;
+        return ss.str();
+    }
+};
+
+using BlockIds  = std::vector;
+using UniqueIds = std::vector;
+
+inline bool is_active(const Block& block)
+{
+    // timestamp may be 0 for newly allocated block that has not been written
+    return block.use_count > 0;
+}
+
+inline bool is_cached(const Block& block)
+{
+    return block.use_count == 0 && block.timestamp != 0;
+}
+
+inline bool is_free(const Block& block)
+{
+    return block.use_count == 0 && block.timestamp == 0;
+}
+
+struct Snapshot {
+    int              active;
+    int              cached;
+    int              free;
+    std::vector use_count;
+};
+
+class BlockManager {
+public:
+    explicit BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator);
+
+    ~BlockManager();
+
+    // free -> active (use_count = 1, ref_count = 1)
+    [[nodiscard]] std::pair Allocate(int count);
+
+    // cached -> active (use_count += 1)
+    [[maybe_unused]] int Lock(const BlockIds& ids);
+
+    // active -> cached (use_count -= 1)
+    [[maybe_unused]] int Unlock(const BlockIds& ids);
+
+    // cached -> free (ref_count = 0)
+    void Evict(int count);
+
+    // cached -> free (ref_count -= 1)
+    void Free(BlockIds bs);
+
+    // increase timestamp in reversed order
+    void Touch(const BlockIds& bs);
+
+    [[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids);
+
+    Snapshot TakeSnapshot();
+
+    int max_block_count() const noexcept
+    {
+        return max_block_count_;
+    }
+
+    int active_count() const noexcept
+    {
+        return active_ids_.size();
+    }
+
+    int cached_count() const noexcept
+    {
+        return cached_ids_.size();
+    }
+
+    int free_count() const noexcept
+    {
+        return (max_block_count_ - blocks_.size()) + free_ids_.size();
+    }
+
+    Block& block(int idx)
+    {
+        return blocks_[idx];
+    }
+
+    int unique_id(int idx)
+    {
+        return blocks_[idx].unique_id;
+    }
+
+    friend std::ostream& operator<<(std::ostream& os, const BlockManager&);
+
+private:
+    static size_t GetBlockCount(size_t block_size, double ratio);
+
+    // move indices between sets
+    static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);
+
+    // allocate a chunk of blocks
+    bool Malloc();
+
+private:
+    size_t      block_size_;
+    int         max_block_count_{};
+    int         chunk_size_{};
+    IAllocator* allocator_;
+
+    std::vector chunks_;
+
+    BlockIds active_ids_;
+    BlockIds cached_ids_;
+    BlockIds free_ids_;
+
+    std::vector blocks_;  // < 100k
+
+    uint64_t unique_id_{1};
+    uint64_t timestamp_{1};
+};
+
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt
index 10b93fb9ec..6274625457 100644
--- a/src/turbomind/models/llama/CMakeLists.txt
+++ b/src/turbomind/models/llama/CMakeLists.txt
@@ -9,14 +9,13 @@ find_package(CUDAToolkit REQUIRED)
 add_library(Llama STATIC
         LlamaV2.cc
         LlamaBatch.cc
-        LlamaCacheManager.cc
-        LlamaContextDecoder.cc
-        LlamaContextAttentionLayer.cc
-        LlamaDecoderSelfAttentionLayer.cc
-        LlamaDecoder.cc
+        BlockManager.cc
+        SequenceManager.cc
         LlamaWeight.cc
         LlamaDecoderLayerWeight.cc
         LlamaFfnLayer.cc
+        unified_decoder.cc
+        unified_attention_layer.cc
         llama_kernels.cu
         llama_decoder_kernels.cu
         llama_utils.cu)
@@ -28,6 +27,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
         DynamicDecodeLayer
         activation_kernels
         decoder_masked_multihead_attention
+        decoder_multihead_attention
         bert_preprocess_kernels
         decoding_kernels
         unfused_attention_kernels
@@ -48,4 +48,11 @@ endif()
 
 add_executable(llama_gemm llama_gemm.cc)
 target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
+
 install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin)
+
+find_package(Catch2 3 QUIET)
+if (Catch2_FOUND)
+        add_executable(test_cache_manager test_cache_manager.cc)
+        target_link_libraries(test_cache_manager PRIVATE Llama Catch2::Catch2WithMain)
+endif ()
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 5d8d7d0411..8ca8917fa5 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -2,65 +2,118 @@
 
 #include "src/turbomind/models/llama/LlamaBatch.h"
 #include "src/turbomind/kernels/decoding_kernels.h"
+#include "src/turbomind/kernels/sampling_topk_kernels.h"
 #include "src/turbomind/macro.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/LlamaV2.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/copy.h"
+#include "src/turbomind/models/llama/llama_kernels.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/debug_utils.h"
+#include "src/turbomind/utils/gemm_test/gemm_func.h"
 #include "src/turbomind/utils/logger.h"
+#include 
+#include 
+#include 
 #include 
+#include 
 #include 
+#include 
+#include 
+#include 
 #include 
 #include 
+#include 
 
 namespace turbomind {
 
+void PrintDecodeTokens(
+    const int* token_ids, int max_seq_len, int batch_sizse, cudaStream_t stream, const std::string& msg)
+{
+    // tokens in [S, B] layout
+    std::vector tokens(max_seq_len * batch_sizse);
+    check_cuda_error(cudaMemcpyAsync(tokens.data(), token_ids, sizeof(int) * tokens.size(), cudaMemcpyDefault, stream));
+    check_cuda_error(cudaStreamSynchronize(stream));
+
+    printf("[%s] ", msg.c_str());
+    for (int j = 0; j < max_seq_len; ++j) {
+        printf("%5d ", j);
+    }
+    printf("\n");
+    for (int i = 0; i < batch_sizse; ++i) {
+        printf("[%s] ", msg.c_str());
+        for (int j = 0; j < max_seq_len; ++j) {
+            // std::cout << sb_tokens[j * batch_size + i] << " ";
+            printf("%5d ", tokens[j * batch_sizse + i]);
+        }
+        printf("\n");
+    }
+}
+void ClearState(BatchState& s)
+{
+    std::fill_n(s.requests.begin(), s.size, nullptr);
+    std::fill_n(s.sequences.begin(), s.size, nullptr);
+    s.size = s.active_size = 0;
+}
+
 template
-void LlamaBatch::verifyRequests(std::vector>& stop_reqs,
-                                   std::vector>& infer_reqs)
+void LlamaBatch::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
 {
     std::unordered_map occurrence;
 
-    auto count_occurrence = [&occurrence](const std::vector>& rs) {
+    auto count_occurrence = [&occurrence](const Requests& rs) {
         for (const auto& r : rs) {
             ++occurrence[r->id];
         }
     };
 
-    auto invalidate = [](const char* type, std::shared_ptr& req, int ec) {
-        TM_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
-        // We don't need a barrier there because
-        // this lambda is called only for new requests
-        // which are visible only for rank = 0 thread.
+    auto reject = [](const char* type, std::shared_ptr& req, int ec) {
+        TM_LOG_WARNING(
+            "[RejectInvalidRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
         req->signal.set_value(ec);
         req.reset();
     };
 
-    auto handle_conflict_or_invalid = [this, &occurrence, &invalidate](std::vector>& rs,
-                                                                       const char*                            type) {
+    auto handle_conflict_or_invalid = [this, &occurrence, &reject](Requests& rs, const char* type) {
         for (auto& r : rs) {
             if (r) {
                 int ec = 0;
 
+                const int  input_length = r->inputs[rank_].getVal("input_lengths", 0);
+                const auto get_offset   = [&](int token_count) {
+                    return std::max(0, std::min(token_count, r->inputs[rank_].getVal("step", token_count)));
+                };
+
                 if (occurrence[r->id] != 1) {
                     ec = Request::kConflict;
                 }
                 else if (r->start_flag && r->stop_flag) {
                     ec = Request::kInvalid;
                 }
-                else if (!r->start_flag && !llama_->kv_cache_mgr_->contains(r->id)) {
-                    ec = Request::kInvalid;
+                else if (input_length > session_len_) {
+                    ec = Request::kTooLong;
+                }
+                else if (!r->start_flag) {
+                    if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
+                        ec = Request::kInvalid;
+                    }
+                    else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
+                        ec = Request::kTooLong;
+                    }
                 }
 
                 if (ec) {
-                    invalidate(type, r, ec);
+                    reject(type, r, ec);
                 }
             }
         }
     };
 
-    auto drop_invalid = [](std::vector>& rs) {
+    auto drop_invalid = [](Requests& rs) {
         int count = 0;
         for (int i = 0; i < rs.size(); ++i) {
             if (rs[i]) {
@@ -80,14 +133,14 @@ void LlamaBatch::verifyRequests(std::vector>& stop_r
         for (auto& r : stop_reqs) {
             if (r && r->end_flag == false) {
                 int ec = Request::kInactive;
-                for (int i = 0; i < batch_size_; ++i) {
-                    if (requests_[i] && requests_[i]->id == r->id) {
+                for (int i = 0; i < state_->size; ++i) {
+                    if (state_->requests[i] && state_->requests[i]->id == r->id) {
                         ec = 0;
                         break;
                     }
                 }
                 if (ec) {
-                    invalidate("stop", r, ec);
+                    reject("stop", r, ec);
                 }
             }
         }
@@ -101,9 +154,9 @@ void LlamaBatch::verifyRequests(std::vector>& stop_r
         // invalidate requests for busy sequences
         for (auto& r : infer_reqs) {
             if (r) {
-                for (int i = 0; i < batch_size_; ++i) {
-                    if (requests_[i] && requests_[i]->id == r->id) {
-                        invalidate("infer", r, Request::kBusy);
+                for (int i = 0; i < state_->size; ++i) {
+                    if (state_->requests[i] && state_->requests[i]->id == r->id) {
+                        reject("infer", r, Request::kBusy);
                         break;
                     }
                 }
@@ -115,53 +168,485 @@ void LlamaBatch::verifyRequests(std::vector>& stop_r
 }
 
 template
-void LlamaBatch::handleStopRequests(const std::vector>& requests)
+auto LlamaBatch::ProcessStopRequests(const Requests& requests) -> std::vector
 {
+    NvtxScope           scope("stop_request");
+    std::vector signals;
+    int                 count = 0;
     for (const auto& r : requests) {
         int ec = Request::kFail;
         // find matching active sequence
-        for (int i = 0; i < batch_size_; ++i) {
+        for (int i = 0; i < state_->size; ++i) {
             // stop & optionally erase active sequence
-            if (requests_[i] && requests_[i]->id == r->id) {
+            if (state_->requests[i] && state_->requests[i]->id == r->id) {
                 ec = 0;
-                finishRequest(i, r->end_flag);
+                signals.push_back(Interrupt(i, true, r->end_flag));
+                ++count;
                 break;
             }
         }
-        // mismatch, try erase inactive sequence
+        // mismatch, try erase inactive sequence, in this case there is no active request to interrupt
         if (ec && r->end_flag) {
-            ec = 0;
-            llama_->kv_cache_mgr_->erase(r->id);
+            if (sequence_manager_->Erase(r->id)) {
+                ec = 0;
+            }
         }
-        // clear output buffers (prevent leaking conversations) if request is successful
-        if (ec == 0) {
-            auto& output_ids      = r->outputs[rank_].at("output_ids");
-            auto& sequence_length = r->outputs[rank_].at("sequence_length");
-            check_cuda_error(
-                cudaMemsetAsync(output_ids.getPtr(), 0, sizeof(int) * output_ids.shape.at(2), stream_));
-            check_cuda_error(cudaMemsetAsync(sequence_length.getPtr(), 0, sizeof(int), stream_));
-            check_cuda_error(cudaStreamSynchronize(stream_));
+        signals.push_back([=] {
+            if (rank_ == 0) {
+                r->signal.set_value(ec);
+            }
+        });
+    }
+    if (count) {
+        check_cuda_error(cudaStreamSynchronize(stream_));
+    }
+    return signals;
+}
+
+template
+void LlamaBatch::ProcessInferRequests(const Requests& requests)
+{
+    NvtxScope scope("infer_request");
+    auto&     state = *incoming_;
+
+    FT_CHECK(state.size == 0);
+    FT_CHECK(state.active_size == 0);
+
+    std::vector existing_idx;
+
+    int idx = 0;
+    for (const auto& r : requests) {
+        FT_CHECK(!state.requests[idx]);
+
+        if (rank_ == 0) {
+            TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
+        }
+
+        state.requests[idx] = r;
+
+        // get sequence for the request
+        state.sequences[idx] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
+        FT_CHECK(state.sequences[idx]);
+
+        auto& seq = *state.sequences[idx];
+
+        if (int step = r->inputs[rank_].getVal("step", -1); step >= 0) {
+            if (step <= seq.tokens.size()) {
+                seq.tokens.resize(step);
+                seq.cache_len = std::min(seq.cache_len, step);
+            }
+            else if (rank_ == 0) {
+                TM_LOG_WARNING(
+                    "[ProcessInferRequests] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
+            }
         }
 
-        // When the signal is set threads from LlamaV2::forward can exit
-        // and free inputs/outputs tensors.
-        // Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
-        // are accessing the tensors.
-        llama_->shared_state_->barrier->wait();
+        const int  input_length = r->inputs[rank_].getVal("input_lengths");
+        const int* input_ids    = r->inputs[rank_].getPtr("input_ids");
+
+        // `output_ids` contains all token ids of the sequences
+        const auto output_ids_base = state.output_ids + session_len_ * idx;
+        auto       output_ids      = output_ids_base;
+
+        // copy history tokens
+        if (!seq.tokens.empty()) {
+            output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);
+        }
+
+        // copy input tokens
+        if (input_length) {
+            output_ids = Copy(input_ids, input_length, output_ids);
+        }
+
+        // total context length (history + input)
+        state.h_context_length[idx] = output_ids - output_ids_base;
+        state.h_finished[idx]       = false;
+
+        const int request_output_len = state.requests[idx]->inputs[rank_].getVal("request_output_len");
+        state.seq_len_limit[idx]     = state.h_context_length[idx] + request_output_len;
+        // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
+        // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
+        if (state.seq_len_limit[idx] >= session_len_) {
+            state.seq_len_limit[idx] = session_len_ - 1;
+            if (rank_ == 0) {
+                const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx];
+                TM_LOG_WARNING(
+                    "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
+                    (long)seq.id,
+                    state.h_context_length[idx],
+                    request_output_len,
+                    (int)session_len_,
+                    trunc_output_len);
+            }
+        }
+
+        // compute rope scaling factor
+        if (r->start_flag) {
+            seq.rope_theta      = model_->attn_params_.rotary_embedding_base;
+            auto scaling_factor = 1.f;
+            if (r->inputs[rank_].isExist("rope_scaling_factor")) {  // runtime scaling factor
+                scaling_factor = r->inputs[rank_].getVal("rope_scaling_factor");
+            }
+            else if (model_->attn_params_.rope_scaling_factor >= 1.f) {  // infer by `seq_len_limit`
+                scaling_factor   = model_->attn_params_.rope_scaling_factor;
+                auto max_seq_len = state.seq_len_limit[idx];
+                auto max_pos_emb = model_->attn_params_.max_position_embeddings;
+                if (max_seq_len > max_pos_emb) {
+                    scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
+                    // scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f))
+                    // - 1.f, 1.f);
+                }
+            }
+            if (scaling_factor != 1.f) {
+                float rope_dim = model_->attn_params_.rotary_embedding_dim;
+                seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
+                TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
+                            (long)seq.id,
+                            scaling_factor,
+                            seq.rope_theta);
+            }
+        }
+        state.h_rope_theta[idx] = seq.rope_theta;
+
+        if (r->start_flag) {
+            // prepare to initialize random state for new sequence
+            h_random_seed_[idx] = r->inputs[rank_].getVal("random_seed", 0);
+        }
+        else {
+            // Recover device states if not a new sequence
+            h_curand_state_[existing_idx.size()] = *(curandState_t*)seq.random_state.data();
+            existing_idx.push_back(idx);
+        }
+
+        // ! SHARED STATE IS MODIFIED, BARRIER SYNCHRONIZATION REQUIRED
+        // assign priority based on arrival time
         if (rank_ == 0) {
-            r->signal.set_value(ec);
+            r->unique_id = request_count_++;
+        }
+
+        // increment pointer
+        idx++;
+    }
+
+    state.size = idx;
+
+    // when there are new sequences
+    if (state.size != existing_idx.size()) {
+        // copy random seeds to device
+        Copy(h_random_seed_, state.size, d_random_seed_);
+        // initialize random states
+        invokeCurandBatchInitialize(state.curand_state, state.size, d_random_seed_, stream_);
+        sync_check_cuda_error();
+    }
+
+    if (!existing_idx.empty()) {
+        // copy existing curand states to device
+        Copy(h_curand_state_, existing_idx.size(), d_curand_state_);
+        // insert the states to their correct positions in the batch
+        IndexedCopy({}, existing_idx, std::tuple{d_curand_state_, state.curand_state, 1});
+    }
+}
+
+template
+void LlamaBatch::AdjustMaxInputCount(GenerationState&                    g,
+                                        const std::vector& sequences,
+                                        const std::vector&             context_length)
+{
+    int input_count = 0;
+    for (int i = 0; i < sequences.size(); ++i) {
+        input_count += context_length[i] - sequences[i]->cache_len;
+    }
+    const int batch_size = sequences.size();
+    input_count -= batch_size;
+
+    // min tokens per iter for satisfying max prefill iters constraint
+    input_count = (input_count + max_prefill_iters_ - 1) / max_prefill_iters_;
+
+    if (g.min_input_count.empty()) {
+        g.min_input_count.resize(max_prefill_iters_);
+    }
+    g.min_input_count.pop_front();
+    g.min_input_count.push_back(input_count);
+    /// TODO: sub-optimal when there are inactive sequences due to memory constraint
+    for (auto& x : g.min_input_count) {
+        x = std::max(x, input_count);
+    }
+
+    input_count = std::max(g.min_input_count.front() + batch_size, num_tokens_per_iter_);
+    input_count = std::min(input_count, max_context_token_num_);
+    // update max input count
+    g.max_input_count1 = input_count;
+    g.max_input_count2 = std::min(input_count + extra_tokens_per_iter_, max_context_token_num_);
+}
+
+template
+void LlamaBatch::Initialize(GenerationState& g)
+{
+    NvtxScope                                scope("initialize");
+    std::vector             sequences;
+    std::vector            status;
+    std::vector                    priorities;
+    std::vector                         context_lengths;
+    std::vector> coords;
+
+    // count the holes introduced by finished requests in from previous iteration or stop requests from
+    // current iteration
+    int holes{};
+    int active_holes{};
+    for (int i = 0; i < state_->size; ++i) {
+        if (!state_->requests[i]) {
+            ++holes;
+            if (i < state_->active_size) {
+                ++active_holes;
+            }
+        }
+    }
+
+    auto process = [&](BatchState* state) {
+        for (int i = 0; i < state->size; ++i) {
+            if (auto& r = state->requests[i]) {
+                sequences.push_back(state->sequences[i]);
+                status.push_back(state->sequences[i]->status);
+                priorities.push_back(r->unique_id);
+                context_lengths.push_back(state->h_context_length[i]);
+                coords.emplace_back(state, i);
+            }
+        }
+    };
+
+    process(state_);
+    process(incoming_);
+
+    auto adjust = [this, &g](const Sequences&        sequences,
+                             const std::vector& context_length) -> std::pair {
+        AdjustMaxInputCount(g, sequences, context_length);
+        return {g.max_input_count1, g.max_input_count2};
+    };
+
+    // TM_LOG_INFO("max_input_count %d", max_input_count);
+    auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_, adjust);
+
+    if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
+        dbg(outcome);
+    }
+
+    bool exchange = outcome.swap_in + outcome.swap_out > 0;
+
+    std::vector idxs(sequences.size());
+    std::iota(idxs.begin(), idxs.end(), 0);
+
+    if (exchange || holes || incoming_->size) {
+        // put active ones first
+        auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
+            return sequences[idx]->status == Sequence::kActive;  // current status
+        });
+
+        // all blocks are not enough to hold a single sequence
+        if (!sequences.empty()) {
+            FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
+        }
+
+        // move the partial seq to the back
+        auto partial_beg = std::stable_partition(idxs.begin(), active_end, [&](int i) {
+            return sequences[i]->cache_len + sequences[i]->input_length == context_lengths[i];
+        });
+        FT_CHECK(active_end - partial_beg <= 1);
+
+        auto swapin_beg = std::stable_partition(idxs.begin(), partial_beg, [&](int i) {
+            return status[i] == Sequence::kActive;  // past status
+        });
+
+        // sort swap-ins according to input length
+        if (swapin_beg != partial_beg) {
+            std::stable_sort(swapin_beg, partial_beg, [&](int i, int j) {
+                return sequences[i]->input_length < sequences[j]->input_length;
+            });
+        }
+
+        // Copy sequence states to back buffer
+        FT_CHECK(back_->size == 0 && back_->active_size == 0);
+        std::vector> cpys;
+        for (const auto& i : idxs) {
+            auto& s = *sequences[i];
+            if (s.status == Sequence::kActive) {
+                ++back_->active_size;
+            }
+            cpys.emplace_back(coords[i].first, back_, coords[i].second, back_->size++);
+        }
+        CopyState(cpys);
+        // Swap the buffers
+        std::swap(state_, back_);
+
+        ClearState(*back_);
+        ClearState(*incoming_);
+    }
+
+    FT_CHECK(state_->size <= max_batch_size_);
+
+    /// Update block ptrs when there were
+    //  1. swap-in or swap-out
+    //  2. holes in the active buffer
+    //  3. new allocations (for existing active sequences)
+    if (exchange || active_holes || outcome.allocation) {
+        // Prepare intermediate buffers
+        h_cu_block_counts_[0] = 0;
+
+        auto k_ptrs = h_k_block_ptrs_;
+        auto v_ptrs = h_v_block_ptrs_;
+
+        const int batch_size = state_->active_size;
+
+        for (int i = 0; i < batch_size; ++i) {
+            const auto& seq = *state_->sequences[i];
+
+            // cumulative num of blocks
+            h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
+
+            FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(),
+                               std::to_string(h_cu_block_counts_[i + 1]));
+
+            k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](int block_id) {
+                return reinterpret_cast(sequence_manager_->GetKeyPtr(block_id));
+            });
+            v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](int block_id) {
+                return reinterpret_cast(sequence_manager_->GetValPtr(block_id));
+            });
         }
+
+        static_assert(sizeof(uintptr_t) == sizeof(void*));
+
+        Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
+        Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
+        Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
+    }
+
+    const int batch_size = state_->active_size;
+
+    // check if the last sequence is partial
+    int partial     = 0;
+    int partial_len = -1;
+    if (state_->active_size) {
+        const int i = state_->active_size - 1;
+        partial = state_->sequences[i]->cache_len + state_->sequences[i]->input_length != state_->h_context_length[i];
+        if (partial) {
+            // backup full context length of partial
+            partial_len = state_->h_context_length[i];
+            // replace with partial context length
+            state_->h_context_length[i] = state_->sequences[i]->cache_len + state_->sequences[i]->input_length;
+        }
+    }
+
+    const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
+
+    std::vector unique_ids(batch_size);
+    for (int i = 0; i < batch_size; ++i) {
+        unique_ids[i] = state_->requests[i]->unique_id;
+    }
+
+    // Real-time context length that will change during generation
+    Copy(state_->h_context_length, batch_size, context_length_buf_);
+    Copy(state_->h_finished, batch_size, finished_buf_);
+    Copy(state_->h_rope_theta, batch_size, rope_theta_);
+
+    // used for dispatching split-k decoding kernels
+    const int sum_seq_len =
+        std::accumulate(state_->h_context_length, state_->h_context_length + batch_size, -batch_size);
+    const int max_seq_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size) - 1;
+
+    // TM_LOG_INFO(
+    //     "[init] batch_size = %d, max_ctx_len = %d, partial = %d", (int)batch_size, (int)max_context_len, partial);
+
+    bool skip_init_sampling = std::equal(g.unique_ids.begin(),  //
+                                         g.unique_ids.end() - g.partial,
+                                         unique_ids.begin(),
+                                         unique_ids.end() - partial);
+
+    g.sum_seq_len            = sum_seq_len;
+    g.max_seq_len            = max_seq_len;
+    g.partial                = partial;
+    g.partial_context_legnth = partial_len;
+    g.unique_ids             = std::move(unique_ids);
+    g.finished_count         = 0;
+
+    if (!skip_init_sampling) {
+        g.max_init_ctx_len = max_context_len;
+        g.step             = max_context_len;
+        InitializeSampling(g);
     }
 }
 
 template
-void LlamaBatch::allocateBuffer(size_t batch_size, size_t session_len)
+void LlamaBatch::CopyState(const std::vector>& desc)
+{
+    if (desc.empty()) {
+        return;
+    }
+
+    std::vector idxs(desc.size());
+    std::iota(idxs.begin(), idxs.end(), 0);
+
+    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return desc[i] < desc[j]; });
+
+    auto get_signature = [&](int i) -> std::pair {
+        return std::make_pair(std::get<0>(desc[idxs[i]]), std::get<1>(desc[idxs[i]]));
+    };
+
+    std::vector offsets;
+    auto             current = get_signature(0);
+    offsets.push_back(0);
+    for (int i = 0; i < idxs.size(); ++i) {
+        if (auto signature = get_signature(i); signature != current) {
+            current = signature;
+            offsets.push_back(i);
+        }
+    }
+    offsets.push_back(idxs.size());
+
+    for (int bi = 1; bi < offsets.size(); ++bi) {
+        int beg = offsets[bi - 1];
+        int end = offsets[bi];
+
+        if (beg == end) {
+            continue;
+        }
+
+        auto [s, d] = get_signature(beg);
+
+        std::vector s_idx;
+        std::vector d_idx;
+        for (int i = beg; i < end; ++i) {
+            s_idx.push_back(std::get<2>(desc[idxs[i]]));
+            d_idx.push_back(std::get<3>(desc[idxs[i]]));
+        }
+
+        IndexedCopy(s_idx,
+                    d_idx,
+                    std::tuple{s->output_ids, d->output_ids, session_len_},
+                    std::tuple{s->curand_state, d->curand_state, 1});
+    }
+
+    for (const auto& [s, d, si, di] : desc) {
+        d->h_context_length[di] = s->h_context_length[si];
+        d->h_finished[di]       = s->h_finished[si];
+        d->h_rope_theta[di]     = s->h_rope_theta[si];
+        d->seq_len_limit[di]    = s->seq_len_limit[si];
+        d->sequences[di]        = s->sequences[si];
+        d->requests[di]         = s->requests[si];
+    }
+}
+
+template
+void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     const size_t batchxbeam = batch_size;
 
-    const size_t hidden_units = llama_->hidden_units_;
-    const size_t vocab_size   = llama_->vocab_size_padded_;
+    const size_t hidden_units      = model_->hidden_units_;
+    const size_t vocab_size        = model_->vocab_size_padded_;
+    const size_t head_dim          = model_->size_per_head_;
+    const size_t local_kv_head_num = model_->local_kv_head_num_;
+    // +1 padding, BlockIterator does not use predicate
+    const size_t max_block_count = sequence_manager_->max_block_count() + 1;
 
     context_decoder_input_buf_ =
         (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
@@ -170,86 +655,123 @@ void LlamaBatch::allocateBuffer(size_t batch_size, size_t session_len)
     context_decoder_ids_buf_ =
         (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);
 
+    tmp_k_cache_buf_ = (T*)allocator_->reMalloc(
+        tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
+    tmp_v_cache_buf_ = (T*)allocator_->reMalloc(
+        tmp_v_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
+
+    tmp_k_ptrs_ = (void**)allocator_->reMalloc(tmp_k_ptrs_, sizeof(void*) * batch_size, false);
+    tmp_v_ptrs_ = (void**)allocator_->reMalloc(tmp_v_ptrs_, sizeof(void*) * batch_size, false);
+
     decoder_input_buf_  = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false);
     decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false);
 
-    input_ids_buf_      = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
-    input_length_buf_   = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
-    history_length_buf_ = (int*)allocator_->reMalloc(history_length_buf_, sizeof(int) * batchxbeam);
-    context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
+    input_ids_buf_       = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
+    input_length_buf_    = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
+    context_length_buf_  = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
+    init_context_length_ = (int*)allocator_->reMalloc(init_context_length_, sizeof(int) * batchxbeam);
 
-    total_padding_count_ = (int*)allocator_->reMalloc(total_padding_count_, sizeof(int) * batchxbeam, false);
-    sequence_lengths_    = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
+    sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
 
-    k_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(k_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam);
-    v_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(v_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam);
+    cu_block_counts_ = (int*)allocator_->reMalloc(cu_block_counts_, sizeof(int) * (batch_size + 1));
+    k_block_ptrs_    = (uintptr_t*)allocator_->reMalloc(k_block_ptrs_, sizeof(uintptr_t) * max_block_count);
+    v_block_ptrs_    = (uintptr_t*)allocator_->reMalloc(v_block_ptrs_, sizeof(uintptr_t) * max_block_count);
 
     logits_buf_       = (float*)allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
     local_logits_buf_ = (float*)allocator_->reMalloc(local_logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
 
     token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true);
 
-    end_ids_buf_   = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false);
     finished_buf_  = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
     seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
 
+    rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false);
+
     is_allocate_buffer_ = true;
 }
 
 template
-void LlamaBatch::allocatePersistantBuffer(size_t max_batch_size)
+void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size)
 {
-    output_ids_buf_ = (int*)allocator_->reMalloc(output_ids_buf_, sizeof(int) * max_batch_size * session_len_, true);
-
-    stop_words_buf_ =
-        (int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
-    bad_words_buf_ =
-        (int*)allocator_->reMalloc(bad_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
+    d_stop_words_ = (int*)allocator_->reMalloc(d_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
+    d_bad_words_  = (int*)allocator_->reMalloc(d_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
+    h_stop_words_ =
+        (int*)allocator_->reMalloc(h_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
+    h_bad_words_ =
+        (int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
 
     h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
     h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
     h_temperature_   = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
     h_repetition_penalty_ =
         (float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true);
-    h_random_seed_ = (uint64_t*)allocator_->reMalloc(h_random_seed_, sizeof(uint64_t) * max_batch_size, true, true);
 
-    sampling_params_ = {{"stop_words_list", stop_words_buf_},
-                        {"bad_words_list", bad_words_buf_},
-                        {"runtime_top_k", h_runtime_top_k_},
-                        {"runtime_top_p", h_runtime_top_p_},
-                        {"temperature", h_temperature_},
-                        {"repetition_penalty", h_repetition_penalty_},
-                        {"random_seed", h_random_seed_}};
+    h_random_seed_ = (unsigned long long*)allocator_->reMalloc(
+        h_random_seed_, sizeof(unsigned long long) * max_batch_size, true, true);
+    d_random_seed_ = (unsigned long long*)allocator_->reMalloc(
+        d_random_seed_, sizeof(unsigned long long) * max_batch_size, true, false);
+
+    h_curand_state_ =
+        (curandState_t*)allocator_->reMalloc(h_curand_state_, sizeof(curandState_t) * max_batch_size, true, true);
+    d_curand_state_ =
+        (curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false);
+
+    d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false);
+    h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true);
+
+    sampling_params_ = {
+        {"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
+        {"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
+        {"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
+        {"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
+        {"temperature", (std::byte*)h_temperature_, nullptr},
+        {"repetition_penalty", (std::byte*)h_repetition_penalty_, nullptr},
+    };
+
+    for (auto& s : states_) {
+        s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
+        s.curand_state =
+            (curandState_t*)allocator_->reMalloc(s.curand_state, sizeof(curandState_t) * max_batch_size, true);
+    }
 
-    topk_curandstate_buf_ = allocator_->reMalloc(topk_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true);
-    topp_curandstate_buf_ = allocator_->reMalloc(topp_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true);
+    const size_t max_block_count = sequence_manager_->max_block_count();
 
     {
-        NcclGuard barrier(llama_->tensor_para_, stream_, true);
+        NcclGuard barrier(model_->tensor_para_, stream_, true);
         h_input_ids_buf_ =
             (int*)allocator_->reMalloc(h_input_ids_buf_, sizeof(int) * max_batch_size * session_len_, false, true);
         h_input_length_buf_ =
             (int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_history_length_buf_ =
-            (int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_context_length_buf_ =
-            (int*)allocator_->reMalloc(h_context_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_sequence_lengths_ =
-            (int*)allocator_->reMalloc(h_sequence_lengths_, sizeof(int) * max_batch_size, false, true);
-        h_k_cache_ptr_buf_ =
-            (uintptr_t*)allocator_->reMalloc(h_k_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true);
-        h_v_cache_ptr_buf_ =
-            (uintptr_t*)allocator_->reMalloc(h_v_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true);
-        h_finished_buf_ = (bool*)allocator_->reMalloc(h_finished_buf_, sizeof(bool) * max_batch_size, false, true);
+
+        h_tmp_k_ptrs_ = (void**)allocator_->reMalloc(h_tmp_k_ptrs_, sizeof(void*) * max_batch_size, false, true);
+        h_tmp_v_ptrs_ = (void**)allocator_->reMalloc(h_tmp_v_ptrs_, sizeof(void*) * max_batch_size, false, true);
+
+        h_cu_block_counts_ =
+            (int*)allocator_->reMalloc(h_cu_block_counts_, sizeof(int) * (max_batch_size + 1), false, true);
+        h_k_block_ptrs_ =
+            (uintptr_t*)allocator_->reMalloc(h_k_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
+        h_v_block_ptrs_ =
+            (uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
+
+        for (auto& s : states_) {
+            s.h_context_length =
+                (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
+            s.h_finished   = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
+            s.h_rope_theta = (float*)allocator_->reMalloc(s.h_rope_theta, sizeof(float) * max_batch_size, false, true);
+        }
+
         h_seq_limit_len_ =
             (uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
+
+        h_output_ids_ =
+            (int*)allocator_->reMalloc(h_output_ids_, sizeof(int) * max_batch_size * session_len_, false, true);
     }
 
     is_allocate_persistant_buffer_ = true;
 }
 
 template
-void LlamaBatch::freeBuffer()
+void LlamaBatch::FreeBuffer()
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     if (is_allocate_buffer_) {
@@ -257,19 +779,24 @@ void LlamaBatch::freeBuffer()
         allocator_->free((void**)&context_decoder_output_buf_);
         allocator_->free((void**)&context_decoder_ids_buf_);
 
+        allocator_->free((void**)&tmp_k_cache_buf_);
+        allocator_->free((void**)&tmp_v_cache_buf_);
+        allocator_->free((void**)&tmp_k_ptrs_);
+        allocator_->free((void**)&tmp_v_ptrs_);
+
         allocator_->free((void**)&decoder_input_buf_);
         allocator_->free((void**)&decoder_output_buf_);
 
         allocator_->free((void**)&input_ids_buf_);
         allocator_->free((void**)&input_length_buf_);
-        allocator_->free((void**)&history_length_buf_);
         allocator_->free((void**)&context_length_buf_);
+        allocator_->free((void**)&init_context_length_);
 
-        allocator_->free((void**)&total_padding_count_);
         allocator_->free((void**)&sequence_lengths_);
 
-        allocator_->free((void**)&k_cache_ptr_buf_);
-        allocator_->free((void**)&v_cache_ptr_buf_);
+        allocator_->free((void**)&cu_block_counts_);
+        allocator_->free((void**)&k_block_ptrs_);
+        allocator_->free((void**)&v_block_ptrs_);
 
         allocator_->free((void**)&logits_buf_);
         allocator_->free((void**)&local_logits_buf_);
@@ -283,854 +810,747 @@ void LlamaBatch::freeBuffer()
 
         allocator_->free((void**)&token_ids_buf_);
 
-        allocator_->free((void**)&end_ids_buf_);
+        allocator_->free((void**)&d_end_ids_buf_);
+        allocator_->free((void**)&h_end_ids_buf_, true);
+
         allocator_->free((void**)&finished_buf_);
         allocator_->free((void**)&seq_limit_len_);
 
+        allocator_->free((void**)&rope_theta_);
+
         is_allocate_buffer_ = false;
     }
 
     if (is_allocate_persistant_buffer_) {
+
+        allocator_->free((void**)&d_stop_words_);
+        allocator_->free((void**)&h_stop_words_, true);
+        allocator_->free((void**)&d_bad_words_);
+        allocator_->free((void**)&h_bad_words_, true);
+        allocator_->free((void**)&d_random_seed_);
+        allocator_->free((void**)&h_random_seed_, true);
+        allocator_->free((void**)&d_curand_state_);
+        allocator_->free((void**)&h_curand_state_, true);
+
+        for (auto& s : states_) {
+            allocator_->free((void**)&s.h_context_length, true);
+            allocator_->free((void**)&s.h_finished, true);
+            allocator_->free((void**)&s.h_rope_theta, true);
+            allocator_->free((void**)&s.output_ids);
+            allocator_->free((void**)&s.curand_state);
+        }
+        allocator_->free((void**)&h_tmp_k_ptrs_, true);
+        allocator_->free((void**)&h_tmp_v_ptrs_, true);
+        allocator_->free((void**)&h_cu_block_counts_, true);
+        allocator_->free((void**)&h_k_block_ptrs_, true);
+        allocator_->free((void**)&h_v_block_ptrs_, true);
         allocator_->free((void**)&h_input_ids_buf_, true);
         allocator_->free((void**)&h_input_length_buf_, true);
-        allocator_->free((void**)&h_history_length_buf_, true);
-        allocator_->free((void**)&h_context_length_buf_, true);
-        allocator_->free((void**)&h_sequence_lengths_, true);
-        allocator_->free((void**)&h_k_cache_ptr_buf_, true);
-        allocator_->free((void**)&h_v_cache_ptr_buf_, true);
         allocator_->free((void**)&h_seq_limit_len_, true);
-        allocator_->free((void**)&h_finished_buf_, true);
 
-        allocator_->free((void**)&output_ids_buf_);
+        allocator_->free((void**)&h_output_ids_, true);
 
         is_allocate_persistant_buffer_ = false;
     }
 }
 
 template
-LlamaBatch::LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2* llama):
-    max_batch_size_(max_batch_size),
-    max_context_token_num_(max_context_token_num),
-    session_len_(session_len),
-    rank_(llama->tensor_para_.rank_),
-    debug_(llama->debug_),
-    llama_(llama),
-    data_type_(getTensorType())
+LlamaBatch::LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model):
+    max_batch_size_(params.max_batch_size),
+    max_context_token_num_(params.max_context_token_num),
+    session_len_(params.session_len),
+    rank_(model->tensor_para_.rank_),
+    debug_(model->debug_),
+    step_length_(params.step_length),
+    model_(model),
+    data_type_(getTensorType()),
+    num_tokens_per_iter_(params.num_tokens_per_iter),
+    extra_tokens_per_iter_(params.extra_tokens_per_iter),
+    max_prefill_iters_(params.max_prefill_iters)
 {
-    stream_         = llama_->stream_;
-    allocator_      = llama_->allocator_;
-    cublas_wrapper_ = llama_->cublas_wrapper_;
+    stream_         = model_->stream_;
+    allocator_      = model_->allocator_;
+    cublas_wrapper_ = model_->cublas_wrapper_;
+
+    const size_t elem_bits = (quant_policy & QuantPolicy::kCacheKVInt8) ? 8 : sizeof(T) * 8;
+
+    sequence_manager_.reset(new SequenceManager{model_->num_layer_,
+                                                model_->local_kv_head_num_,
+                                                model_->size_per_head_,
+                                                (size_t)cache_block_seq_len,
+                                                params.cache_max_block_count,
+                                                params.cache_chunk_size,
+                                                elem_bits,
+                                                model->tensor_para_.rank_,
+                                                allocator_});
+
+    const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
+    if (max_session_len < session_len_) {
+        if (rank_ == 0) {
+            TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
+                           session_len_,
+                           max_session_len);
+        }
+        session_len_ = max_session_len;
+    }
+
+    for (auto& s : states_) {
+        s.requests.resize(max_batch_size_);
+        s.sequences.resize(max_batch_size_);
+        s.seq_len_limit.resize(max_batch_size_);
+    }
 
-    requests_.resize(max_batch_size);
-    request_seq_len_limit_.resize(max_batch_size);
-    cached_seq_.resize(max_batch_size);
+    state_    = &states_[0];
+    back_     = &states_[1];
+    incoming_ = &states_[2];
 
-    allocatePersistantBuffer(max_batch_size);
+    AllocateBuffer(max_batch_size_, session_len_);
+    AllocatePersistantBuffer(max_batch_size_);
 }
 
 template
-void LlamaBatch::initializeSampling(int infer_request_count)
+void LlamaBatch::InitializeSampling(const GenerationState& g)
 {
+    NvtxScope _("InitSampling");
+    const int batch_size = state_->active_size - g.partial;
+    if (batch_size == 0) {
+        return;
+    }
+
+    // Context length at initialization, will stay constant until re-initialziation
+    Copy(context_length_buf_, batch_size, init_context_length_);
+
+    Copy(context_length_buf_, batch_size, sequence_lengths_);
+    // `sequence_lengths_` will be increased by dynamic decode
+    // note that in decoder and in output "sequence length" has different semantic
+    // - in decoder it means length of sequence that has kv cache already computed
+    // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
+    invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
+    sync_check_cuda_error();
+
+    Clear(token_ids_buf_, batch_size * session_len_);
+    invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
+    sync_check_cuda_error();
+
+    // token_ids_buf_[s, b]
+    // ABCDe            ABCDe     e
+    // ABCDEFGHIJk      ABCDEFGHIJk
+    // ABCDEFGHi    ->  ABCDEFGHi i
+    // ABCDEFGh         ABCDEFGh  h
+    // ABCd             ABCd      d
+    invokePadLastTokenIds(token_ids_buf_, init_context_length_, g.max_init_ctx_len, batch_size, stream_);
+    sync_check_cuda_error();
+
+    // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for
+    for (int i = 0; i < batch_size; ++i) {
+        h_seq_limit_len_[i] = state_->seq_len_limit[i] + (g.max_init_ctx_len - state_->h_context_length[i]);
+    }
+    Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
+
     TensorMap inputs;
-    for (const auto& param : sampling_params_) {
+    for (const auto& [name, h_ptr, d_ptr] : sampling_params_) {
+        // find an exemplar that matches the param name
         const Tensor* ptr{};
-        for (int i = 0; i < batch_size_; ++i) {
-            if (requests_[i]->inputs[rank_].isExist(param.first)) {
-                ptr = &requests_[i]->inputs[rank_].at(param.first);
+        for (int i = 0; i < batch_size; ++i) {
+            if (state_->requests[i]->inputs[rank_].isExist(name)) {
+                ptr = &state_->requests[i]->inputs[rank_].at(name);
                 break;
             }
         }
+        // fill the batch of the param
         if (ptr) {
             const auto& ref   = *ptr;
             auto        shape = ref.shape;
             FT_CHECK(shape[0] == 1);
-            shape[0]                = batch_size_;
+            shape[0]                = batch_size;
             const int size_in_bytes = ref.sizeBytes();
-            check_cuda_error(cudaMemsetAsync(param.second, 0, size_in_bytes * batch_size_, stream_));
-            for (int i = 0; i < batch_size_; ++i) {
-                if (requests_[i]->inputs[rank_].isExist(param.first)) {
-                    auto& src = requests_[i]->inputs[rank_].at(param.first);
+            memset(h_ptr, 0, size_in_bytes * batch_size);
+            for (int i = 0; i < batch_size; ++i) {
+                FT_CHECK(state_->requests[i] != nullptr);
+                if (state_->requests[i]->inputs[rank_].isExist(name)) {
+                    Tensor& src = state_->requests[i]->inputs[rank_].at(name);
                     FT_CHECK(ref.shape == src.shape);
-                    check_cuda_error(cudaMemcpyAsync((uint8_t*)param.second + size_in_bytes * i,
-                                                     src.getPtr(),
-                                                     size_in_bytes,
-                                                     cudaMemcpyDefault,
-                                                     stream_));
+                    std::copy_n(src.getPtr(), size_in_bytes, h_ptr + size_in_bytes * i);
                 }
             }
-            inputs.insert({param.first, {ref.where, ref.type, shape, param.second}});
+            if (d_ptr) {
+                Copy(h_ptr, batch_size * size_in_bytes, d_ptr);
+            }
+            inputs.insert({name, {d_ptr ? MEMORY_GPU : MEMORY_CPU, ref.type, shape, d_ptr ? d_ptr : h_ptr}});
             if (debug_ && rank_ == 0) {
-                TM_LOG_INFO("[initializeSampling] %s", format({param.first, inputs.at(param.first)}).c_str());
+                TM_LOG_INFO("[initializeSampling] %s", format({name, inputs.at(name)}).c_str());
             }
         }
     }
 
-    inputs_ = std::move(inputs);
+    // init for eos
+    std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
+    Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
+    inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)batch_size}, d_end_ids_buf_}});
 
-    llama_->dynamic_decode_layer_->setup(batch_size_, 1, &inputs_);
-
-    for (int i = 0; i < batch_size_; ++i) {
-        // recover random states if not a new request or new request w/o "random_seed"
-        if (i < batch_size_ - infer_request_count || !requests_[i]->inputs[rank_].isExist("random_seed")) {
-            check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topk_curandstate_buf() + i,
-                                             (curandState_t*)topk_curandstate_buf_ + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topp_curandstate_buf() + i,
-                                             (curandState_t*)topp_curandstate_buf_ + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-        }
-    }
+    inputs_ = std::move(inputs);
 
-    handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
-    cudaStreamSynchronize(0);
+    model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
 }
 
 template
-void LlamaBatch::initializeGeneration()
+void LlamaBatch::OutputContextLogits(T*                      context_decoder_output,
+                                        const std::vector& indices,
+                                        const std::vector& lengths)
 {
-    max_context_len_ = *std::max_element(h_context_length_buf_, h_context_length_buf_ + batch_size_);
-
-    check_cuda_error(cudaMemsetAsync(token_ids_buf_, 0, sizeof(int) * batch_size_ * session_len_ * 2, stream_));
-    invokeTransposeAxis01(token_ids_buf_, output_ids_buf_, batch_size_, session_len_, 1, stream_);
-    sync_check_cuda_error();
-
-    // token_ids_buf_[s, b]
-    // ABCDe            ABCDe     e
-    // ABCDEFGHIJk      ABCDEFGHIJk
-    // ABCDEFGHi    ->  ABCDEFGHi i
-    // ABCDEFGh         ABCDEFGh  h
-    // ABCd             ABCd      d
-    for (int i = 0; i < batch_size_; ++i) {
-        auto token_ids = token_ids_buf_ + i;
-        auto p_src     = h_context_length_buf_[i] - 1;
-        auto p_dst     = max_context_len_ - 1;
-        if (p_src != p_dst) {  // dst and src of `cudaMemcpyAsync` must not overlap
-            check_cuda_error(cudaMemcpyAsync(token_ids + p_dst * batch_size_,
-                                             token_ids + p_src * batch_size_,
-                                             sizeof(int),
-                                             cudaMemcpyDefault,
-                                             stream_));
+    std::vector output_logits;
+    int                 num_token = 0;
+    {
+        bool is_return_logits = false;
+        for (int k = 0; k < indices.size(); ++k) {
+            auto& request = state_->requests[indices[k]];
+            output_logits.push_back(request->outputs[rank_].getPtr("logits", nullptr));
+            num_token += lengths[k];
+            if (output_logits.back()) {
+                is_return_logits = true;
+            }
+        }
+        if (!is_return_logits) {
+            return;
         }
     }
 
-    check_cuda_error(cudaMemcpyAsync(
-        context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-
-    check_cuda_error(
-        cudaMemcpyAsync(sequence_lengths_, context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    // `sequence_lengths_` will be increased by dynamic decode
-    // note that in decoder and in output "sequence length" has different semantic
-    // - in decoder it means length of sequence that has kv cache already computed
-    // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
-    invokePlusScalar(sequence_lengths_, -1, batch_size_, stream_);
-    sync_check_cuda_error();
-
-    // total_padding_count_
-    // decoding starts at max_context_len
-    check_cuda_error(cudaMemsetAsync(total_padding_count_, 0, sizeof(int) * batch_size_, stream_));
-    invokeUpdatePaddingCount(total_padding_count_,  //
-                             context_length_buf_,
-                             max_context_len_,
-                             batch_size_,
-                             1,
-                             stream_);
-    sync_check_cuda_error();
-
-    // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for
-    for (int i = 0; i < batch_size_; ++i) {
-        h_seq_limit_len_[i] = request_seq_len_limit_[i] + (max_context_len_ - h_context_length_buf_[i]);
-        // mask finished sequences
-        h_finished_buf_[i] = max_context_len_ >= h_seq_limit_len_[i];
+    if (context_logits_buf_ == nullptr) {
+        NcclGuard guard(model_->tensor_para_, stream_, true);
+        context_logits_buf_ =
+            (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * max_context_token_num_);
+        const auto tp = model_->tensor_para_.world_size_;
+        if (tp > 1) {
+            FT_CHECK(model_->vocab_size_padded_ % tp == 0);
+            const auto local_vocab_size = model_->vocab_size_padded_ / tp;
+            local_context_logits_buf_ =
+                (float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
+        }
     }
-    check_cuda_error(
-        cudaMemcpyAsync(seq_limit_len_, h_seq_limit_len_, sizeof(uint32_t) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(
-        cudaMemcpyAsync(finished_buf_, h_finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_));
 
-    // ! range of step_ [1, 2 * session_len]
-    // consider a sequence with context_len == session_len and another sequence with context_len == 1 and
-    // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
-    step_ = max_context_len_;
+    model_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
 
-    if (rank_ == 0) {
-        TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size_);
-        TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len_);
-
-        TM_LOG_INFO("[initGen] slot  sequence_id  context_len  seq_limit_len  finished");
-        for (int i = 0; i < batch_size_; ++i) {
-            TM_LOG_INFO("[initGen] %4d  %11ld  %11d  %13d  %8d",
-                        i,
-                        (long)cached_seq_[i].id,
-                        h_context_length_buf_[i],
-                        (int)h_seq_limit_len_[i],
-                        (int)h_finished_buf_[i]);
+    auto logits = context_logits_buf_;
+
+    for (int k = 0; k < indices.size(); ++k) {
+        if (output_logits[k]) {
+            Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]);
         }
+        logits += model_->vocab_size_padded_ * lengths[k];
     }
 }
 
 template
-bool LlamaBatch::generate()
+auto LlamaBatch::Finish(GenerationState& g) -> std::vector
 {
-    constexpr int kLogInterval = 10;
-    if (rank_ == 0 && (step_ - 1) % kLogInterval == 0) {
-        TM_LOG_INFO("------------------------- step = %d -------------------------", step_ - 1);
+    NvtxScope scope("Finish");
+    const int batch_size = state_->active_size;
+
+    if (batch_size - g.partial) {
+        FT_CHECK(g.step >= 0);
+
+        // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
+        invokeGatherOutput(state_->output_ids,
+                           token_ids_buf_,
+                           init_context_length_,
+                           g.max_init_ctx_len,
+                           g.step,
+                           session_len_,
+                           batch_size - g.partial,
+                           stream_);
+        sync_check_cuda_error();
     }
 
-    const bool is_first_step = step_ == max_context_len_;
-
-    std::vector prev;
-    if (debug_ && rank_ == 0 && is_first_step) {
-        prev.resize(batch_size_);
-        cudaMemcpyAsync(prev.data(),
-                        token_ids_buf_ + (step_ - 1) * batch_size_,
-                        sizeof(int) * batch_size_,
-                        cudaMemcpyDefault,
-                        stream_);
-    }
+    Copy(state_->output_ids, batch_size * session_len_, h_output_ids_);
+    Copy(finished_buf_, batch_size, state_->h_finished);
+    Copy(sequence_lengths_, batch_size, state_->h_context_length);
 
-    // embeddingLookup(step_ - 1);
-    llama_->embeddingLookup(decoder_input_buf_,  //
-                            token_ids_buf_,
-                            batch_size_,
-                            step_ - 1);
-
-    llama_->decoderForward(decoder_output_buf_,
-                           k_cache_ptr_buf_,
-                           v_cache_ptr_buf_,
-                           decoder_input_buf_,
-                           sequence_lengths_,
-                           total_padding_count_,
-                           finished_buf_,
-                           step_,
-                           0,
-                           session_len_,
-                           batch_size_);
+    check_cuda_error(cudaStreamSynchronize(stream_));
 
-    llama_->postDecodeEmbedding(logits_buf_,  //
-                                local_logits_buf_,
-                                decoder_output_buf_,
-                                batch_size_);
+    // invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
+    // generated) tokens
+    for (int i = 0; i < batch_size; ++i) {
+        ++state_->h_context_length[i];
+    }
 
-    // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
-    // not supported yet.
-    bool should_stop{};
-    llama_->dynamicDecode(token_ids_buf_,
-                          finished_buf_,
-                          sequence_lengths_,
-                          &should_stop,
-                          &inputs_,
-                          &outputs_,
-                          logits_buf_,
-                          seq_limit_len_,
-                          context_length_buf_,
-                          end_ids_buf_,
-                          step_,
-                          0,
-                          max_context_len_,
-                          session_len_ * 2,
-                          batch_size_);
+    {  // set output tokens ids and sequence length
+        int* output_ptr = h_output_ids_;
+        for (int i = 0; i < batch_size - g.partial; ++i) {
+            if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
+                auto      output_ids = state_->requests[i]->outputs[rank_].getPtr("output_ids");
+                auto      output_len = state_->requests[i]->outputs[rank_].getPtr("sequence_length");
+                const int count      = state_->h_context_length[i];
+                // TODO: sync history output tokens at when receiving the request and copy the last token here
+                std::copy(output_ptr, output_ptr + count, output_ids);
+                *output_len = count;
+            }
+            output_ptr += session_len_;
+        }
+    }
 
     if (debug_ && rank_ == 0) {
-        std::vector curr(batch_size_);
-
-        cudaMemcpyAsync(
-            curr.data(), token_ids_buf_ + step_ * batch_size_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_);
-        cudaStreamSynchronize(stream_);
-
-        if (is_first_step) {
-            std::stringstream sprev;
-            for (int k = 0; k < prev.size(); ++k) {
-                sprev << std::setw(6) << prev[k];
+        for (int i = 0; i < batch_size; ++i) {
+            // ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
+            std::vector tokens(state_->h_context_length[i]);
+            Copy(state_->output_ids + i * session_len_, tokens.size(), tokens.data());
+            cudaStreamSynchronize(stream_);
+            std::stringstream ss;
+            for (const auto& t : tokens) {
+                ss << " " << t;
             }
-            TM_LOG_INFO("[ lookup ] step = %d, [%s]", step_ - 1, sprev.str().c_str());
+            TM_LOG_INFO("[Finish] slot %d, tokens [%s]", i, ss.str().c_str());
         }
+    }
 
-        std::stringstream scurr;
-        for (int k = 0; k < curr.size(); ++k) {
-            scurr << std::setw(6) << curr[k];
+    std::vector signals;
+    {
+        NvtxScope _("stream_and_completion_signal");
+        for (int i = 0; i < batch_size - g.partial; ++i) {
+            if (state_->requests[i]) {
+                if (state_->h_finished[i]) {
+                    // Interrupt finished sequences and move the request handle into the signal closure
+                    signals.push_back(Interrupt(i));
+                    ++g.finished_count;
+                }
+                else if (state_->requests[i]->stream_cb) {
+                    // Create signals by copying the request handles for non-finished streaming requests
+                    signals.push_back([this, r = state_->requests[i]] {
+                        if (rank_ == 0) {
+                            r->stream_cb(&r->outputs[rank_].get());
+                        }
+                    });
+                }
+            }
+        }
+        if (g.finished_count) {
+            // synchronize for interrupted sequences
+            check_cuda_error(cudaStreamSynchronize(stream_));
         }
-        TM_LOG_INFO("[generate] step = %d, [%s]", step_ - 1, scurr.str().c_str());
     }
 
-    ////////////////////////////////////////////////
-    /// ! increase the step counter
-    ++step_;
+    if (g.partial) {
+        const int i = batch_size - 1;
+        // recover full context length of partial
+        state_->h_context_length[i] = g.partial_context_legnth;
+    }
 
-    return !should_stop;
+    return signals;
 }
 
 template
-void LlamaBatch::initialize(const std::vector>& infer_requests)
+auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Signal
 {
-    FT_CHECK(batch_size_ + infer_requests.size() <= max_batch_size_);
+    if (rank_ == 0) {
+        TM_LOG_INFO("[Interrupt] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
+    }
 
-    const int infer_request_count = infer_requests.size();
+    if (debug_ && rank_ == 0) {
+        std::vector tokens(state_->h_context_length[index]);
+        Copy(state_->output_ids + index * session_len_, tokens.size(), tokens.data());
+        cudaStreamSynchronize(stream_);
+        std::stringstream ss;
+        for (const auto& t : tokens) {
+            ss << " " << t;
+        }
+        TM_LOG_INFO("[Interrupt] slot %d, tokens [%s]", index, ss.str().c_str());
+    }
 
-    allocateBuffer(batch_size_ + infer_request_count, session_len_);
+    if (state_->requests[index]->end_flag || force_end) {
+        // Sequence is ending this round or a stop request is issued to end it
+        FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
+    }
+    else {
+        const int output_len = state_->h_context_length[index];
+        auto&     seq        = *state_->sequences[index];
 
-    // handle infer requests
-    std::vector       tmp_input_length(infer_request_count);
-    std::vector tmp_cached_seq;
-    tmp_cached_seq.reserve(infer_request_count);
+        // Update token IDs
+        seq.tokens.resize(output_len);
+        const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr();
+        std::copy_n(output_ids_data, output_len, seq.tokens.data());
 
-    int tmp_max_input_length = 0;
-    for (int i = 0; i < infer_request_count; ++i) {
-        auto& r = *infer_requests[i];
+        // Save random state in host memory
+        seq.random_state.resize(sizeof(curandState_t));
+        // This async copy must be synchronized by the caller
+        Copy(state_->curand_state + index, 1, (curandState_t*)seq.random_state.data());
 
-        LlamaCacheManager::Sequence seq{};
-        if (r.start_flag) {
-            seq = llama_->kv_cache_mgr_->create(r.id, stream_);
-        }
-        else {
-            seq = llama_->kv_cache_mgr_->fetch(r.id, stream_);
-        }
+        // Set unlock flag for corresponding blocks, will be unlocked in the next `Materialize()`
+        sequence_manager_->UpdateAndSetUnlock(seq);
+    }
 
-        const int step = r.inputs[rank_].getVal("step", -1);
-        if (step >= 0) {
-            if (step <= seq.token_ids.size()) {
-                seq.token_ids.resize(step);
-                seq.cache_len = std::min(seq.cache_len, (size_t)step);
-            }
-            else if (rank_ == 0) {
-                TM_LOG_WARNING("[initialize] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
-            }
+    state_->sequences[index] = nullptr;
+
+    // move the request handle into the signal
+    return [this, r = std::move(state_->requests[index])] {
+        if (rank_ == 0) {
+            r->signal.set_value(0);
         }
+    };
+}
 
-        // input length with missing cache accounted for
-        int actual_input_len = r.inputs[rank_].getVal("input_lengths") + (seq.token_ids.size() - seq.cache_len);
+template
+void LlamaBatch::InternalThreadEntry(int device_id)
+{
+    // TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
+    check_cuda_error(cudaSetDevice(device_id));
 
-        // insert `start_id` for empty sequences
-        if (seq.token_ids.empty() && actual_input_len == 0) {
-            seq.token_ids.push_back(llama_->start_id_);
-            seq.cache_len    = 0;
-            actual_input_len = seq.token_ids.size() - seq.cache_len;
-        }
+    auto& shared_state = model_->shared_state_;
 
-        tmp_input_length[i] = actual_input_len;
+    auto& request_queue  = shared_state->request_queue;
+    auto& infer_requests = shared_state->infer_requests;
+    auto& stop_requests  = shared_state->stop_requests;
 
-        tmp_max_input_length = std::max((int)tmp_max_input_length, actual_input_len);
-        tmp_cached_seq.push_back(std::move(seq));
-    }
+    GenerationState g{};
 
-    FT_CHECK(tmp_max_input_length > 0);
-    const int max_input_length = tmp_max_input_length;
+    constexpr int request_interval = 1;
+    long          request_counter  = 0;
 
-    // arrange requests in ascending order w.r.t actual input lengths, so that requests need context decoding will
-    // be together
-    {
-        std::vector idxs(tmp_input_length.size());
-        std::iota(idxs.begin(), idxs.end(), 0);
-        std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return tmp_input_length[i] < tmp_input_length[j]; });
-        for (int i = 0; i < idxs.size(); ++i) {
-            requests_[batch_size_ + i]   = infer_requests[idxs[i]];
-            cached_seq_[batch_size_ + i] = tmp_cached_seq[idxs[i]];
+    while (1) {
+        if (rank_ == 0) {
+            const int  free_slot_count = max_batch_size_ - state_->size + g.finished_count;
+            const bool is_empty        = (free_slot_count == max_batch_size_);
+            stop_requests.clear();
+            infer_requests.clear();
+            if (is_empty || request_counter % request_interval == 0) {
+                // Block if batch is empty
+                request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
+                if (!shared_state->abort) {
+                    RejectInvalidRequests(stop_requests, infer_requests);
+                }
+            }
         }
-    }
 
-    const int count = batch_size_ + infer_requests.size();
+        NvtxScope scope("mainloop");
 
-    std::vector tmp_input_len(count);
+        // wait while rank-0 is dequeueing
+        shared_state->barrier->wait();
 
-    for (int i = batch_size_; i < count; ++i) {
-        const auto& seq = cached_seq_[i];
+        if (shared_state->abort) {
+            TM_LOG_INFO("[InternalThreadEntry] stop requested.");
+            return;
+        }
 
-        h_input_length_buf_[i] = requests_[i]->inputs[rank_].getVal("input_lengths");
-        tmp_input_len[i]       = h_input_length_buf_[i];
-        // prepare output ids
-        // <--------> max_context_len
-        // aaaAAAA
-        // bbbbBBBBBB
-        // ccCCC
-        auto output_ids_ptr = output_ids_buf_ + i * session_len_;
+        auto signals = ProcessStopRequests(stop_requests);
 
-        // clear the persistent buffer to prevent leaking previous conversation
-        check_cuda_error(cudaMemsetAsync(output_ids_ptr, 0, sizeof(int) * session_len_, stream_));
+        // Shared `priority` field will be assigned by rank-0
+        ProcessInferRequests(infer_requests);
 
-        if (!seq.token_ids.empty()) {
-            check_cuda_error(cudaMemcpyAsync(output_ids_ptr,  //
-                                             seq.token_ids.data(),
-                                             sizeof(int) * seq.token_ids.size(),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            output_ids_ptr += seq.token_ids.size();
-        }
+        // Wait while shared `requests` is being used
+        shared_state->barrier->wait();
 
-        if (h_input_length_buf_[i]) {
-            auto input_ids_ptr = requests_[i]->inputs[rank_].getPtr("input_ids");
-            check_cuda_error(cudaMemcpyAsync(output_ids_ptr,  //
-                                             input_ids_ptr,
-                                             sizeof(int) * h_input_length_buf_[i],
-                                             cudaMemcpyDefault,
-                                             stream_));
-        }
+        SendSignals(std::move(signals));
 
-        if (!requests_[i]->start_flag && !seq.random_state_.empty()) {
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + i,
-                                             seq.random_state_.data(),
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + i,
-                                             seq.random_state_.data() + sizeof(curandState_t),
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-        }
-    }
+        Initialize(g);
 
-    for (int i = batch_size_; i < count; ++i) {
-        const auto& seq           = cached_seq_[i];
-        const int   missed        = (int)seq.token_ids.size() - seq.cache_len;
-        auto        input_ids_buf = input_ids_buf_ + i * session_len_;
-        FT_CHECK(missed >= 0);
-        if (missed > 0) {
-            check_cuda_error(cudaMemcpyAsync(input_ids_buf,  //
-                                             seq.token_ids.data() + seq.cache_len,
-                                             sizeof(int) * missed,
-                                             cudaMemcpyDefault,
-                                             stream_));
-            input_ids_buf += missed;
-        }
-        auto& input_ids = requests_[i]->inputs[rank_].at("input_ids");
-        check_cuda_error(cudaMemcpyAsync(input_ids_buf,  //
-                                         input_ids.getPtr(),
-                                         sizeof(int) * h_input_length_buf_[i],
-                                         cudaMemcpyDefault,
-                                         stream_));
-        h_input_length_buf_[i] += missed;
-        h_history_length_buf_[i] = seq.cache_len;
-        h_context_length_buf_[i] = h_input_length_buf_[i] + h_history_length_buf_[i];
-
-        const int request_output_len = requests_[i]->inputs[rank_].getVal("request_output_len");
-        request_seq_len_limit_[i]    = h_context_length_buf_[i] + request_output_len;
-        // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
-        // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
-        if (request_seq_len_limit_[i] >= session_len_) {
-            request_seq_len_limit_[i] = session_len_ - 1;
-            if (rank_ == 0) {
-                const int trunc_output_len = request_seq_len_limit_[i] - h_context_length_buf_[i];
-                TM_LOG_WARNING(
-                    "[initialize] [%ld] total sequence length (%d + %d) exceeds session_len (%d), request_output_len is truncated to %d",
-                    (long)seq.id,
-                    h_context_length_buf_[i],
-                    request_output_len,
-                    (int)session_len_,
-                    trunc_output_len);
+        FT_CHECK(step_length_ == 1);
+
+        if (state_->active_size) {
+            for (int i = 0; i < step_length_; ++i) {
+                //
+                auto cont = Forward(g, i);
+                //
+                if (auto signals = Finish(g); !signals.empty()) {
+                    if (g.finished_count) {
+                        // Finished requests and corresponding output tensors will be released when notified
+                        // wait for all ranks to ensure no rank (except for output thread) will access related
+                        // resources
+                        shared_state->barrier->wait();
+                    }
+                    SendSignals(std::move(signals));
+                }
+                if (!cont) {  // early exit
+                    break;
+                }
             }
         }
 
-        h_k_cache_ptr_buf_[i] = (uint64_t)seq.k_cache;
-        h_v_cache_ptr_buf_[i] = (uint64_t)seq.v_cache;
+        ++request_counter;
     }
 
-    const int max_context_len = *std::max_element(h_context_length_buf_ + batch_size_, h_context_length_buf_ + count);
-
-    batch_size_      = count;
-    max_context_len_ = max_context_len;
-    step_            = max_context_len;
-
-    check_cuda_error(
-        cudaMemcpyAsync(input_length_buf_, h_input_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        history_length_buf_, h_history_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-
-    if (llama_->tensor_para_.rank_ == 0) {
-        TM_LOG_INFO("[init] infer_request_count = %d", (int)infer_request_count);
-        TM_LOG_INFO("[init] batch_size = %d", (int)batch_size_);
-        TM_LOG_INFO("[init] session_len = %d", (int)session_len_);
-        TM_LOG_INFO("[init] max_input_length = %d", (int)max_input_length);
-        TM_LOG_INFO("[init] max_context_len = %d", (int)max_context_len);
-        TM_LOG_INFO(
-            "[init] slot  sequence_id  history_len  input_len  context_len  tmp_input_len  token_ids.size  cache_len");
-        for (int i = batch_size_ - infer_request_count; i < batch_size_; ++i) {
-            TM_LOG_INFO("[init] %4d  %11ld  %11d  %9d  %11d  %13d  %14d  %9d",
-                        i,
-                        (int)cached_seq_[i].id,
-                        h_history_length_buf_[i],
-                        h_input_length_buf_[i],
-                        h_context_length_buf_[i],
-                        tmp_input_len[i],
-                        (int)cached_seq_[i].token_ids.size(),
-                        (int)cached_seq_[i].cache_len);
-        }
-    }
+    FT_CHECK(0);
 }
 
 template
-void LlamaBatch::contextDecode()
+void LlamaBatch::SendSignals(std::vector signals)
 {
-    int base = -1;
-    for (int i = 0; i < batch_size_; ++i) {
-        if (h_input_length_buf_[i] > 1) {
-            base = i;
-            break;
-        }
+    if (rank_ != 0 || signals.empty()) {
+        return;
     }
-    if (base >= 0) {
-        check_cuda_error(cudaStreamSynchronize(stream_));
-        const auto tick = std::chrono::high_resolution_clock::now();
-
-        const int context_decode_count = batch_size_ - base;
-        if (rank_ == 0) {
-            TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
-        }
-        invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_);
-        invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
-
-        auto get_input_len   = [this](int index) { return h_input_length_buf_[index] - 1; };
-        auto get_context_len = [this](int index) { return h_context_length_buf_[index] - 1; };
-
-        std::vector decode_indices{base};
-        std::vector decode_lengths{get_input_len(base)};
-
-        auto token_num       = get_input_len(base);
-        auto max_input_len   = get_input_len(base);
-        auto max_context_len = get_context_len(base);
-        auto offset          = base;
-        for (int i = offset + 1; i <= batch_size_; ++i) {
-            if (i == batch_size_ || token_num + h_context_length_buf_[i] > max_context_token_num_) {
-                const int context_decode_batch_size = i - offset;
-                if (rank_ == 0) {
-                    TM_LOG_INFO(
-                        "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
-                        base,
-                        context_decode_batch_size,
-                        token_num,
-                        max_input_len,
-                        max_context_len);
-                }
-                // construct context_decoder_ids w/o padding
-                // aaaa____
-                // bb______ -> aaaabbcccccccc
-                // cccccccc
-                auto context_decoder_ids = context_decoder_ids_buf_;
-                for (int j = offset; j < i; ++j) {
-                    check_cuda_error(cudaMemcpyAsync(context_decoder_ids,
-                                                     input_ids_buf_ + j * session_len_,
-                                                     sizeof(int) * get_input_len(j),
-                                                     cudaMemcpyDefault,
-                                                     stream_));
-                    context_decoder_ids += get_input_len(j);
-                }
-                llama_->contextDecode(nullptr,
-                                      k_cache_ptr_buf_ + offset,
-                                      v_cache_ptr_buf_ + offset,
-                                      context_decoder_input_buf_,
-                                      context_decoder_output_buf_,
-                                      context_decoder_ids_buf_,
-                                      input_length_buf_ + offset,
-                                      history_length_buf_ + offset,
-                                      context_length_buf_ + offset,
-                                      token_num,
-                                      max_input_len,
-                                      max_context_len,
-                                      session_len_,
-                                      context_decode_batch_size);
-
-                // compute logits of inputs if requested
-                outputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
-
-                if (i < batch_size_) {
-                    // initialize next sub-batch
-                    token_num       = get_input_len(i);
-                    max_input_len   = get_input_len(i);
-                    max_context_len = get_context_len(i);
-                    offset          = i;
-
-                    decode_indices = {i};
-                    decode_lengths = {get_input_len(i)};
-                }
-            }
-            else {
-                // add to current sub-batch
-                token_num += get_input_len(i);
-                max_input_len   = std::max(max_input_len, get_input_len(i));
-                max_context_len = std::max(max_context_len, get_context_len(i));
-
-                decode_indices.push_back(i);
-                decode_lengths.push_back(get_input_len(i));
-            }
-        }
-
-        invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
-        invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_);
-
-        for (int i = offset; i < batch_size_; ++i) {
-            h_input_length_buf_[i] = 0;
-        }
-
-        check_cuda_error(cudaStreamSynchronize(stream_));
-        const auto tock = std::chrono::high_resolution_clock::now();
-        if (rank_ == 0) {
-            TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration(tock - tick).count());
-        }
+    {
+        std::lock_guard lock{output_mutex_};
+        output_signals_.insert(output_signals_.end(),  //
+                               std::move_iterator{signals.begin()},
+                               std::move_iterator{signals.end()});
     }
-    else if (rank_ == 0) {
-        TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
+    output_cv_.notify_one();
+}
+
+template
+void LlamaBatch::Start()
+{
+    TM_LOG_INFO("LlamaBatch::Start()");
+    int device_id = -1;
+    check_cuda_error(cudaGetDevice(&device_id));
+    internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
+    if (rank_ == 0) {
+        output_thread_ = std::thread(&LlamaBatch::OutputThreadEntry, this);
     }
 }
 
 template
-void LlamaBatch::outputContextLogits(T*                      context_decoder_output,
-                                        const std::vector& indices,
-                                        const std::vector& lengths)
+void LlamaBatch::OutputThreadEntry()
 {
-    std::vector output_logits;
-    int                 num_token = 0;
-    {
-        bool is_return_logits = false;
-        for (int k = 0; k < indices.size(); ++k) {
-            auto& request = requests_[indices[k]];
-            output_logits.push_back(request->outputs[rank_].getPtr("logits", nullptr));
-            num_token += lengths[k];
-            if (output_logits.back()) {
-                is_return_logits = true;
+    while (true) {
+        std::vector signals;
+        {
+            // Wait for signals to come
+            std::unique_lock lock(output_mutex_);
+            output_cv_.wait(lock, [&] { return !output_signals_.empty() || output_stop_token_; });
+            if (output_stop_token_) {
+                TM_LOG_INFO("[OutputThreadEntry] stop requested.");
+                return;
             }
+            signals = std::move(output_signals_);
         }
-        if (!is_return_logits) {
-            return;
+        if (rank_ == 0 && model_->ffi_lock_) {
+            model_->ffi_lock_(1);
         }
-    }
-
-    if (context_logits_buf_ == nullptr) {
-        NcclGuard guard(llama_->tensor_para_, stream_, true);
-        context_logits_buf_ =
-            (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
-        const auto tp = llama_->tensor_para_.world_size_;
-        if (tp > 1) {
-            FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
-            const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
-            local_context_logits_buf_ =
-                (float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
+        // invoke stream cbs & signals
+        for (const auto& s : signals) {
+            s();
         }
-    }
-
-    llama_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
-
-    auto logits = context_logits_buf_;
-
-    for (int k = 0; k < indices.size(); ++k) {
-        if (output_logits[k]) {
-            check_cuda_error(cudaMemcpyAsync(output_logits[k],
-                                             logits,
-                                             sizeof(float) * llama_->vocab_size_ * lengths[k],
-                                             cudaMemcpyDefault,
-                                             stream_));
+        if (rank_ == 0 && model_->ffi_lock_) {
+            model_->ffi_lock_(0);
         }
-        logits += llama_->vocab_size_padded_ * lengths[k];
     }
 }
 
 template
-void LlamaBatch::finish()
+bool LlamaBatch::Forward(GenerationState& g, int iter)
 {
-    // secure info needed by `synchronize()`
-    check_cuda_error(
-        cudaMemcpyAsync(h_finished_buf_, finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(
-        cudaMemcpyAsync(h_sequence_lengths_, sequence_lengths_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
+    NvtxScope _("Forward");
 
-    setOutputTensors(step_);
+    FT_CHECK(max_context_token_num_ >= max_batch_size_);
 
-    check_cuda_error(cudaStreamSynchronize(stream_));
+    const int active_size = state_->active_size;
 
-    if (rank_ == 0 && llama_->ffi_lock_) {
-        llama_->ffi_lock_(1);
+    constexpr int kLogInterval = 10;
+    if (rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
+        TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
     }
-    for (int i = 0; i < batch_size_; ++i) {
-        FT_CHECK(requests_[i] != nullptr);
-        if (requests_[i]->stream_cb && rank_ == 0) {
-            requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
+
+    int               pf_offset = -1;
+    std::vector input_d_ptrs(active_size);
+
+    if (iter == 0) {  // The first iter may have pre-fill tokens
+        for (int i = 0; i < active_size; ++i) {
+            const auto& seq = *state_->sequences[i];
+            // const int   missing    = state_->h_context_length[i] - seq.cache_len;
+            FT_CHECK(seq.input_length >= 1);
+            h_input_length_buf_[i] = seq.input_length;
+            input_d_ptrs[i]        = state_->output_ids + i * session_len_ + seq.cache_len;
+            if (seq.input_length > 1 && pf_offset < 0) {
+                pf_offset = i;
+            }
+        }
+        if (pf_offset < 0) {
+            pf_offset = active_size;
         }
     }
-    if (rank_ == 0 && llama_->ffi_lock_) {
-        llama_->ffi_lock_(0);
+    else {
+        for (int i = 0; i < active_size; ++i) {
+            h_input_length_buf_[i] = 1;
+            input_d_ptrs[i]        = state_->output_ids + i * session_len_ + state_->h_context_length[i] - 1;
+        }
+        pf_offset = active_size;
     }
 
-    if (debug_ && rank_ == 0) {
-        std::stringstream ss;
-        for (int i = 0; i < batch_size_; ++i) {
-            ss << (i ? ", " : "") << "(" << h_sequence_lengths_[i] << "," << h_finished_buf_[i] << ")";
-        }
-        TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
+    // These buffers are only accessed when there are prefill workloads
+    if (pf_offset != active_size) {
+        Copy(state_->h_context_length, active_size, context_length_buf_);
+        Copy(h_input_length_buf_, active_size, input_length_buf_);
     }
 
-    for (int i = 0; i < batch_size_; ++i) {
-        if (h_finished_buf_[i]) {
-            finishRequest(i, false);
-            ++finished_count_;
+    // Find mini-batch offsets: input length > 1 ? prefill() : decode()
+    // Constraints on mini-batches
+    // - `context_decoder_input` and `context_decoder_output` can hold `max_context_token_num_` tokens w/o padding
+    // - prefill() use `tmp_k_cache_buf_` and `tmp_k_cache_buf_`, they can hold `max_context_token_num_` tokens
+    //     but each sequence is padded to the maximum context length in the batch
+    std::vector offsets{0};
+    std::vector max_context_cnts;
+    // initialize first mini-batch with decode tokens
+    int accum_size        = pf_offset;
+    int accum_token_count = pf_offset;
+    int max_context_count = 0;
+    for (int i = pf_offset; i < active_size; ++i) {
+        FT_CHECK(iter == 0);
+        int size          = accum_size + 1;
+        int input_count   = accum_token_count + h_input_length_buf_[i];
+        int context_count = std::max(max_context_count, state_->h_context_length[i]);
+        // correct pre-fill batch size for the first batch
+        int pf_size = offsets.size() == 1 ? size - pf_offset : size;
+        // we have `cu_seqlens` on q so no padding for input is needed
+        // prefill kernels are expecting uniform k/v cache length -> `max_context_count * size <=
+        // max_context_token_num_`
+        if (input_count <= max_context_token_num_ && context_count * pf_size <= max_context_token_num_) {
+            accum_size        = size;
+            accum_token_count = input_count;
+            max_context_count = context_count;
+        }
+        else {
+            offsets.push_back(i);
+            max_context_cnts.push_back(max_context_count);
+            accum_size        = 1;
+            accum_token_count = h_input_length_buf_[i];
+            max_context_count = state_->h_context_length[i];
         }
     }
-}
-
-template
-void LlamaBatch::synchronize()
-{
-    // compact
-    int idx = 0;
-    for (int i = 0; i < batch_size_; ++i) {
-        if (requests_[i]) {
-            h_input_length_buf_[idx]   = 0;
-            h_history_length_buf_[idx] = 0;
-
-            h_context_length_buf_[idx] = h_sequence_lengths_[i] + 1;
-            h_sequence_lengths_[idx]   = h_context_length_buf_[idx];
-
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + idx,
-                                             llama_->dynamic_decode_layer_->topk_curandstate_buf() + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + idx,
-                                             llama_->dynamic_decode_layer_->topp_curandstate_buf() + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-
-            if (i != idx) {
-                h_finished_buf_[idx]        = h_finished_buf_[i];
-                request_seq_len_limit_[idx] = request_seq_len_limit_[i];
-
-                h_k_cache_ptr_buf_[idx] = h_k_cache_ptr_buf_[i];
-                h_v_cache_ptr_buf_[idx] = h_v_cache_ptr_buf_[i];
-
-                requests_[idx]   = std::move(requests_[i]);
-                cached_seq_[idx] = std::move(cached_seq_[i]);
-                check_cuda_error(cudaMemcpyAsync(output_ids_buf_ + idx * session_len_,
-                                                 output_ids_buf_ + i * session_len_,
-                                                 sizeof(int) * h_context_length_buf_[idx],
-                                                 cudaMemcpyDefault,
-                                                 stream_));
+    offsets.push_back(active_size);
+    max_context_cnts.push_back(max_context_count);
+
+    // forward on mini-batches
+    for (int p = 0; p < (int)offsets.size() - 1; ++p) {
+        int  first           = offsets[p];
+        int  last            = offsets[p + 1];
+        int  mini_batch_size = last - first;
+        T*   k_ptr           = tmp_k_cache_buf_;
+        T*   v_ptr           = tmp_v_cache_buf_;
+        int  max_input_len{};
+        auto input_ids = context_decoder_ids_buf_;
+        //
+        std::vector decode_indices{};
+        std::vector decode_lengths{};
+
+        BatchedCopy batched_copy;
+        for (int i = first; i < last; ++i) {
+            input_ids = batched_copy.Add(input_d_ptrs[i], h_input_length_buf_[i], input_ids);
+            dbg(i, h_input_length_buf_[i]);
+            // allocate tmp k/v buffer for pre-fill sequences
+            if (i < pf_offset) {
+                h_tmp_k_ptrs_[i] = h_tmp_v_ptrs_[i] = nullptr;
             }
-            ++idx;
+            else {
+                h_tmp_k_ptrs_[i] = k_ptr;
+                h_tmp_v_ptrs_[i] = v_ptr;
+                k_ptr += model_->local_kv_head_num_ * max_context_cnts[p] * model_->size_per_head_;
+                v_ptr += model_->local_kv_head_num_ * max_context_cnts[p] * model_->size_per_head_;
+            }
+            decode_indices.push_back(i);
+            decode_lengths.push_back(h_input_length_buf_[i]);
+            max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
         }
-    }
-    batch_size_ = idx;
+        int token_count = input_ids - context_decoder_ids_buf_;
 
-    if (rank_ == 0) {
-        TM_LOG_INFO("[synchronize] batch_size = %d", (int)batch_size_);
-    }
+        batched_copy.Submit(stream_);
 
-    finished_count_ = 0;
-}
+        Copy(h_tmp_k_ptrs_ + first, mini_batch_size, tmp_k_ptrs_ + first);
+        Copy(h_tmp_v_ptrs_ + first, mini_batch_size, tmp_v_ptrs_ + first);
 
-template
-void LlamaBatch::setOutputTensors(int max_gen_step)
-{
-    // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
-    invokeGatherOutput(output_ids_buf_,
-                       token_ids_buf_,
-                       context_length_buf_,
-                       max_context_len_,
-                       max_gen_step,
-                       session_len_,
-                       batch_size_,
-                       stream_);
-    sync_check_cuda_error();
+        const int dc_batch_size = p ? 0 : pf_offset;
+        const int pf_batch_size = mini_batch_size - dc_batch_size;
 
-    /// TODO: fuse the loop into a single kernel
-    for (int i = 0; i < batch_size_; ++i) {
-        if (requests_[i]) {
-            auto& output_ids      = requests_[i]->outputs[rank_].at("output_ids");
-            auto& sequence_length = requests_[i]->outputs[rank_].at("sequence_length");
-            check_cuda_error(cudaMemcpyAsync(output_ids.getPtr(),
-                                             output_ids_buf_ + i * session_len_,
-                                             sizeof(int) * output_ids.shape.at(2),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync(
-                sequence_length.getPtr(), sequence_lengths_ + i, sizeof(int), cudaMemcpyDefault, stream_));
-            if (max_gen_step > max_context_len_) {  // +1 for newly generated token
-                invokePlusScalar(sequence_length.getPtr(), 1, 1, stream_);
+        if (rank_ == 0) {
+            if (pf_batch_size) {
+                TM_LOG_INFO("[Forward] [%d, %d), dc_bsz = %d, pf_bsz = %d, n_tok = %d, max_q = %d, max_k = %d",
+                            first,
+                            last,
+                            dc_batch_size,
+                            pf_batch_size,
+                            token_count,
+                            max_input_len,
+                            max_context_cnts[p]);
             }
         }
+
+        model_->forwardUnified(decoder_output_buf_ + first * model_->hidden_units_,
+                               context_decoder_output_buf_,  // temp
+                               context_decoder_input_buf_,   // temp
+                               (void**)k_block_ptrs_,
+                               (void**)v_block_ptrs_,
+                               context_decoder_ids_buf_,  // temp
+                               cu_block_counts_ + first,
+                               rope_theta_ + first,
+                               finished_buf_ + first,
+                               input_length_buf_ + first,
+                               context_length_buf_ + first,
+                               (T**)tmp_k_ptrs_ + first,
+                               (T**)tmp_v_ptrs_ + first,
+                               token_count,
+                               dc_batch_size,
+                               g.step,
+                               g.sum_seq_len,
+                               g.max_seq_len,
+                               pf_batch_size,
+                               max_input_len,
+                               max_context_cnts[p],
+                               max_context_cnts[p]);
+
+        if (iter == 0) {
+            // compute logits of inputs if requested
+            OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
+        }
     }
-}
 
-template
-void LlamaBatch::finishRequest(int index, bool force_end)
-{
-    if (rank_ == 0) {
-        TM_LOG_INFO("[finishRequest] slot = %d, id = %lu", index, (long)requests_[index]->id);
+    std::fill(h_input_length_buf_, h_input_length_buf_ + active_size, 0);
+
+    // `SequenceManager` needs real-time value of cache length
+    for (int i = 0; i < active_size; ++i) {
+        if (state_->requests[i]) {
+            FT_CHECK(state_->sequences[i]);
+            state_->sequences[i]->cache_len += state_->sequences[i]->input_length;
+        }
+    }
+
+    bool should_stop{};
+
+    if (active_size > g.partial) {
+        model_->postDecodeEmbedding(logits_buf_, local_logits_buf_, decoder_output_buf_, active_size - g.partial);
+
+        FT_CHECK(g.step >= 0);
+
+        // TM_LOG_INFO("dyn decode bsz %d, partial %d", active_size, g.partial);
+
+        // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
+        // not supported yet.
+        model_->dynamicDecode(token_ids_buf_,
+                              finished_buf_,
+                              sequence_lengths_,
+                              &should_stop,
+                              state_->curand_state,
+                              &inputs_,
+                              &outputs_,
+                              logits_buf_,
+                              seq_limit_len_,
+                              init_context_length_,
+                              d_end_ids_buf_,
+                              g.step,
+                              0,
+                              g.max_init_ctx_len,
+                              session_len_ * 2,
+                              active_size - g.partial);
     }
 
     if (debug_ && rank_ == 0) {
-        std::vector tokens(h_sequence_lengths_[index] + 1);
-        cudaMemcpyAsync(tokens.data(),
-                        output_ids_buf_ + index * session_len_,
-                        sizeof(int) * tokens.size(),
-                        cudaMemcpyDefault,
-                        stream_);
+        std::vector curr(active_size);
+        Copy(token_ids_buf_ + g.step * active_size, active_size, curr.data());
         cudaStreamSynchronize(stream_);
-        std::stringstream ss;
-        for (const auto& t : tokens) {
-            ss << " " << t;
+        std::stringstream scurr;
+        for (int k = 0; k < curr.size(); ++k) {
+            scurr << std::setw(6) << curr[k];
         }
-        TM_LOG_INFO("[finishRequest] slot %d, tokens [%s]", index, ss.str().c_str());
+        TM_LOG_INFO("[Forward] step = %d, [%s]", g.step - 1, scurr.str().c_str());
     }
 
-    auto&      output_ids_tensor = requests_[index]->outputs[rank_].at("output_ids");
-    const auto output_ids_data   = output_ids_tensor.getPtr();
-    if (requests_[index]->end_flag || force_end) {
-        llama_->kv_cache_mgr_->erase(requests_[index]->id);
-    }
-    else {
-        // the last generated token is not processed by decoder thus dont have k/v cache
-        const int n_steps    = step_ - max_context_len_;
-        const int cache_len  = h_sequence_lengths_[index];
-        const int output_len = n_steps > 0 ? cache_len + 1 : cache_len;
-
-        auto& seq = cached_seq_[index];
-
-        seq.cache_len = cache_len;
-
-        // update token IDs
-        seq.token_ids.resize(output_len);
-        check_cuda_error(cudaMemcpyAsync(
-            seq.token_ids.data(), output_ids_data, sizeof(int) * output_len, cudaMemcpyDefault, stream_));
-
-        // update random states
-        seq.random_state_.resize(sizeof(curandState_t) * 2);
-        check_cuda_error(cudaMemcpyAsync(seq.random_state_.data(),
-                                         llama_->dynamic_decode_layer_->topk_curandstate_buf() + index,
-                                         sizeof(curandState_t),
-                                         cudaMemcpyDefault,
-                                         stream_));
-        check_cuda_error(cudaMemcpyAsync(seq.random_state_.data() + sizeof(curandState_t),
-                                         llama_->dynamic_decode_layer_->topp_curandstate_buf() + index,
-                                         sizeof(curandState_t),
-                                         cudaMemcpyDefault,
-                                         stream_));
+    // check_cuda_error(cudaStreamSynchronize(stream_));
 
-        check_cuda_error(cudaStreamSynchronize(stream_));
-
-        llama_->kv_cache_mgr_->update(cached_seq_[index], stream_);
-    }
+    ////////////////////////////////////////////////
+    /// ! increase the counters
+    g.step += 1;
+    g.max_seq_len += 1;
+    g.sum_seq_len += state_->active_size;
 
-    // When the signal is set threads from LlamaV2::forward can exit
-    // and free inputs/outputs tensors.
-    // Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
-    // are accessing the tensors.
-    llama_->shared_state_->barrier->wait();
-    if (rank_ == 0) {
-        requests_[index]->signal.set_value(0);
-    }
+    // PrintDecodeTokens(token_ids_buf_, g.step, active_size, stream_, "Forward");
 
-    requests_[index] = nullptr;
+    return !should_stop;
 }
 
 template class LlamaBatch;
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 280562ffb1..0173ddfceb 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -2,146 +2,277 @@
 
 #pragma once
 
-#include "src/turbomind/models/llama/LlamaCacheManager.h"
+// #include "src/turbomind/models/llama/LlamaCacheManager.h"
+#include "src/turbomind/layers/sampling_layers/BaseSamplingLayer.h"
+#include "src/turbomind/models/llama/Barrier.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/llama_kernels.h"
+#include "src/turbomind/models/llama/llama_params.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/cublasMMWrapper.h"
+#include "src/turbomind/utils/cuda_utils.h"
+#include 
+#include 
+#include 
 
 namespace turbomind {
 
+struct BatchState {
+    int*  h_context_length;
+    bool* h_finished;
+
+    curandState_t* curand_state;
+    int*           output_ids;  // output ids in [B, S]
+
+    float* h_rope_theta;
+
+    std::vector seq_len_limit;
+
+    std::vector          sequences;
+    std::vector> requests;
+
+    // |<-- existing -->|<-- swap-in -->|
+    // |<----------- active ----------->|<-- inactive -->|
+    int active_size;
+    int size;
+};
+
 template
 class LlamaV2;
 
+struct GenerationState {
+    int max_init_ctx_len;
+    int step;
+
+    int sum_seq_len;
+    int max_seq_len;
+
+    int partial;
+    int partial_context_legnth;
+
+    std::vector unique_ids;
+
+    int max_input_count1;
+    int max_input_count2;
+
+    std::deque min_input_count;
+
+    int finished_count;
+};
+
 template
 class LlamaBatch {
 public:
-    int size() const noexcept
-    {
-        return batch_size_;
-    };
+    void AllocateBuffer(size_t batch_size, size_t session_len);
+    void AllocatePersistantBuffer(size_t max_batch_size);
+    void FreeBuffer();
 
-    int maxSize() const noexcept
-    {
-        return max_batch_size_;
-    }
+    using Requests = std::vector>;
+    using Signal   = std::function;
 
-    int finishedCount() const noexcept
-    {
-        return finished_count_;
-    }
+    void RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs);
 
-    void verifyRequests(std::vector>& stop_reqs,
-                        std::vector>& infer_reqs);
-    void handleStopRequests(const std::vector>& requests);
+    [[nodiscard]] auto ProcessStopRequests(const Requests& requests) -> std::vector;
 
-    void allocateBuffer(size_t batch_size, size_t session_len);
-    void allocatePersistantBuffer(size_t max_batch_size);
-    void freeBuffer();
+    void ProcessInferRequests(const Requests& requests);
 
-    void initializeSampling(int infer_request_count);
+    void AdjustMaxInputCount(GenerationState&                    g,
+                             const std::vector& sequences,
+                             const std::vector&             context_length);
 
-    void initialize(const std::vector>& infer_requests);
-    void contextDecode();
+    void Initialize(GenerationState& g);
 
-    void initializeGeneration();
-    bool generate();
+    void InitializeSampling(const GenerationState& g);
 
-    void finish();
-    void finishRequest(int index, bool force_end);
+    [[nodiscard]] bool Forward(GenerationState& g, int iter);
 
-    void synchronize();
+    [[nodiscard]] auto Finish(GenerationState& g) -> std::vector;
 
-    void setOutputTensors(int max_gen_step);
+    [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
 
     void
-    outputContextLogits(T* context_decoder_output, const std::vector& indices, const std::vector& lengths);
+    OutputContextLogits(T* context_decoder_output, const std::vector& indices, const std::vector& lengths);
 
-    explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2* llama);
+    explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model);
 
     ~LlamaBatch()
     {
-        freeBuffer();
+        TM_LOG_INFO("~LlamaBatch()");
+        model_->shared_state_->request_queue.close();
+
+        internal_thread_.join();
+
+        if (output_thread_.joinable()) {
+            {
+                std::lock_guard lock{output_mutex_};
+                output_stop_token_ = true;
+            }
+            output_cv_.notify_one();
+            output_thread_.join();
+        }
+
+        FreeBuffer();
+    }
+
+    void Start();
+
+private:
+    void InternalThreadEntry(int device_id);
+
+    void OutputThreadEntry();
+
+    void CopyState(const std::vector>& desc);
+
+    void SendSignals(std::vector signals);
+
+    // analogs to `std::copy_n`
+    template
+    U* Copy(const U* src, size_t count, U* dst)
+    {
+        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(U) * count, cudaMemcpyDefault, stream_));
+        return dst += count;
+    }
+
+    template
+    U* Clear(U* data, size_t count)
+    {
+        check_cuda_error(cudaMemsetAsync(data, 0, sizeof(U) * count, stream_));
+        return data += count;
+    }
+
+    template
+    void IndexedCopyImpl(const int* src_idx, const int* dst_idx, int count, const std::tuple&... cpys)
+    {
+        if (!count) {
+            return;
+        }
+        constexpr int N = sizeof...(Ts);
+        static_assert((!std::is_same_v && ...));
+        std::array src_ptr{std::get<0>(cpys)...};
+        std::array dst_ptr{std::get<1>(cpys)...};
+        std::array   elem_sz{int(sizeof(Ts) * std::get<2>(cpys))...};
+        invokeIndexedCopy(src_ptr.data(),  //
+                          dst_ptr.data(),
+                          elem_sz.data(),
+                          src_idx,
+                          dst_idx,
+                          count,
+                          N,
+                          stream_);
+        sync_check_cuda_error();
+    }
+
+    template
+    void IndexedCopy(const std::vector& src_idx,
+                     const std::vector& dst_idx,
+                     const std::tuple&... cpys)
+    {
+        // has the same size, or one is empty
+        FT_CHECK(src_idx.size() == dst_idx.size() || (src_idx.empty() ^ dst_idx.empty()));
+        IndexedCopyImpl(src_idx.empty() ? nullptr : src_idx.data(),
+                        dst_idx.empty() ? nullptr : dst_idx.data(),
+                        std::max(src_idx.size(), dst_idx.size()),
+                        cpys...);
+    }
+
+    template
+    void IndexedCopy(int count, const std::tuple&... cpys)
+    {
+        IndexedCopyImpl(nullptr, nullptr, count, cpys...);
     }
 
 private:
     const int  max_batch_size_;
     const int  max_context_token_num_;
-    const int  session_len_;
+    int        session_len_;
     const int  rank_;
     const bool debug_;
+    const int  step_length_;
 
-    LlamaV2* const llama_;
-
-    // active requests
-    std::vector> requests_;
-
-    T*   context_decoder_input_buf_{};   // CTXDEC
-    T*   context_decoder_output_buf_{};  // CTXDEC
-    int* context_decoder_ids_buf_{};
-
-    T* decoder_input_buf_{};   // CTXDEC, GENERATE
-    T* decoder_output_buf_{};  // CTXDEC, GENERATE
+    LlamaV2* const model_;
 
-    int* input_ids_buf_{};       // input token ids + cache missed token ids, CTXDEC
-    int* input_length_buf_{};    // input + cache missed length, CTXDEC, GENERATE
-    int* history_length_buf_{};  // history length, CTXDEC
-    int* context_length_buf_{};  // history length + input_length, CTXDEC, GENERATE
+    std::unique_ptr sequence_manager_;
 
-    int* total_padding_count_{};  // GENERATE
-    int* sequence_lengths_{};     // current sequence length
+    ///////////////////////////////////////////////////////////////////
+    // k/v cache block buffers
+    int*       cu_block_counts_{};
+    uintptr_t* k_block_ptrs_{};
+    uintptr_t* v_block_ptrs_{};
 
-    uint64_t* k_cache_ptr_buf_{};
-    uint64_t* v_cache_ptr_buf_{};
+    ////////////////////////////////////////////////////////////////////
+    // context decoding temp buffers
+    T*   context_decoder_input_buf_{};
+    T*   context_decoder_output_buf_{};
+    int* context_decoder_ids_buf_{};
+    int* input_ids_buf_{};
+    // lengths
+    int* input_length_buf_{};    // input + cache missed length
+    int* context_length_buf_{};  // history length + input_length
+    int* init_context_length_{};
+    // temp buffers used for block->linear kv-cache conversion
+    T*     tmp_k_cache_buf_{};
+    T*     tmp_v_cache_buf_{};
+    void** tmp_k_ptrs_{};
+    void** tmp_v_ptrs_{};
+    void** h_tmp_k_ptrs_{};
+    void** h_tmp_v_ptrs_{};
+
+    T*   decoder_input_buf_{};
+    T*   decoder_output_buf_{};
+    int* sequence_lengths_{};  // current sequence length
+    int* init_ctx_lens_{};
 
     float* logits_buf_{};        // combined logits
     float* local_logits_buf_{};  // tensor parallel local logits
     float* context_logits_buf_{};
     float* local_context_logits_buf_{};
 
+    float* rope_theta_{};
+
     // used by dynamic decoder
-    int*      token_ids_buf_{};   // all token IDs in [S, B], indexed using `step`
-    int*      output_ids_buf_{};  // output ids in [B, S]
-    int*      end_ids_buf_{};
+    int*      token_ids_buf_{};  // all token IDs in [S, B], indexed using `step`
     bool*     finished_buf_{};
     uint32_t* seq_limit_len_{};
+    int*      h_end_ids_buf_{};
+    int*      d_end_ids_buf_{};
 
     // pinned buffers
     int*       h_input_ids_buf_{};
     int*       h_input_length_buf_{};
-    int*       h_history_length_buf_{};
-    int*       h_context_length_buf_{};
-    int*       h_sequence_lengths_{};
-    bool*      h_finished_buf_{};
-    uintptr_t* h_k_cache_ptr_buf_{};
-    uintptr_t* h_v_cache_ptr_buf_{};
     uint32_t*  h_seq_limit_len_{};
+    int*       h_cu_block_counts_{};
+    uintptr_t* h_k_block_ptrs_{};
+    uintptr_t* h_v_block_ptrs_{};
 
-    int*      stop_words_buf_{};  // [batch_size, 2, kMaxStopWordsLen]
-    int*      bad_words_buf_{};
-    int*      h_runtime_top_k_{};
-    float*    h_runtime_top_p_{};
-    float*    h_temperature_{};
-    float*    h_repetition_penalty_{};
-    uint64_t* h_random_seed_{};
+    int*   h_runtime_top_k_{};
+    float* h_runtime_top_p_{};
+    float* h_temperature_{};
+    float* h_repetition_penalty_{};
+    int*   h_stop_words_{};  // [batch_size, 2, kMaxStopWordsLen]
+    int*   h_bad_words_{};
+    int*   d_stop_words_{};  // [batch_size, 2, kMaxStopWordsLen]
+    int*   d_bad_words_{};
 
-    void* topk_curandstate_buf_{};
-    void* topp_curandstate_buf_{};
+    unsigned long long* h_random_seed_{};
+    unsigned long long* d_random_seed_{};
 
-    // hard limits for persistent buffers
-    static constexpr int kMaxStopBadWordsLen = 32;
+    curandState_t* h_curand_state_{};
+    curandState_t* d_curand_state_{};
 
-    using CachedSeq = LlamaCacheManager::Sequence;
+    std::array states_{};
 
-    std::vector cached_seq_;
-    std::vector       request_seq_len_limit_;
+    BatchState* state_{};
+    BatchState* back_{};
+    BatchState* incoming_{};
 
-    const DataType data_type_{};
+    uint64_t request_count_{0};
 
-    int batch_size_{};
-    int max_context_len_{};
-    int step_{};
-    int finished_count_{};
+    // hard limits for persistent buffers
+    static constexpr int kMaxStopBadWordsLen = 32;
+
+    const DataType data_type_{};
 
     bool is_allocate_persistant_buffer_ = false;
     bool is_allocate_buffer_            = false;
@@ -149,11 +280,26 @@ class LlamaBatch {
     TensorMap inputs_;
     TensorMap outputs_;
 
-    std::unordered_map sampling_params_;
+    std::vector> sampling_params_;
 
     cudaStream_t     stream_{};
     cublasMMWrapper* cublas_wrapper_{};
     IAllocator*      allocator_{};
+
+    std::thread internal_thread_;
+
+    // async stream callback utils
+    std::thread             output_thread_;
+    std::mutex              output_mutex_;
+    std::condition_variable output_cv_;
+    std::vector     output_signals_;
+    bool                    output_stop_token_{false};
+
+    int* h_output_ids_{};
+
+    const int num_tokens_per_iter_;
+    const int extra_tokens_per_iter_;
+    const int max_prefill_iters_;
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaCacheManager.cc b/src/turbomind/models/llama/LlamaCacheManager.cc
deleted file mode 100644
index 1b929baa35..0000000000
--- a/src/turbomind/models/llama/LlamaCacheManager.cc
+++ /dev/null
@@ -1,192 +0,0 @@
-// Copyright (c) OpenMMLab. All rights reserved.
-
-#include "src/turbomind/models/llama/LlamaCacheManager.h"
-#include "src/turbomind/utils/cuda_utils.h"
-#include "src/turbomind/utils/logger.h"
-
-namespace turbomind {
-
-LlamaCacheManager::~LlamaCacheManager()
-{
-    for (auto& p : device_mem_) {
-        allocator_->free(&p, false);
-    }
-}
-
-void* LlamaCacheManager::allocate(bool is_preallocte)
-{
-    if (rank_ == 0) {
-        TM_LOG_INFO("[LlamaCacheManager][allocate]");
-    }
-
-    void* mem_ptr{};
-
-    if (!device_free_.empty()) {
-        mem_ptr = device_free_.front();
-        device_free_.pop();
-
-        if (rank_ == 0) {
-            TM_LOG_INFO("[LlamaCacheManager][allocate] free = %d", (int)device_free_.size());
-        }
-    }
-    else if (entry_count_ < max_entry_count_) {
-        const auto   alloc_count     = std::min(chunk_size_, max_entry_count_ - entry_count_);
-        const size_t entry_byte_size = 2 * cache_byte_size_;  // 2 for k,v
-
-        if (rank_ == 0) {
-            TM_LOG_INFO("[LlamaCacheManager][allocate] malloc %d", (int)alloc_count);
-        }
-        const auto chunk_ptr = allocator_->malloc(alloc_count * entry_byte_size, false);
-        FT_CHECK(chunk_ptr);
-        device_mem_.push_back(chunk_ptr);
-        entry_count_ += alloc_count;
-        if (rank_ == 0) {
-            TM_LOG_INFO("[LlamaCacheManager][allocate] count = %d", entry_count_);
-        }
-
-        for (int i = 0; i < alloc_count; ++i) {
-            device_free_.push((uint8_t*)chunk_ptr + entry_byte_size * i);
-        }
-
-        if (!is_preallocte) {
-            mem_ptr = device_free_.front();
-            device_free_.pop();
-        }
-
-        if (rank_ == 0) {
-            TM_LOG_INFO("[LlamaCacheManager][allocate] free = %d", (int)device_free_.size());
-        }
-    }
-    else {
-        mem_ptr = evict();
-        FT_CHECK_WITH_INFO(mem_ptr, "No enough cache entries.");
-    }
-
-    return mem_ptr;
-}
-
-auto LlamaCacheManager::create(uint64_t id, cudaStream_t stream) -> Sequence
-{
-    if (rank_ == 0) {
-        TM_LOG_INFO("[LlamaCacheManager][create] %ld", (long)id);
-    }
-
-    for (const auto& e : device_cache_) {
-        if (e.id == id) {
-            if (rank_ == 0) {
-                TM_LOG_WARNING("[LlamaCacheManager][create] Removing conflicting id %ld", (long)id);
-            }
-            erase(id);
-        }
-    }
-
-    const auto mem_ptr = (uint8_t*)allocate(false);
-    check_cuda_error(cudaMemsetAsync(mem_ptr, 0, cache_byte_size_ * 2, stream));
-
-    device_cache_.push_back({
-        id,
-        max_seq_len_,
-        {},
-        0,
-        mem_ptr,
-        mem_ptr + cache_byte_size_,
-        {},
-        static_cast(-1),
-    });
-
-    return device_cache_.back();
-}
-
-auto LlamaCacheManager::getEntryOrThrow(uint64_t id) -> std::vector::iterator
-{
-    auto pred = [&](const Sequence& s) { return s.id == id; };
-    auto it   = std::find_if(device_cache_.begin(), device_cache_.end(), pred);
-    if (it == device_cache_.end()) {
-        TM_LOG_ERROR("[LlamaCacheManager] %ld not found.\n", (long)id);
-        FT_CHECK(0);
-    }
-    return it;
-}
-
-auto LlamaCacheManager::fetch(uint64_t id, cudaStream_t stream) -> Sequence
-{
-    if (rank_ == 0) {
-        TM_LOG_INFO("[LlamaCacheManager][fetch] %ld", (long)id);
-    }
-
-    auto entry = getEntryOrThrow(id);
-
-    if (entry->k_cache == nullptr) {
-        FT_CHECK(entry->cache_len == 0);
-        const auto mem_ptr = allocate(false);
-        check_cuda_error(cudaMemsetAsync(mem_ptr, 0, cache_byte_size_ * 2, stream));
-        entry->k_cache = mem_ptr;
-        entry->v_cache = (uint8_t*)entry->k_cache + cache_byte_size_;
-    }
-
-    entry->timestamp = static_cast(-1);
-    return *entry;
-}
-
-void LlamaCacheManager::update(const Sequence& seq, cudaStream_t stream)
-{
-    if (rank_ == 0) {
-        TM_LOG_INFO("[LlamaCacheManager][update] %ld", (long)seq.id);
-    }
-
-    auto entry = getEntryOrThrow(seq.id);
-
-    entry->timestamp = ++timestamp_;
-    entry->token_ids = seq.token_ids;
-    entry->cache_len = seq.cache_len;
-    FT_CHECK(seq.k_cache == entry->k_cache && seq.v_cache == entry->v_cache);
-}
-
-void LlamaCacheManager::erase(uint64_t id)
-{
-    if (rank_ == 0) {
-        TM_LOG_INFO("[LlamaCacheManager][erase] %ld", (long)id);
-    }
-
-    auto entry = getEntryOrThrow(id);
-
-    if (entry->k_cache) {
-        device_free_.push(entry->k_cache);
-        if (rank_ == 0) {
-            TM_LOG_INFO("[LlamaCacheManager][erase] free = %d", (int)device_free_.size());
-        }
-    }
-    device_cache_.erase(entry);
-}
-
-void* LlamaCacheManager::evict()
-{
-    FT_CHECK(!device_cache_.empty());
-    auto it = std::min_element(device_cache_.begin(), device_cache_.end(), [](const auto& a, const auto& b) {
-        return a.timestamp < b.timestamp;
-    });
-
-    if (it->timestamp == static_cast(-1)) {
-        return nullptr;
-    }
-
-    if (rank_ == 0) {
-        TM_LOG_INFO("[LlamaCacheManager][evict] %ld", (long)it->id);
-    }
-
-    FT_CHECK(it->k_cache);
-    auto mem_ptr = it->k_cache;
-    it->k_cache = it->v_cache = nullptr;
-    it->cache_len             = 0;
-    it->timestamp             = static_cast(-1);
-    return mem_ptr;
-}
-
-bool LlamaCacheManager::contains(uint64_t id) const noexcept
-{
-    auto pred = [&](const Sequence& s) { return s.id == id; };
-    auto it   = std::find_if(device_cache_.begin(), device_cache_.end(), pred);
-    return it != device_cache_.end();
-}
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaCacheManager.h b/src/turbomind/models/llama/LlamaCacheManager.h
deleted file mode 100644
index 2dd539550c..0000000000
--- a/src/turbomind/models/llama/LlamaCacheManager.h
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright (c) OpenMMLab. All rights reserved.
-
-#include "src/turbomind/utils/allocator.h"
-#include "src/turbomind/utils/logger.h"
-#include 
-#include 
-#include 
-#include 
-#include 
-
-namespace turbomind {
-
-// k-cache layout [L, H, D/x, S[s:], x]
-// v-cache layout [L, H, S[s:], D/x, x]
-
-class LlamaCacheManager {
-public:
-    LlamaCacheManager(size_t      layer_num,
-                      size_t      head_num,
-                      size_t      size_per_head,
-                      size_t      max_seq_len,
-                      size_t      elem_bits,
-                      size_t      max_entry_count,
-                      size_t      chunk_size,
-                      int         rank,
-                      IAllocator* allocator):
-        layer_num_(layer_num),
-        head_num_(head_num),
-        size_per_head_(size_per_head),
-        max_seq_len_(max_seq_len),
-        elem_bits_(elem_bits),
-        cache_byte_size_(layer_num_ * head_num_ * max_seq_len_ * size_per_head_ * elem_bits_ / 8),
-        max_entry_count_(max_entry_count),
-        chunk_size_(chunk_size),
-        rank_(rank),
-        allocator_(allocator)
-    {
-        if (rank == 0) {
-            TM_LOG_INFO("[LlamaCacheManager] max_entry_count = %d", (int)max_entry_count_);
-            TM_LOG_INFO("[LlamaCacheManager] chunk_size = %d", (int)chunk_size_);
-        }
-        allocate(true);
-    }
-
-    ~LlamaCacheManager();
-
-    struct Sequence {
-        // header
-        uint64_t id;
-        size_t   max_seq_len;
-
-        // payloads
-        std::vector token_ids;  // all token ids
-        size_t           cache_len;  // cache_len == 0 -> cache miss
-        void*            k_cache;
-        void*            v_cache;
-
-        std::vector random_state_;  // states for RNGs
-
-        // for LRU policy
-        uint64_t timestamp;
-    };
-
-    Sequence create(uint64_t id, cudaStream_t stream);
-
-    Sequence fetch(uint64_t id, cudaStream_t stream);
-
-    void update(const Sequence& seq, cudaStream_t stream);
-
-    void erase(uint64_t id);
-
-    bool contains(uint64_t id) const noexcept;
-
-private:
-    std::vector::iterator getEntryOrThrow(uint64_t id);
-
-    void* allocate(bool is_preallocte);
-
-    void* evict();
-
-private:
-    const size_t layer_num_{};
-    const size_t head_num_{};
-    const size_t size_per_head_{};
-    const size_t max_seq_len_{};
-    const size_t elem_bits_{};
-    const size_t cache_byte_size_{};
-    const size_t max_entry_count_{};
-    const size_t chunk_size_{};
-    const int    rank_{};
-    IAllocator*  allocator_{};
-
-    std::queue  device_free_;
-    std::vector device_mem_;
-    int                entry_count_{};
-
-    uint64_t timestamp_{};
-
-    std::vector device_cache_;
-};
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
deleted file mode 100644
index 881582acea..0000000000
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ /dev/null
@@ -1,423 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
- * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/GptContextAttentionLayer.cc
-
-#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
-#include "src/turbomind/kernels/bert_preprocess_kernels.h"
-#include "src/turbomind/kernels/unfused_attention_kernels.h"
-#include "src/turbomind/macro.h"
-#include "src/turbomind/models/llama/LlamaNcclGuard.h"
-#include "src/turbomind/models/llama/llama_kernels.h"
-#include "src/turbomind/models/llama/llama_utils.h"
-#include "src/turbomind/utils/Tensor.h"
-#include "src/turbomind/utils/cuda_utils.h"
-#include "src/turbomind/utils/logger.h"
-
-namespace turbomind {
-
-template
-void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size,
-                                                   size_t num_token,
-                                                   size_t max_q_len,
-                                                   size_t max_k_len)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-    const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
-
-    // no padding
-    qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true);
-
-    // padding is rebuilt for q/k/v_buf_2_
-    // [qH + 2kvH, B, S, D]
-    q_buf_2_ = (T*)allocator_->reMalloc(
-        q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true);
-    k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
-    v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
-
-    if (use_fmha_) {
-        FlashAttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
-        if (flash_attention.get_workspace_size() > 0) {
-            qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), true);
-        }
-    }
-    else {
-        // kv heads are repeated for unfused attention
-        k_cache_buf_ = (T*)allocator_->reMalloc(
-            k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true);
-        v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;
-
-        qk_buf_ =
-            (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true);
-
-        // qkv_buf_2_ has padding
-        qkv_buf_2_ = (T*)allocator_->reMalloc(
-            qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true);
-    }
-
-    // qkv_buf_3_ padding is removed
-    qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true);
-
-    is_allocate_buffer_ = true;
-}
-
-template
-void LlamaContextAttentionLayer::freeBuffer()
-{
-    if (is_allocate_buffer_) {
-        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-        allocator_->free((void**)(&qkv_buf_));
-        allocator_->free((void**)(&q_buf_2_));
-        if (use_fmha_) {
-            allocator_->free((void**)&qk_buf_float_);
-        }
-        else {
-            allocator_->free((void**)(&k_cache_buf_));
-            allocator_->free((void**)(&qk_buf_));
-            allocator_->free((void**)(&qkv_buf_2_));
-        }
-        allocator_->free((void**)(&qkv_buf_3_));
-
-        is_allocate_buffer_ = false;
-    }
-}
-
-template
-inline void LlamaContextAttentionLayer::forward(TensorMap*                     output_tensors,
-                                                   const TensorMap*               input_tensors,
-                                                   const LlamaAttentionWeight* weights)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-    /**
-     * input_tensors:
-     *   \param input_query [token_num, hidden_dim]
-     *   \param attention_mask [batch_size, 1, max_q_len, max_kv_len]
-     *   \param padding_offset [token_num], int
-     *   \param input_lengths [batch_size], int
-     *   \param history_lengths [batch_size], int
-     *   \param context_lengths [batch_size], int
-     *   \param cu_seqlens [batch_size+1], int
-     *   \param max_seq_len [1], int on cpu
-     *   \param is_final_layer [1], bool on cpu
-     *   \param layer_id [1], int on cpu
-     *
-     * output_tensors:
-     *   \param hidden_features [token_num, hidden_dim]
-     *   \param key_cache [batch_size], uint64
-     *   \param value_cache [batch_size], uint64
-     */
-
-    /////////////////////////////////////////////
-    /// parse inputs
-    const int batch_size = input_tensors->at("attention_mask").shape[0];
-    const int max_q_len  = input_tensors->at("attention_mask").shape[2];
-    const int max_k_len  = input_tensors->at("attention_mask").shape[3];
-    const int layer_id   = input_tensors->getVal("layer_id");
-
-    const int num_token = input_tensors->at("input_query").shape[0];
-
-    const int max_seq_len = input_tensors->at("max_seq_len").getVal();
-
-    T* attention_out   = output_tensors->at("hidden_features").getPtr();
-    T* attention_input = input_tensors->at("input_query").getPtr();
-    T* attention_mask  = input_tensors->at("attention_mask").getPtr();
-
-    const auto input_length   = input_tensors->at("input_lengths").getPtr();
-    const auto history_length = input_tensors->at("history_lengths").getPtr();
-    const auto context_length = input_tensors->at("context_lengths").getPtr();
-    int*       cu_seqlens     = input_tensors->at("cu_seqlens").getPtr();
-
-    const auto padding_offset = input_tensors->at("padding_offset").getPtr();
-
-    /////////////////////////////////////////////
-    /// allocate buffers
-    allocateBuffer(batch_size, num_token, max_q_len, max_k_len);
-
-    //////////////////////////////////////////////
-    /// qkv gemm
-    // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim]
-    linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv);
-
-    //////////////////////////////////////////////
-    /// transpose qkv & apply rotary embedding & rebuild padding
-    /// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D])
-    invokeAddFusedQKVBiasTranspose(q_buf_2_,
-                                   k_buf_2_,
-                                   v_buf_2_,
-                                   qkv_buf_,
-                                   weights->qkv.bias,
-                                   padding_offset,  // padding_offset,
-                                   history_length,  // used for applying rotary embedding
-                                   input_length,
-                                   batch_size,
-                                   max_q_len,  // seq_len
-                                   num_token,  // batch_size * seq_len
-                                   local_head_num_,
-                                   local_kv_head_num_,
-                                   size_per_head_,
-                                   params_.rotray_embedding_dim,
-                                   params_.rotary_embedding_base,
-                                   params_.max_position_embeddings,
-                                   params_.use_dynamic_ntk,
-                                   params_.use_logn_attn,
-                                   stream_);
-    sync_check_cuda_error();
-
-    const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
-
-    auto k_cache_ptrs = output_tensors->getPtr("key_cache");
-    auto v_cache_ptrs = output_tensors->getPtr("value_cache");
-    //////////////////////////////////////////////////////////
-    /// insert the k/v computed from inputs into k/v cache
-    /// transpose kv -> kv cache
-    // put k/v_buf from shape [B, kvH, s, D] to
-    // k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x]
-    // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x]
-    invokeExtendKVCache(k_cache_ptrs,
-                        v_cache_ptrs,
-                        layer_offset,
-                        k_buf_2_,
-                        v_buf_2_,
-                        batch_size,
-                        input_length,
-                        max_q_len,
-                        history_length,
-                        max_seq_len,
-                        size_per_head_,
-                        local_kv_head_num_,
-                        stream_,
-                        quant_policy_,
-                        weights->past_kv_scale.data());
-
-    sync_check_cuda_error();
-    if (use_fmha_) {
-        fusedMultiHeadAttention(k_cache_ptrs,
-                                v_cache_ptrs,
-                                layer_offset,
-                                attention_mask,
-                                cu_seqlens,
-                                input_tensors->at("context_lengths").getPtr(),
-                                batch_size,
-                                max_q_len,
-                                max_k_len,
-                                max_seq_len);
-    }
-    else {
-        unfusedMultiHeadAttention(k_cache_ptrs,
-                                  v_cache_ptrs,
-                                  layer_offset,
-                                  attention_mask,
-                                  padding_offset,
-                                  context_length,
-                                  batch_size,
-                                  num_token,
-                                  max_q_len,
-                                  max_k_len,
-                                  max_seq_len,
-                                  quant_policy_,
-                                  weights->past_kv_scale.data());
-    }
-
-    //////////////////////////////////////////////
-    /// output gemm  -> 
-    linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output);
-
-    if (tensor_para_.world_size_ > 1) {
-        NcclGuard nccl_guard(tensor_para_, stream_);
-        ftNcclAllReduceSum(attention_out, attention_out, num_token * hidden_units_, tensor_para_, stream_);
-        sync_check_cuda_error();
-    }
-
-    if (is_free_buffer_after_forward_ == true) {
-        freeBuffer();
-    }
-    sync_check_cuda_error();
-}
-
-template
-void LlamaContextAttentionLayer::fusedMultiHeadAttention(T**    key_cache_ptrs,
-                                                            T**    val_cache_ptrs,
-                                                            size_t cache_layer_offset,
-                                                            T*     attention_mask,
-                                                            int*   cu_seqlens,
-                                                            int*   context_lengths,
-                                                            int    batch_size,
-                                                            int    max_q_len,
-                                                            int    max_k_len,
-                                                            int    max_seq_len)
-{
-    //////////////////////////////////////////////
-    // flash attention
-    // flash attention 2 only support half inputs
-    using AttentionOp = FlashAttentionOp;
-    using Layout      = typename AttentionOp::AttentionLayout;
-    Layout layout_q{
-        int(local_head_num_ * max_q_len * size_per_head_), int(size_per_head_), int(max_q_len * size_per_head_)};
-    Layout layout_k{int(local_head_num_ * max_seq_len * size_per_head_),
-                    int(size_per_head_),
-                    int(max_seq_len * size_per_head_),
-                    false,
-                    cache_layer_offset,
-                    key_cache_ptrs};
-    Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
-                    int(size_per_head_),
-                    int(max_seq_len * size_per_head_),
-                    false,
-                    cache_layer_offset,
-                    val_cache_ptrs};
-    Layout layout_o{
-        int(local_head_num_ * max_q_len * size_per_head_),
-        int(local_head_num_ * size_per_head_),
-        int(size_per_head_),
-        true,
-    };
-    size_t                       group_size = size_t(local_head_num_ / local_kv_head_num_);
-    AttentionOp                  flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
-    typename AttentionOp::Params attn_params{qkv_buf_3_,
-                                             q_buf_2_,
-                                             k_cache_buf_,
-                                             v_cache_buf_,
-                                             attention_mask,
-                                             qk_buf_float_,
-                                             cu_seqlens,
-                                             nullptr,
-                                             nullptr,
-                                             context_lengths,
-                                             group_size,
-                                             layout_q,
-                                             layout_k,
-                                             layout_v,
-                                             layout_o};
-
-    //
-    flash_attention(attn_params, stream_);
-}
-
-template
-void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T**          key_cache_ptrs,
-                                                              T**          val_cache_ptrs,
-                                                              size_t       cache_layer_offset,
-                                                              const T*     attention_mask,
-                                                              const int*   padding_offset,
-                                                              const int*   context_length,
-                                                              int          batch_size,
-                                                              int          num_token,
-                                                              int          max_q_len,
-                                                              int          max_k_len,
-                                                              int          max_seq_len,
-                                                              int          quant,
-                                                              const float* kv_scale)
-{
-    // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
-    // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
-    invokeTransposeKVCache(k_cache_buf_,
-                           v_cache_buf_,
-                           (const T**)key_cache_ptrs,
-                           (const T**)val_cache_ptrs,
-                           cache_layer_offset,
-                           batch_size,
-                           context_length,  // history_len + input_len = context_len
-                           max_k_len,
-                           max_seq_len,
-                           size_per_head_,
-                           local_head_num_,
-                           head_n_rep_,
-                           stream_,
-                           quant,
-                           kv_scale);
-    sync_check_cuda_error();
-
-    const T qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f));
-
-    //////////////////////////////////////////////
-    /// Q*K batch gemm
-    /// -> [B, H, s, t + s]
-    cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T,
-                                        CUBLAS_OP_N,
-                                        max_k_len,                      // m
-                                        max_q_len,                      // n
-                                        size_per_head_,                 // k
-                                        k_cache_buf_,                   // A
-                                        size_per_head_,                 // lda
-                                        max_k_len * size_per_head_,     // strideA
-                                        q_buf_2_,                       // B
-                                        size_per_head_,                 // ldb
-                                        max_q_len * size_per_head_,     // strideB
-                                        qk_buf_,                        // C
-                                        max_k_len,                      // ldc
-                                        max_q_len * max_k_len,          // strideC
-                                        batch_size * local_head_num_);  // batchCount
-
-    //////////////////////////////////////////////
-    /// ! masked softmax (kernel asserts k_length <= 4096)
-    MaskedSoftmaxParam param{};
-    param.attention_score    = qk_buf_;
-    param.qk                 = qk_buf_;
-    param.attention_mask     = attention_mask;
-    param.batch_size         = batch_size;
-    param.q_length           = max_q_len;
-    param.k_length           = max_k_len;
-    param.num_heads          = local_head_num_;
-    param.qk_scale           = qk_scale;
-    param.linear_bias_slopes = nullptr;
-    invokeMaskedSoftmax(param, stream_);
-    sync_check_cuda_error();
-
-    //////////////////////////////////////////////
-    /// softmax(QK)*V batch gemm
-    // -> [B, H, S, D]
-    cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N,
-                                        CUBLAS_OP_N,
-                                        size_per_head_,                 // m
-                                        max_q_len,                      // n
-                                        max_k_len,                      // k
-                                        v_cache_buf_,                   // A
-                                        size_per_head_,                 // lda
-                                        max_k_len * size_per_head_,     // strideA,
-                                        qk_buf_,                        // B
-                                        max_k_len,                      // ldb
-                                        max_k_len * max_q_len,          // strideB
-                                        qkv_buf_2_,                     // C
-                                        size_per_head_,                 // ldc,
-                                        max_q_len * size_per_head_,     // strideC
-                                        batch_size * local_head_num_);  // batchCount
-
-    //////////////////////////////////////////////
-    /// transpose  -> 
-    invokeTransposeAttentionOutRemovePadding(qkv_buf_2_,
-                                             qkv_buf_3_,
-                                             num_token,
-                                             batch_size,
-                                             max_q_len,
-                                             local_head_num_,
-                                             size_per_head_,
-                                             padding_offset,
-                                             nullptr,
-                                             0,
-                                             stream_);
-    sync_check_cuda_error();
-}
-
-template class LlamaContextAttentionLayer;
-template class LlamaContextAttentionLayer;
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.h b/src/turbomind/models/llama/LlamaContextAttentionLayer.h
deleted file mode 100644
index f79eaa4ef2..0000000000
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.h
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
- * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/GptContextAttentionLayer.h
-
-#pragma once
-
-#include "src/turbomind/models/llama/LlamaDenseWeight.h"
-#include "src/turbomind/models/llama/LlamaLinear.h"
-#include "src/turbomind/models/llama/llama_params.h"
-#include "src/turbomind/utils/Tensor.h"
-#include "src/turbomind/utils/nccl_utils.h"
-
-namespace turbomind {
-
-template
-class LlamaContextAttentionLayer {
-public:
-    void freeBuffer();
-    void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
-
-    LlamaContextAttentionLayer(size_t               head_num,
-                               size_t               kv_head_num,
-                               size_t               size_per_head,
-                               LlamaAttentionParams attn_params,
-                               NcclParam            tensor_para,
-                               cudaStream_t         stream,
-                               cublasMMWrapper*     cublas_wrapper,
-                               IAllocator*          allocator,
-                               bool                 is_free_buffer_after_forward,
-                               bool                 use_fmha,
-                               int                  quant_policy):
-        head_num_(head_num),
-        size_per_head_(size_per_head),
-        hidden_units_(head_num * size_per_head),
-        local_head_num_(head_num / tensor_para.world_size_),
-        local_kv_head_num_(kv_head_num / tensor_para.world_size_),
-        head_n_rep_(head_num / kv_head_num),
-        params_(attn_params),
-        tensor_para_(tensor_para),
-        stream_(stream),
-        cublas_wrapper_(cublas_wrapper),
-        linear_(cublas_wrapper, stream),
-        allocator_(allocator),
-        is_free_buffer_after_forward_(is_free_buffer_after_forward),
-        use_fmha_(use_fmha),
-        quant_policy_(quant_policy)
-    {
-        FT_CHECK(head_num % kv_head_num == 0);
-    }
-
-    void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight* weights);
-
-    void fusedMultiHeadAttention(T**    key_cache_ptrs,
-                                 T**    val_cache_ptrs,
-                                 size_t cache_layer_offset,
-                                 T*     attention_mask,
-                                 int*   cu_seqlens,
-                                 int*   context_lengths,
-                                 int    batch_size,
-                                 int    max_q_len,
-                                 int    max_k_len,
-                                 int    max_seq_len);
-
-    void unfusedMultiHeadAttention(T**          key_cache_ptrs,
-                                   T**          val_cache_ptrs,
-                                   size_t       cache_layer_offset,
-                                   const T*     attention_mask,
-                                   const int*   padding_offset,
-                                   const int*   context_length,
-                                   int          batch_size,
-                                   int          num_token,
-                                   int          max_q_len,
-                                   int          max_k_len,
-                                   int          max_seq_len,
-                                   int          quant_policy,
-                                   const float* kv_scale);
-
-private:
-    const size_t head_num_;
-    const size_t size_per_head_;
-    const size_t hidden_units_;
-    const size_t local_kv_head_num_;
-    const size_t local_head_num_;
-    const size_t head_n_rep_;
-    const bool   is_free_buffer_after_forward_;
-
-    const LlamaAttentionParams params_;
-
-    const bool use_fmha_;
-    const int  quant_policy_;
-
-    NcclParam tensor_para_;
-
-    cudaStream_t     stream_;
-    IAllocator*      allocator_;
-    cublasMMWrapper* cublas_wrapper_;
-    LlamaLinear   linear_;
-
-    T*     qkv_buf_{};
-    T*     q_buf_2_{};
-    T*     k_buf_2_{};
-    T*     v_buf_2_{};
-    T*     k_cache_buf_{};
-    T*     v_cache_buf_{};
-    T*     qk_buf_{};
-    float* qk_buf_float_{};
-    T*     qkv_buf_2_{};
-    T*     qkv_buf_3_{};
-
-    bool is_allocate_buffer_ = false;
-};
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
deleted file mode 100644
index f914063a70..0000000000
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ /dev/null
@@ -1,290 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptContextDecoder.cc
-
-#include "src/turbomind/models/llama/LlamaContextDecoder.h"
-#include "src/turbomind/kernels/bert_preprocess_kernels.h"
-#include "src/turbomind/kernels/gpt_kernels.h"
-#include "src/turbomind/macro.h"
-#include "src/turbomind/models/llama/LlamaContextDecoder.h"
-#include "src/turbomind/models/llama/llama_decoder_kernels.h"
-#include "src/turbomind/models/llama/llama_kernels.h"
-#include "src/turbomind/utils/Tensor.h"
-
-namespace turbomind {
-
-template
-void LlamaContextDecoder::allocateBuffer()
-{
-    FT_CHECK(false);
-}
-
-template
-void LlamaContextDecoder::allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-    attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * max_q_len * max_kv_len, false);
-    padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * max_q_len, false);
-    cu_seqlens_     = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false);
-
-    is_allocate_buffer_ = true;
-}
-
-template
-void LlamaContextDecoder::freeBuffer()
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    if (is_allocate_buffer_) {
-        allocator_->free((void**)&padding_offset_);
-        allocator_->free((void**)&cu_seqlens_);
-        allocator_->free((void**)&attention_mask_);
-        allocator_->free((void**)&h_pinned_token_num_ptr_, true);
-        is_allocate_buffer_ = false;
-    }
-}
-
-template
-void LlamaContextDecoder::initialize(const LlamaAttentionParams& attn_params,
-                                        size_t                      kv_head_num,
-                                        bool                        use_fmha,
-                                        int                         quant_policy)
-{
-    h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
-
-    context_attention_layer_ = new LlamaContextAttentionLayer(head_num_,
-                                                                 kv_head_num,
-                                                                 size_per_head_,
-                                                                 attn_params,
-                                                                 tensor_para_,
-                                                                 stream_,
-                                                                 cublas_wrapper_,
-                                                                 allocator_,
-                                                                 is_free_buffer_after_forward_,
-                                                                 use_fmha,
-                                                                 quant_policy);
-
-    silu_ffn_layer_ = new LlamaFfnLayer(head_num_,
-                                           size_per_head_,
-                                           inter_size_,
-                                           tensor_para_,
-                                           stream_,
-                                           cublas_wrapper_,
-                                           allocator_,
-                                           is_free_buffer_after_forward_);
-}
-
-template
-void LlamaContextDecoder::forwardSelfAttn(const Session&                                 sess,
-                                             T*                                             attn_io,
-                                             const std::unordered_map* input_tensors,
-                                             int                                            layer,
-                                             bool                                           is_final)
-{
-    // TM_LOG_ERROR(__PRETTY_FUNCTION__);
-    TensorMap self_attention_input_tensors{
-        {"input_query", Tensor{MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
-        {"attention_mask",
-         {MEMORY_GPU, data_type_, {sess.batch_size, 1, sess.max_query_len, sess.max_key_len}, attention_mask_}},
-        {"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &layer}},
-        {"is_final_layer", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_final}},
-        {"padding_offset", {MEMORY_GPU, TYPE_INT32, {sess.token_num}, padding_offset_}},
-        {"cu_seqlens", {MEMORY_GPU, TYPE_INT32, {sess.batch_size + 1}, cu_seqlens_}},
-        {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
-        {"history_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.history_length}},
-        {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
-        {"max_seq_len", input_tensors->at("max_seq_len")}};
-
-    auto& k_cache = *sess.k_cache;
-    auto& v_cache = *sess.v_cache;
-
-    TensorMap self_attention_output_tensors{
-        {"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
-        {"key_cache", k_cache},
-        {"value_cache", v_cache},
-    };
-
-    context_attention_layer_->forward(&self_attention_output_tensors,  //
-                                      &self_attention_input_tensors,
-                                      &sess.weights->at(layer)->self_attn_weights);
-}
-
-template
-LlamaContextDecoder::LlamaContextDecoder(size_t                      head_num,
-                                            size_t                      kv_head_num,
-                                            size_t                      size_per_head,
-                                            size_t                      inter_size,
-                                            size_t                      num_layer,
-                                            const LlamaAttentionParams& attn_params,
-                                            float                       rmsnorm_eps,
-                                            NcclParam                   tensor_para,
-                                            cudaStream_t                stream,
-                                            cublasMMWrapper*            cublas_wrapper,
-                                            IAllocator*                 allocator,
-                                            bool                        is_free_buffer_after_forward,
-                                            bool                        use_fmha,
-                                            int                         quant_policy):
-    BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
-    head_num_(head_num),
-    size_per_head_(size_per_head),
-    inter_size_(inter_size),
-    hidden_units_(head_num * size_per_head),
-    num_layer_(num_layer),
-    rmsnorm_eps_(rmsnorm_eps),
-    tensor_para_(tensor_para),
-    data_type_(getTensorType())
-{
-    initialize(attn_params, kv_head_num, use_fmha, quant_policy);
-}
-
-template
-LlamaContextDecoder::~LlamaContextDecoder()
-{
-    delete context_attention_layer_;
-    delete silu_ffn_layer_;
-    freeBuffer();
-}
-
-template
-void LlamaContextDecoder::forward(std::vector*                            output_tensors,
-                                     const std::vector*                      input_tensors,
-                                     const std::vector*>* decoder_layer_weights)
-{
-    FT_CHECK(false);
-}
-
-template
-void LlamaContextDecoder::forward(std::unordered_map*        output_tensors,
-                                     const std::unordered_map*  input_tensors,
-                                     const std::vector*>* decoder_layer_weights)
-{
-    /**
-     * input tensors:
-     *   \param decoder_input [num_token, hidden_units], float
-     *   \param input_lengths [batch_size], int
-     *   \param history_lengths [batch_size], int
-     *   \param context_legnths [batch_size], int
-     *   \param output_norm_weight [hidden_dims], float
-     *   \param max_q_len [1], int on cpu
-     *   \param max_kv_len [1], int on cpu
-     *   \param max_seq_len [1], int on cpu
-     *
-     * output tensors:
-     *   \param decoder_output [num_token, hidden_units],
-     *   \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x]
-     *   \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head]
-     *   \param last_token_hidden_units [batch_size, hidden_units]
-     */
-
-    Session sess{};
-
-    sess.token_num     = input_tensors->at("decoder_input").shape[0];
-    sess.batch_size    = input_tensors->at("input_lengths").shape[0];
-    sess.max_query_len = input_tensors->at("max_q_len").getVal();
-    sess.max_key_len   = input_tensors->at("max_kv_len").getVal();
-    sess.weights       = decoder_layer_weights;
-
-    sess.input_length   = input_tensors->at("input_lengths").getPtr();
-    sess.history_length = input_tensors->at("history_lengths").getPtr();
-    sess.context_length = input_tensors->at("context_lengths").getPtr();
-
-    T* decoder_input_output = input_tensors->at("decoder_input").getPtr();
-    T* decoder_output       = output_tensors->at("decoder_output").getPtr();
-
-    sess.k_cache = &output_tensors->at("key_cache");
-    sess.v_cache = &output_tensors->at("value_cache");
-
-    allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len);
-
-    size_t tmp_token_num{};
-    invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_,
-                                       &tmp_token_num,  // updated token num
-                                       padding_offset_,
-                                       cu_seqlens_,
-                                       input_tensors->at("input_lengths").getPtr(),
-                                       sess.batch_size,
-                                       sess.max_query_len,
-                                       stream_);
-    sync_check_cuda_error();
-    FT_CHECK(tmp_token_num == sess.token_num);
-
-    invokeCreateCausalMasks(attention_mask_,
-                            sess.input_length,
-                            sess.context_length,
-                            sess.max_query_len,
-                            sess.max_key_len,
-                            sess.batch_size,
-                            stream_);
-    sync_check_cuda_error();
-
-    /////////////////////////////////////////////
-    /// RMSNorm
-    invokeRootMeanSquareNorm(decoder_output,
-                             decoder_input_output,
-                             decoder_layer_weights->at(0)->self_attn_norm_weights,
-                             rmsnorm_eps_,
-                             sess.token_num,
-                             hidden_units_,
-                             stream_);
-    sync_check_cuda_error();
-
-    for (size_t layer = 0; layer < num_layer_; ++layer) {
-        /////////////////////////////////////////////
-        /// self-attention
-        forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);
-
-        invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
-                                          decoder_output,
-                                          decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
-                                          decoder_layer_weights->at(layer)->ffn_norm_weights,
-                                          rmsnorm_eps_,
-                                          sess.token_num,
-                                          hidden_units_,
-                                          stream_);
-        sync_check_cuda_error();
-
-        ////////////////////////////////////////////
-        /// feed-forward network
-        TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
-        TensorMap ffn_outputs{
-            {"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
-        silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);
-
-        auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
-                                                     input_tensors->at("output_norm_weight").getPtr();
-        invokeFusedAddBiasResidualRMSNorm(decoder_input_output,  //
-                                          decoder_output,
-                                          decoder_layer_weights->at(layer)->ffn_weights.output.bias,
-                                          scale_weight,
-                                          rmsnorm_eps_,
-                                          sess.token_num,
-                                          hidden_units_,
-                                          stream_);
-        sync_check_cuda_error();
-    }
-
-    if (is_free_buffer_after_forward_) {
-        freeBuffer();
-    }
-}
-
-template class LlamaContextDecoder;
-template class LlamaContextDecoder;
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.h b/src/turbomind/models/llama/LlamaContextDecoder.h
deleted file mode 100644
index da6264176f..0000000000
--- a/src/turbomind/models/llama/LlamaContextDecoder.h
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptContextDecoder.h
-
-#pragma once
-
-#include "src/turbomind/layers/BaseLayer.h"
-#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
-#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
-#include "src/turbomind/models/llama/LlamaFfnLayer.h"
-#include "src/turbomind/models/llama/llama_params.h"
-#include "src/turbomind/utils/Tensor.h"
-#include "src/turbomind/utils/allocator.h"
-#include "src/turbomind/utils/cublasMMWrapper.h"
-#include "src/turbomind/utils/custom_ar_comm.h"
-#include "src/turbomind/utils/nccl_utils.h"
-
-namespace turbomind {
-
-template
-class LlamaContextDecoder: public BaseLayer {
-protected:
-    void allocateBuffer() override;
-    void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
-    void freeBuffer() override;
-
-    void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_fmha, int quant_policy);
-
-    size_t head_num_;
-    size_t size_per_head_;
-    size_t inter_size_;
-    size_t num_layer_;
-    size_t hidden_units_;
-    float  rmsnorm_eps_;
-
-    NcclParam tensor_para_;
-
-    T*   attention_mask_{};
-    int* padding_offset_{};
-    int* cu_seqlens_{};  // cu for cumulative
-
-    size_t* h_pinned_token_num_ptr_{};
-
-    LlamaContextAttentionLayer* context_attention_layer_{};
-    LlamaFfnLayer*              silu_ffn_layer_{};
-
-    const DataType data_type_;
-
-    struct Session {
-        size_t  batch_size;
-        size_t  token_num;
-        size_t  max_query_len;
-        size_t  max_key_len;
-        Tensor* k_cache;
-        Tensor* v_cache;
-        int*    input_length{};
-        int*    history_length{};
-        int*    context_length{};
-
-        const std::vector*>* weights;
-    };
-
-    void forwardSelfAttn(const Session&                                 sess,
-                         T*                                             attn_io,
-                         const std::unordered_map* input_tensors,
-                         int                                            layer,
-                         bool                                           is_final);
-
-public:
-    LlamaContextDecoder(size_t                      head_num,
-                        size_t                      kv_head_num,
-                        size_t                      size_per_head,
-                        size_t                      inter_size,
-                        size_t                      num_layer,
-                        const LlamaAttentionParams& attn_params,
-                        float                       rmsnorm_eps,
-                        NcclParam                   tensor_para,
-                        cudaStream_t                stream,
-                        cublasMMWrapper*            cublas_wrapper,
-                        IAllocator*                 allocator,
-                        bool                        is_free_buffer_after_forward,
-                        bool                        use_fmha,
-                        int                         quant_policy);
-
-    ~LlamaContextDecoder() override;
-
-    virtual void forward(std::unordered_map*        output_tensors,
-                         const std::unordered_map*  input_tensors,
-                         const std::vector*>* decoder_layer_weights);
-
-    virtual void forward(std::vector*                            output_tensors,
-                         const std::vector*                      input_tensors,
-                         const std::vector*>* decoder_layer_weights);
-};
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
deleted file mode 100644
index 73e95b1353..0000000000
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ /dev/null
@@ -1,247 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
- * Copyright (c) 2022, SK Telecom Authored by A. Dialog
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.cc
-
-#include "src/turbomind/models/llama/LlamaDecoder.h"
-#include "src/turbomind/macro.h"
-#include "src/turbomind/models/llama/llama_decoder_kernels.h"
-#include "src/turbomind/models/llama/llama_kernels.h"
-#include "src/turbomind/models/llama/llama_params.h"
-#include "src/turbomind/models/llama/llama_utils.h"
-
-namespace turbomind {
-
-template
-LlamaDecoder::LlamaDecoder(size_t                      head_num,
-                              size_t                      kv_head_num,
-                              size_t                      size_per_head,
-                              size_t                      inter_size,
-                              size_t                      num_layer,
-                              const LlamaAttentionParams& attn_params,
-                              float                       rmsnorm_eps,
-                              NcclParam                   tensor_para,
-                              cudaStream_t                stream,
-                              cublasMMWrapper*            cublas_wrapper,
-                              IAllocator*                 allocator,
-                              bool                        is_free_buffer_after_forward,
-                              int                         quant_policy):
-    BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
-    head_num_(head_num),
-    size_per_head_(size_per_head),
-    inter_size_(inter_size),
-    num_layer_(num_layer),
-    hidden_units_(head_num * size_per_head),
-    rmsnorm_eps_(rmsnorm_eps),
-    tensor_para_(tensor_para),
-    data_type_(getTensorType())
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    initialize(attn_params, kv_head_num, quant_policy);
-}
-
-template
-LlamaDecoder::~LlamaDecoder()
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    delete self_attention_layer_;
-    delete silu_ffn_layer_;
-}
-
-template
-void LlamaDecoder::initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-    self_attention_layer_ = new LlamaDecoderSelfAttentionLayer(head_num_,
-                                                                  kv_head_num,
-                                                                  size_per_head_,
-                                                                  attn_params,
-                                                                  tensor_para_,
-                                                                  stream_,
-                                                                  cublas_wrapper_,
-                                                                  allocator_,
-                                                                  is_free_buffer_after_forward_,
-                                                                  quant_policy);
-
-    silu_ffn_layer_ = new LlamaFfnLayer(head_num_,
-                                           size_per_head_,
-                                           inter_size_,
-                                           tensor_para_,
-                                           stream_,
-                                           cublas_wrapper_,
-                                           allocator_,
-                                           is_free_buffer_after_forward_);
-}
-
-template
-void LlamaDecoder::allocateBuffer()
-{
-    FT_CHECK(false);
-}
-
-template
-void LlamaDecoder::allocateBuffer(size_t batch_size)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    is_allocate_buffer_ = true;
-}
-
-template
-void LlamaDecoder::freeBuffer()
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    if (is_allocate_buffer_) {
-        is_allocate_buffer_ = false;
-    }
-}
-
-template
-void LlamaDecoder::forwardSelfAttn(const LlamaDecoder::Session&                   sess,
-                                      T*                                             attn_io,
-                                      const std::unordered_map* input_tensors,
-                                      size_t                                         layer)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    TensorMap self_attention_input_tensors(*input_tensors);
-    self_attention_input_tensors.insert("input_query",
-                                        {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io});
-    const int layer_id = layer;
-    self_attention_input_tensors.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id});
-    auto& k_cache = *sess.k_cache;
-    auto& v_cache = *sess.v_cache;
-
-    TensorMap self_attention_output_tensors{
-        {"attention_output", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io}},
-        {"key_cache", k_cache},
-        {"value_cache", v_cache},
-    };
-
-    self_attention_layer_->forward(&self_attention_output_tensors,  //
-                                   &self_attention_input_tensors,
-                                   &sess.weights->at(layer)->self_attn_weights);
-}
-
-template
-void LlamaDecoder::forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer)
-{
-    TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, ffn_io}}};
-    TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, ffn_io}}};
-    silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &sess.weights->at(layer)->ffn_weights);
-}
-
-template
-void LlamaDecoder::forward(std::vector*                            output_tensors,
-                              const std::vector*                      input_tensors,
-                              const std::vector*>* decoder_layer_weights)
-{
-    FT_CHECK(false);
-}
-
-template
-void LlamaDecoder::forward(std::unordered_map*        output_tensors,
-                              const std::unordered_map*  input_tensors,
-                              const std::vector*>* decoder_layer_weights)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    /**
-     * input_tensors:
-     *   \param decoder_input [batch_size, hidden_dims]
-     *   \param sequence_lengths [batch_size] int
-     *   \param output_norm_weight [hidden_dims]
-     *   \param step [1] on cpu
-     *   \param ite [1] on cpu
-     *   \param finished [batch_size] bool
-     *   \param total_padding_tokens [batch_size], int
-     *   \param max_seq_len [1] on cpu
-     *   \param masked_tokens [batch_size, memory_len] bool (optional), NOT USED YET
-     *
-     * output_tensors:
-     *   \param decoder_output [batch_size, hidden_dimension]
-     *   \param key_cache [batch_size] uint64_t
-     *   \param value_cache [batch_size] uint64_t
-     */
-
-    // for the shape of key cache, refer to decoder_masked_multihead_attention_template.hpp
-
-    Session sess{};
-    sess.batch_size = input_tensors->at("decoder_input").shape[0];
-    sess.weights    = decoder_layer_weights;
-
-    allocateBuffer(sess.batch_size);
-
-    sess.ite     = input_tensors->at("ite").getVal();
-    sess.k_cache = &output_tensors->at("key_cache");
-    sess.v_cache = &output_tensors->at("value_cache");
-
-    sess.max_memory_len = input_tensors->at("max_seq_len").getVal();
-
-    T* decoder_input  = input_tensors->at("decoder_input").getPtr();
-    T* decoder_output = output_tensors->at("decoder_output").getPtr();
-
-    ////////////////////////////////////////////
-    /// RMSNorm
-    invokeRootMeanSquareNorm(decoder_output,
-                             decoder_input,
-                             decoder_layer_weights->at(0)->self_attn_norm_weights,
-                             rmsnorm_eps_,
-                             sess.batch_size,
-                             hidden_units_,
-                             stream_);
-    sync_check_cuda_error();
-
-    for (size_t layer = 0; layer < num_layer_; ++layer) {
-        // output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
-        forwardSelfAttn(sess, decoder_output, input_tensors, layer);
-
-        invokeFusedAddBiasResidualRMSNorm(decoder_input,
-                                          decoder_output,
-                                          decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
-                                          decoder_layer_weights->at(layer)->ffn_norm_weights,
-                                          rmsnorm_eps_,
-                                          sess.batch_size,
-                                          hidden_units_,
-                                          stream_);
-        sync_check_cuda_error();
-
-        // decoder_layer_output_ = ffn(decoder_normed_input_)
-        forwardFfn(sess, decoder_output, layer);
-
-        auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
-                                                     input_tensors->at("output_norm_weight").getPtr();
-        invokeFusedAddBiasResidualRMSNorm(decoder_input,  //
-                                          decoder_output,
-                                          decoder_layer_weights->at(layer)->ffn_weights.output.bias,
-                                          scale_weight,
-                                          rmsnorm_eps_,
-                                          sess.batch_size,
-                                          hidden_units_,
-                                          stream_);
-        sync_check_cuda_error();
-    }
-
-    if (is_free_buffer_after_forward_) {
-        freeBuffer();
-    }
-}
-
-template class LlamaDecoder;
-template class LlamaDecoder;
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaDecoder.h b/src/turbomind/models/llama/LlamaDecoder.h
deleted file mode 100644
index 091c2ba55a..0000000000
--- a/src/turbomind/models/llama/LlamaDecoder.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
- * Copyright (c) 2022, SK Telecom Authored by A. Dialog
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.h
-
-#include "src/turbomind/layers/BaseLayer.h"
-#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
-#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
-#include "src/turbomind/models/llama/LlamaFfnLayer.h"
-#include "src/turbomind/models/llama/llama_params.h"
-#include "src/turbomind/utils/custom_ar_comm.h"
-#include "src/turbomind/utils/nccl_utils.h"
-
-namespace turbomind {
-
-template
-class LlamaDecoder: public BaseLayer {
-protected:
-    void allocateBuffer() override;  // deprecated
-    void allocateBuffer(size_t batch_size);
-    void freeBuffer() override;
-    void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy);
-
-    size_t head_num_;
-    size_t size_per_head_;
-    size_t inter_size_;
-    size_t num_layer_;
-    size_t hidden_units_;
-    float  rmsnorm_eps_;
-
-    NcclParam tensor_para_;
-
-    LlamaDecoderSelfAttentionLayer* self_attention_layer_{};
-    LlamaFfnLayer*                  silu_ffn_layer_{};
-
-    const DataType data_type_;
-
-    struct Session {
-        size_t                                          batch_size;
-        int                                             ite;
-        size_t                                          max_memory_len;
-        Tensor*                                         k_cache;
-        Tensor*                                         v_cache;
-        const std::vector*>* weights;
-    };
-
-    void forwardSelfAttn(const Session&                                 sess,
-                         T*                                             attn_io,
-                         const std::unordered_map* input_tensors,
-                         size_t                                         layer);
-
-    void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer);
-
-public:
-    LlamaDecoder(size_t                      head_num,
-                 size_t                      kv_head_num,
-                 size_t                      size_per_head,
-                 size_t                      inter_size,
-                 size_t                      num_layer,
-                 const LlamaAttentionParams& attn_params,
-                 float                       rmsnorm_eps,
-                 NcclParam                   tensor_para,
-                 cudaStream_t                stream,
-                 cublasMMWrapper*            cublas_wrapper,
-                 IAllocator*                 allocator,
-                 bool                        is_free_buffer_after_forward,
-                 int                         quant_policy);
-
-    ~LlamaDecoder() override;
-
-    virtual void forward(std::unordered_map*        output_tensors,
-                         const std::unordered_map*  input_tensors,
-                         const std::vector*>* decoder_layer_weights);
-
-    virtual void forward(std::vector*                            output_tensors,
-                         const std::vector*                      input_tensors,
-                         const std::vector*>* decoder_layer_weights);
-};
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
index 0f65eb1e12..ab3eb783c4 100644
--- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
+++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
@@ -110,6 +110,47 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias)
     }
 }
 
+template
+std::string concat(FirstArg&& first, Args&&... args)
+{
+    std::stringstream stream;
+    stream << first;
+    ((stream << "." << args), ...);
+    return stream.str();
+}
+
+template
+void getWeightTensor(LlamaDenseWeight& weights, bool bias, const std::string& prefix, TensorMap& output)
+{
+    auto get_name = [=](const std::string& name) { return concat(prefix, name); };
+
+    if (bias) {
+        output.insert(get_name("bias"),
+                      Tensor{MEMORY_GPU, getTensorType(), {weights.output_dims * sizeof(T)}, weights.bias});
+    }
+    const size_t bit_size = getBitSize(weights.type);
+    if (bit_size >= 16) {
+        output.insert(get_name("weight"),
+                      Tensor{MEMORY_GPU,
+                             getTensorType(),
+                             {weights.input_dims * weights.output_dims * sizeof(T)},
+                             weights.kernel});
+    }
+    else {  // int8, int4
+        const int factor = sizeof(float) * 8 / bit_size;
+        output.insert(get_name("qweight"),
+                      Tensor{MEMORY_GPU,
+                             TYPE_INT32,
+                             {weights.input_dims * weights.output_dims * sizeof(int) / factor},
+                             weights.kernel});
+        output.insert(get_name("scales_zeros"),
+                      Tensor{MEMORY_GPU,
+                             getTensorType(),
+                             {weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
+                             weights.scales_and_zeros});
+    }
+}
+
 template
 void loadWeights(LlamaDenseWeight& w,
                  std::string          prefix,
@@ -226,6 +267,7 @@ void LlamaDecoderLayerWeight::mallocWeights()
 
     turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_);
     turbomind::mallocWeights(self_attn_weights.output, attn_bias_);
+    self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f};
 
     if (weight_type_ == WeightType::kINT4) {
         turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false);
@@ -294,16 +336,43 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType
     loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0);
 
     // load kv_cache quant scale
-    // if file not exist, get empty vector
     std::string   scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight";
     std::ifstream in(scale_path, std::ios::in);
     if (in.is_open()) {
         in.close();
         self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path);
     }
+}
+
+template
+TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix)
+{
+    TensorMap output;
+
+    output.insert(concat(prefix, "attention_norm.weight"),
+                  Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, self_attn_norm_weights});
+
+    output.insert(concat(prefix, "ffn_norm.weight"),
+                  Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, ffn_norm_weights});
+
+    auto get_prefix = [=](std::string_view name) { return concat(prefix, name, tensor_para_rank_); };
+
+    getWeightTensor(self_attn_weights.qkv, attn_bias_, get_prefix("attention.w_qkv"), output);
+
+    getWeightTensor(self_attn_weights.output, attn_bias_, get_prefix("attention.wo"), output);
+
+    if (weight_type_ == WeightType::kINT4) {
+        getWeightTensor(ffn_weights.fused_gating_intermediate, false, get_prefix("feed_forward.w13"), output);
+    }
     else {
-        self_attn_weights.past_kv_scale = {};
+        getWeightTensor(ffn_weights.gating, false, get_prefix("feed_forward.w1"), output);
+        getWeightTensor(ffn_weights.intermediate, false, get_prefix("feed_forward.w3"), output);
     }
+    getWeightTensor(ffn_weights.output, false, get_prefix("feed_forward.w2"), output);
+    output.insert(concat(prefix, "past_kv_scale", tensor_para_rank_, "weight"),
+                  Tensor{MEMORY_CPU, TYPE_FP32, {4 * sizeof(float)}, self_attn_weights.past_kv_scale.data()});
+
+    return output;
 }
 
 template struct LlamaDecoderLayerWeight;
diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h
index 2141f72e7f..169a3aa9e6 100644
--- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h
+++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h
@@ -21,6 +21,7 @@
 #pragma once
 
 #include "src/turbomind/models/llama/LlamaDenseWeight.h"
+#include "src/turbomind/utils/Tensor.h"
 
 namespace turbomind {
 
@@ -43,6 +44,8 @@ struct LlamaDecoderLayerWeight {
 
     void loadModel(std::string dir_path, FtCudaDataType model_file_type);
 
+    TensorMap getParams(std::string prefix);
+
     T*                      self_attn_norm_weights{};
     T*                      ffn_norm_weights{};
     LlamaAttentionWeight self_attn_weights{};
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
deleted file mode 100644
index 103b32e88f..0000000000
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ /dev/null
@@ -1,309 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.cc
-#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
-#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
-#include "src/turbomind/macro.h"
-#include "src/turbomind/models/llama/LlamaNcclGuard.h"
-#include "src/turbomind/models/llama/llama_kernels.h"
-#include "src/turbomind/models/llama/llama_utils.h"
-#include "src/turbomind/utils/cuda_utils.h"
-#include "src/turbomind/utils/logger.h"
-#include "src/turbomind/utils/nvtx_utils.h"
-#include 
-// #include 
-
-namespace turbomind {
-
-template
-struct SATypeConverter {
-    using Type = T;
-};
-
-template<>
-struct SATypeConverter {
-    using Type = uint16_t;
-};
-
-template
-static inline void fusedQKV_masked_attention_dispatch(const T*     qkv_buf,
-                                                      const T*     qkv_bias,
-                                                      const T*     relative_attention_bias,
-                                                      T*           key_cache,
-                                                      T*           value_cache,
-                                                      T**          k_cache_per_sample,
-                                                      T**          v_cache_per_sample,
-                                                      size_t       kv_cache_per_sample_offset,
-                                                      const int*   cache_indir,
-                                                      T*           context_buf,
-                                                      const bool*  finished,
-                                                      const int*   sequence_lengths,
-                                                      const int    max_batch_size,
-                                                      const int    inference_batch_size,
-                                                      const int    beam_width,
-                                                      const int    head_num,
-                                                      const int    kv_head_num,
-                                                      const int    size_per_head,
-                                                      const int    rotary_embedding_dim,
-                                                      const float  rotary_embedding_base,
-                                                      const int    max_position_embeddings,
-                                                      const bool   use_dynamic_ntk,
-                                                      const bool   use_logn_attn,
-                                                      const int    memory_max_len,
-                                                      const int*   prefix_prompt_lengths,
-                                                      const int    max_prefix_prompt_length,
-                                                      const int    max_input_len,
-                                                      const int*   total_padding_tokens,
-                                                      const int    step,
-                                                      const float  q_scaling,
-                                                      const int    relative_attention_bias_stride,
-                                                      const T*     linear_bias_slopes,
-                                                      const bool*  masked_tokens,
-                                                      const int*   ia3_tasks,
-                                                      const T*     ia3_key_weights,
-                                                      const T*     ia3_value_weights,
-                                                      const float* qkv_scale_out,
-                                                      const float* attention_out_scale,
-                                                      const int    int8_mode,
-                                                      const float* attention_kv_scale,
-                                                      cudaStream_t stream)
-{
-    using DataType = typename SATypeConverter::Type;
-    // Prepare the parameters.
-    Masked_multihead_attention_params params;
-    memset(¶ms, 0, sizeof(params));
-    // int hidden_units = head_num * size_per_head;
-    if (qkv_bias != nullptr) {
-        params.q_bias = reinterpret_cast(qkv_bias);
-        params.k_bias = reinterpret_cast(qkv_bias) + head_num * size_per_head;
-        params.v_bias = reinterpret_cast(qkv_bias) + (head_num + kv_head_num) * size_per_head;
-    }
-    else {
-        params.q_bias = nullptr;
-        params.k_bias = nullptr;
-        params.v_bias = nullptr;
-    }
-
-    // Set the output buffer.
-    params.out = reinterpret_cast(context_buf);
-
-    // Set the input buffers.
-    // [B, nH + kvH, D]
-    params.q = reinterpret_cast(qkv_buf);
-    params.k = reinterpret_cast(qkv_buf) + head_num * size_per_head;
-    params.v = reinterpret_cast(qkv_buf) + (head_num + kv_head_num) * size_per_head;
-
-    params.stride   = (head_num + 2 * kv_head_num) * size_per_head;
-    params.finished = const_cast(finished);
-
-    FT_CHECK(k_cache_per_sample && v_cache_per_sample);
-
-    params.k_cache_per_sample         = reinterpret_cast(k_cache_per_sample);
-    params.v_cache_per_sample         = reinterpret_cast(v_cache_per_sample);
-    params.kv_cache_per_sample_offset = kv_cache_per_sample_offset;
-    params.batch_size                 = inference_batch_size;
-    params.beam_width                 = beam_width;
-    params.memory_max_len             = memory_max_len;
-    params.prefix_prompt_lengths      = prefix_prompt_lengths;
-    params.max_prefix_prompt_length   = max_prefix_prompt_length;
-    params.length_per_sample          = sequence_lengths;  // max_input_length + current output length
-    // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
-    params.timestep     = step + max_prefix_prompt_length - 1;
-    params.num_heads    = head_num;
-    params.num_kv_heads = kv_head_num;
-
-    params.hidden_size_per_head    = size_per_head;
-    params.rotary_embedding_dim    = rotary_embedding_dim;
-    params.rotary_embedding_base   = rotary_embedding_base;
-    params.max_position_embeddings = max_position_embeddings;
-    params.use_dynamic_ntk         = use_dynamic_ntk;
-    params.use_logn_attn           = use_logn_attn;
-
-    // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
-    params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling);
-
-    params.total_padding_tokens = total_padding_tokens;
-    if (relative_attention_bias != nullptr) {
-        params.relative_attention_bias = reinterpret_cast(relative_attention_bias);
-    }
-    params.relative_attention_bias_stride = relative_attention_bias_stride;
-    params.masked_tokens                  = masked_tokens;
-
-    // The slope of linear position bias per head, e.g., ALiBi.
-    if (linear_bias_slopes != nullptr) {
-        params.linear_bias_slopes = reinterpret_cast(linear_bias_slopes);
-    }
-    params.max_input_length = max_input_len;
-
-    params.int8_mode = int8_mode;
-
-    if (int8_mode & QuantPolicy::kCacheKVInt8) {
-        params.attention_k_scale = attention_kv_scale[0];
-        params.attention_k_zp    = attention_kv_scale[1];
-        params.attention_v_scale = attention_kv_scale[2];
-        params.attention_v_zp    = attention_kv_scale[3];
-    }
-
-    PUSH_RANGE("scaled dot-product fusion");
-    masked_multihead_attention(params, stream);
-    POP_RANGE;
-}
-
-template
-void LlamaDecoderSelfAttentionLayer::allocateBuffer(size_t batch_size, int key_len, int max_memory_len)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-    const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
-
-    qkv_buf_ = reinterpret_cast(
-        allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false));
-    context_buf_ =
-        reinterpret_cast(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
-
-    is_allocate_buffer_ = true;
-}
-
-template
-void LlamaDecoderSelfAttentionLayer::freeBuffer()
-{
-    if (is_allocate_buffer_) {
-        allocator_->free((void**)(&qkv_buf_));
-        allocator_->free((void**)(&context_buf_));
-        is_allocate_buffer_ = false;
-    }
-}
-
-template
-void LlamaDecoderSelfAttentionLayer::forward(TensorMap*                     output_tensors,
-                                                const TensorMap*               input_tensors,
-                                                const LlamaAttentionWeight* weights)
-{
-    /**
-     * input tensors:
-     *    \param input_query [batch_size, hidden_units],
-     *    \param sequence_lengths [batch_size]
-     *    \param step [1] on cpu
-     *    \param finished [batch_size]
-     *    \param total_padding_tokens [batch_size]
-     *    \param layer_id [1], int on cpu
-     *    \param max_seq_len [1] on cpu
-     *    \param masked_tokens [batch_size, memory_len], (optional), NOT USED YET
-     *    \param cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional)
-     *
-     * output tensors:
-     *    \param attention_output [batch_size, hidden_units],
-     *    \param key_cache [batch, local_head_num, memory_max_len, size_per_head]
-     *    \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
-     */
-
-    const T*    input_query_data      = input_tensors->getPtr("input_query");
-    const int*  sequence_lengths_data = input_tensors->getPtr("sequence_lengths");
-    const int*  total_padding_len     = input_tensors->getPtr("total_padding_tokens");
-    const bool* finished_data         = input_tensors->getPtr("finished", nullptr);
-    const bool* masked_tokens_data    = input_tensors->getPtr("masked_tokens", nullptr);
-    const int*  cache_indir           = input_tensors->getPtr("cache_indirection", nullptr);
-
-    T*  hidden_features_data = output_tensors->getPtr("attention_output");
-    T** key_cache_ptrs       = output_tensors->getPtr("key_cache");
-    T** value_cache_ptrs     = output_tensors->getPtr("value_cache");
-
-    const int layer_id = input_tensors->getVal("layer_id");
-
-    const int max_seq_len = input_tensors->getVal("max_seq_len");
-    const int step        = input_tensors->getVal("step");
-
-    const int step_1 = step - 1;
-
-    const int batch_size = input_tensors->at("input_query").shape[0];
-    const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1;
-
-    allocateBuffer(batch_size, step, max_seq_len);
-
-    PUSH_RANGE("qkv_gemm");
-    linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
-    POP_RANGE;
-
-    const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
-    const int  memory_len            = max_seq_len;
-
-    fusedQKV_masked_attention_dispatch(
-        qkv_buf_,
-        weights->qkv.bias,  // query_weight.bias,
-        nullptr,            // relative_attention_bias,
-        nullptr,
-        nullptr,
-        key_cache_ptrs,
-        value_cache_ptrs,
-        kv_cache_layer_offset,
-        cache_indir,
-        context_buf_,
-        finished_data,
-        sequence_lengths_data,  // NOTE: current seq len including padding (fixed after meeting the finished id)
-        batch_size,
-        batch_size,
-        beam_width,
-        local_head_num_,
-        local_kv_head_num_,
-        size_per_head_,
-        params_.rotray_embedding_dim,
-        params_.rotary_embedding_base,
-        params_.max_position_embeddings,
-        params_.use_dynamic_ntk,
-        params_.use_logn_attn,
-        memory_len,
-        nullptr,  // prefix_prompt_lengths
-        0,        // max_prefix_prompt_length
-        0,        // max_input_length, not used w/o linear_bias_slopes
-        input_tensors->getPtr("total_padding_tokens", nullptr),
-        step,
-        1.f,                            // q_scaling
-        0,                              // relative_attention_bias_stride
-        nullptr,                        // linear_bias_slopes
-        nullptr,                        //  masked_tokens_data,
-        nullptr,                        // ia3_tasks
-        nullptr,                        // ia3_key_weights
-        nullptr,                        // ia3_value_weights
-        nullptr,                        // qkv_scale_out
-        nullptr,                        // attention_out_scale
-        quant_policy_,                  // int8_mode
-        weights->past_kv_scale.data(),  // attention kv scale
-        stream_);
-    sync_check_cuda_error();
-
-    linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
-
-    if (tensor_para_.world_size_ > 1) {
-        NcclGuard nccl_guard(tensor_para_, stream_);
-        ftNcclAllReduceSum(
-            hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_);
-        sync_check_cuda_error();
-    }
-
-    if (is_free_buffer_after_forward_) {
-        freeBuffer();
-    }
-
-    // LOG(WARNING);
-}
-
-template class LlamaDecoderSelfAttentionLayer;
-template class LlamaDecoderSelfAttentionLayer;
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
deleted file mode 100644
index 89afe3f964..0000000000
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Copyright (c) OpenMMLab. All rights reserved.
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Modified from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.h
-
-#pragma once
-
-#include "src/turbomind/models/llama/LlamaDenseWeight.h"
-#include "src/turbomind/models/llama/LlamaLinear.h"
-#include "src/turbomind/models/llama/llama_params.h"
-#include "src/turbomind/utils/Tensor.h"
-#include "src/turbomind/utils/nccl_utils.h"
-
-namespace turbomind {
-
-template
-class LlamaDecoderSelfAttentionLayer {
-public:
-    void freeBuffer();
-    void allocateBuffer(size_t batch_size, int key_len, int max_memory_len);
-
-    LlamaDecoderSelfAttentionLayer(size_t                      head_num,
-                                   size_t                      kv_head_num,
-                                   size_t                      size_per_head,
-                                   const LlamaAttentionParams& attn_params,
-                                   NcclParam                   tensor_para,
-                                   cudaStream_t                stream,
-                                   cublasMMWrapper*            cublas_wrapper,
-                                   IAllocator*                 allocator,
-                                   bool                        is_free_buffer_after_forward,
-                                   int                         quant_policy):
-        head_num_(head_num),
-        kv_head_num_(kv_head_num),
-        size_per_head_(size_per_head),
-        hidden_units_(head_num * size_per_head),
-        local_head_num_(head_num / tensor_para.world_size_),
-        local_kv_head_num_(kv_head_num_ / tensor_para.world_size_),
-        local_hidden_units_(hidden_units_ / tensor_para.world_size_),
-        params_(attn_params),
-        tensor_para_(tensor_para),
-        stream_(stream),
-        linear_(cublas_wrapper, stream),
-        allocator_(allocator),
-        is_free_buffer_after_forward_(is_free_buffer_after_forward),
-        quant_policy_(quant_policy)
-    {
-    }
-
-    ~LlamaDecoderSelfAttentionLayer()
-    {
-        freeBuffer();
-    }
-
-    void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight* weights);
-
-private:
-    const size_t head_num_;
-    const size_t kv_head_num_;
-    const size_t size_per_head_;
-    const size_t hidden_units_;
-    const size_t local_head_num_;
-    const size_t local_kv_head_num_;
-    const size_t local_hidden_units_;
-    const bool   is_free_buffer_after_forward_;
-    const int    quant_policy_;
-
-    const LlamaAttentionParams& params_;
-
-    NcclParam tensor_para_;
-
-    cudaStream_t   stream_;
-    IAllocator*    allocator_;
-    LlamaLinear linear_;
-
-    T* qkv_buf_     = nullptr;
-    T* context_buf_ = nullptr;
-
-    bool is_allocate_buffer_{};
-};
-
-}  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc
index f605d8f27b..0d78dc4e80 100644
--- a/src/turbomind/models/llama/LlamaFfnLayer.cc
+++ b/src/turbomind/models/llama/LlamaFfnLayer.cc
@@ -20,6 +20,7 @@
 #include "src/turbomind/models/llama/LlamaFfnLayer.h"
 #include "src/turbomind/kernels/activation_kernels.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
+#include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/nvtx_utils.h"
 // #include 
 
@@ -46,6 +47,7 @@ void LlamaFfnLayer::freeBuffer()
 template
 void LlamaFfnLayer::activation(int num_token)
 {
+    NvtxScope scope("activation");
     invokeGenericActivation(gating_buf_,
                                             (const T*)nullptr,  // bias
                                             inter_buf_,
@@ -76,6 +78,8 @@ void LlamaFfnLayer::forward(TensorMap*               output_tensors,
      *   \param ffn_output [token_num, hidden_dimension]
      */
 
+    NvtxScope scope("ffn");
+
     const size_t num_token = input_tensors->at("ffn_input").shape[0];
     // LOG(WARNING);
 
@@ -84,24 +88,28 @@ void LlamaFfnLayer::forward(TensorMap*               output_tensors,
     const T* ffn_input_data  = input_tensors->at("ffn_input").getPtr();
     T*       ffn_output_data = output_tensors->at("ffn_output").getPtr();
 
-    PUSH_RANGE("ffn");
-
     if (weights->fused_gating_intermediate.kernel) {
+        NvtxScope scope("fused_silu_ffn");
         linear_.forward(
             gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear::kFusedSiluFfn);
     }
     else {
-        // w1(x)
-        linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
-        // w3(x)
-        linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
+        {  // w1(x)
+            NvtxScope scope("w1");
+            linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
+        }
+        {  // w3(x)
+            NvtxScope scope("w3");
+            linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
+        }
         // silu(w1(x)) * w3(x)
         activation(num_token);
     }
 
-    // w2(x)
-    linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
-    POP_RANGE;
+    {  // w2(x)
+        NvtxScope scope("w2");
+        linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
+    }
 
     if (tensor_para_.world_size_ > 1) {
         NcclGuard nccl_guard(tensor_para_, stream_);
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index 8768e7fd05..72fee5cfbf 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -28,14 +28,16 @@
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/LlamaWeight.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/models/llama/llama_params.h"
 #include "src/turbomind/models/llama/llama_utils.h"
+#include "src/turbomind/models/llama/unified_decoder.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/logger.h"
 #include 
 #include 
 #include 
-#include 
 
 namespace turbomind {
 
@@ -46,18 +48,14 @@ LlamaV2::LlamaV2(size_t                       head_num,
                     size_t                       inter_size,
                     size_t                       num_layer,
                     size_t                       vocab_size,
-                    const LlamaAttentionParams&  attn_params,
                     float                        norm_eps,
-                    int                          max_batch_size,
-                    int                          max_context_token_num,
-                    int                          session_len,
-                    int                          step_length,
+                    const LlamaAttentionParams&  attn_params,
                     int                          start_id,
                     int                          end_id,
-                    int                          cache_max_entry_count,
-                    int                          cache_chunk_size,
+                    int                          cache_block_seq_len,
                     int                          quant_policy,
                     bool                         use_context_fmha,
+                    const EngineParams&          engine_params,
                     std::shared_ptr shared_state,
                     LlamaWeight*              weights,
                     NcclParam                    tensor_para,
@@ -71,12 +69,14 @@ LlamaV2::LlamaV2(size_t                       head_num,
     inter_size_(inter_size),
     num_layer_(num_layer),
     vocab_size_(vocab_size),
+    attn_params_(attn_params),
     vocab_size_padded_(vocab_size),
     rmsnorm_eps_(norm_eps),
     start_id_(start_id),
     end_id_(end_id),
     hidden_units_(head_num * size_per_head),
     local_head_num_(head_num / tensor_para.world_size_),
+    local_kv_head_num_(kv_head_num / tensor_para.world_size_),
     weights_(weights),
     tensor_para_(tensor_para),
     stream_(stream),
@@ -85,8 +85,6 @@ LlamaV2::LlamaV2(size_t                       head_num,
     is_free_buffer_after_forward_(is_free_buffer_after_forward),
     cuda_device_prop_(cuda_device_prop),
     debug_(isDebug()),
-    step_length_(step_length),
-    batch_(max_batch_size, max_context_token_num, session_len, this),
     shared_state_(shared_state)
 
 {
@@ -96,80 +94,45 @@ LlamaV2::LlamaV2(size_t                       head_num,
     vocab_size_padded_ =
         (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
 
-    size_t elem_bits = 0;
-    if (quant_policy & QuantPolicy::kCacheKVInt8) {
-        elem_bits = sizeof(int8_t) * 8;
-        if (use_context_fmha) {
-            TM_LOG_ERROR("use_context_fmha not support int8");
-            assert(0);
-        }
-    }
-    else {
-        elem_bits = sizeof(T) * 8;
-    }
+    batch_ = std::make_unique>(engine_params, cache_block_seq_len, quant_policy, this);
 
-    const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
-
-    kv_cache_mgr_ = std::make_unique(num_layer_,
-                                                        local_kv_head_num,
-                                                        size_per_head_,
-                                                        session_len,
-                                                        elem_bits,
-                                                        cache_max_entry_count,
-                                                        cache_chunk_size,
-                                                        tensor_para.rank_,
-                                                        allocator);
-    initialize(attn_params, kv_head_num, use_context_fmha, quant_policy);
-    start();
+    initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy);
+
+    /// TODO: decouple Llama model and batch inference
+    batch_->Start();
 }
 
 template
 LlamaV2::~LlamaV2()
 {
-    shared_state_->request_queue.close();
-    internal_thread_.join();
-
-    delete decoder_;
+    unified_decoder_.reset();
     delete dynamic_decode_layer_;
-    delete context_decoder_;
 }
 
 template
 void LlamaV2::initialize(const LlamaAttentionParams& attn_params,
                             size_t                      kv_head_num,
                             bool                        use_context_fmha,
+                            int                         cache_block_seq_len,
                             int                         quant_policy)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
-    context_decoder_ = new LlamaContextDecoder(head_num_,
-                                                  kv_head_num,
-                                                  size_per_head_,
-                                                  inter_size_,
-                                                  num_layer_,
-                                                  attn_params,
-                                                  rmsnorm_eps_,
-                                                  tensor_para_,
-                                                  stream_,
-                                                  cublas_wrapper_,
-                                                  allocator_,
-                                                  is_free_buffer_after_forward_,
-                                                  use_context_fmha,
-                                                  quant_policy);
-
-    decoder_ = new LlamaDecoder(head_num_,
-                                   kv_head_num,
-                                   size_per_head_,
-                                   inter_size_,
-                                   num_layer_,
-                                   attn_params,
-                                   rmsnorm_eps_,
-                                   tensor_para_,
-                                   stream_,
-                                   cublas_wrapper_,
-                                   allocator_,
-                                   is_free_buffer_after_forward_,
-                                   quant_policy);
+    unified_decoder_.reset(new UnifiedDecoder(head_num_,
+                                                 kv_head_num,
+                                                 size_per_head_,
+                                                 inter_size_,
+                                                 num_layer_,
+                                                 attn_params,
+                                                 rmsnorm_eps_,
+                                                 tensor_para_,
+                                                 stream_,
+                                                 cublas_wrapper_,
+                                                 allocator_,
+                                                 is_free_buffer_after_forward_,
+                                                 use_context_fmha,
+                                                 cache_block_seq_len,
+                                                 quant_policy));
 
     dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_,
                                                           vocab_size_padded_,
@@ -184,6 +147,7 @@ void LlamaV2::initialize(const LlamaAttentionParams& attn_params,
 template
 void LlamaV2::embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step)
 {
+    NvtxScope scope("embeddingLookup");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     // ! This kernel can't be used in context decoding
     invokeEmbeddingLookupPosEncodingPadCount(embeddings,
@@ -202,28 +166,32 @@ void LlamaV2::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
 }
 
 template
-void LlamaV2::contextDecode(T*         deocder_output,
-                               uintptr_t* k_cache_ptr,
-                               uintptr_t* v_cache_ptr,
-                               T*         context_decoder_input_buf,
-                               T*         context_decoder_output_buf,
-                               const int* input_ids,
-                               const int* input_length,
-                               const int* history_length,
-                               const int* context_length,
-                               size_t     token_num,
-                               size_t     max_input_len,
-                               size_t     max_context_len,
-                               size_t     session_len,
-                               size_t     batch_size)
+void LlamaV2::forwardUnified(T*           out,
+                                T*           decoder_output,
+                                T*           decoder_input,
+                                void**       k_block_ptrs,
+                                void**       v_block_ptrs,
+                                const int*   input_ids,
+                                const int*   cu_block_cnts,
+                                const float* rope_theta,
+                                const bool*  dc_finished,
+                                const int*   pf_input_length,
+                                const int*   pf_context_length,
+                                T**          pf_tmp_k_ptrs,
+                                T**          pf_tmp_v_ptrs,
+                                size_t       token_num,
+                                int          dc_batch_size,
+                                int          dc_step,
+                                int          dc_sum_seq_len,
+                                int          dc_max_seq_len,
+                                int          pf_batch_size,
+                                int          pf_max_input_len,
+                                int          pf_max_context_len,
+                                int          pf_session_len)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
-    if (tensor_para_.rank_ == 0) {
-        TM_LOG_INFO("context decoding start");
-    }
-
-    invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf,
+    invokeInputIdsEmbeddingLookupPosEncoding(decoder_input,
                                              nullptr,  // processed somewhere else
                                              weights_->pre_decoder_embedding_table,
                                              static_cast(nullptr),
@@ -237,81 +205,38 @@ void LlamaV2::contextDecode(T*         deocder_output,
                                              stream_);
     sync_check_cuda_error();
 
-    const auto dtype = getTensorType();
-    const auto bsz   = batch_size;
-
-    const int max_q_len   = max_input_len;
-    const int max_kv_len  = max_context_len;
-    const int max_seq_len = session_len;
-
-    std::unordered_map decoder_input_tensors{
-        {"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_input_buf}},
-        {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
-        {"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, input_length}},
-        {"history_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, history_length}},
-        {"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, context_length}},
-        {"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}},
-        {"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}},
-        {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
-    };
-
-    std::unordered_map decoder_output_tensors{
-        {"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_output_buf}},
-        {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}},
-        {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}},
-        {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}};
-
-    context_decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
-
-    if (tensor_para_.rank_ == 0) {
-        TM_LOG_INFO("context decoding end");
-    }
-}
-
-template
-void LlamaV2::decoderForward(T*         decoder_output,
-                                uintptr_t* k_cache_ptr,
-                                uintptr_t* v_cache_ptr,
-                                T*         decoder_input,
-                                const int* sequence_length,
-                                const int* total_padding_count,
-                                bool*      finished,
-                                int        step,
-                                int        ite,
-                                size_t     session_len,
-                                size_t     batch_size)
-{
-    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
-    const int  max_seq_len = session_len;
-    const auto dtype       = getTensorType();
-
-    // max_input_length is not used w/o linear_bias_slopes
-    // sequence_lengths_ will be incremented in dynamic decode
-    std::unordered_map decoder_input_tensors{
-        {"decoder_input", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_input}},
-        {"sequence_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
-        {"total_padding_tokens", {MEMORY_GPU, TYPE_INT32, {batch_size}, total_padding_count}},
-        {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
-        {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
-        {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
-        {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
-        {"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}},
-    };
-
-    // LOG(ERROR) << key_cache_ << " " << value_cache_;
-    std::unordered_map decoder_output_tensors{
-        {"decoder_output", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_output}},
-        {"key_cache", {MEMORY_GPU, TYPE_UINT64, {batch_size}, k_cache_ptr}},
-        {"value_cache", {MEMORY_GPU, TYPE_UINT64, {batch_size}, v_cache_ptr}},
-    };
-
-    decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
+    const auto   dtype = getTensorType();
+    const size_t bsz   = dc_batch_size + pf_batch_size;
+
+    TensorMap inputs{{"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, decoder_input}},
+                     {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
+                     {"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, pf_input_length}},
+                     {"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, pf_context_length}},
+                     {"dc_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &dc_batch_size}},
+                     {"dc_sum_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &dc_sum_seq_len}},
+                     {"dc_max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &dc_max_seq_len}},
+                     {"finished", {MEMORY_GPU, TYPE_BOOL, {bsz}, dc_finished}},
+                     {"pf_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &pf_batch_size}},
+                     {"pf_max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_max_input_len}},
+                     {"pf_max_k_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_max_context_len}},
+                     {"session_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_session_len}},
+                     {"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}},
+                     {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {bsz}, cu_block_cnts}}};
+
+    TensorMap outputs{{"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, decoder_output}},
+                      {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_block_ptrs}},
+                      {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_block_ptrs}},
+                      {"tmp_k", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_k_ptrs}},
+                      {"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_v_ptrs}},
+                      {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, out}}};
+
+    unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights);
 }
 
 template
 void LlamaV2::postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size)
 {
+    NvtxScope scope("postDecodeEmbedding");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     cudaDataType_t data_type = getCudaDataType();
     float          alpha     = 1.f;
@@ -377,6 +302,7 @@ void LlamaV2::dynamicDecode(int*            token_ids,
                                bool*           finished,
                                int*            sequence_length,
                                bool*           should_stop,
+                               curandState_t*  curand_state,
                                TensorMap*      inputs,
                                TensorMap*      outputs,
                                const float*    logits,
@@ -389,6 +315,7 @@ void LlamaV2::dynamicDecode(int*            token_ids,
                                size_t          token_ids_len,
                                size_t          batch_size)
 {
+    NvtxScope scope("dynamicDecode");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     int local_batch_size = (int)batch_size;
 
@@ -420,7 +347,8 @@ void LlamaV2::dynamicDecode(int*            token_ids,
         {"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}},
         {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
         {"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
-        {"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}}};
+        {"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}},
+        {"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, curand_state}}};
 
     const std::vector optional_outputs{"cum_log_probs", "output_log_probs"};
     for (const auto& key : optional_outputs) {
@@ -432,83 +360,6 @@ void LlamaV2::dynamicDecode(int*            token_ids,
     dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
 }
 
-template
-void LlamaV2::internalThreadEntry(int device_id)
-{
-    TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_);
-    check_cuda_error(cudaSetDevice(device_id));
-
-    auto& request_queue  = shared_state_->request_queue;
-    auto& infer_requests = shared_state_->infer_requests;
-    auto& stop_requests  = shared_state_->stop_requests;
-
-    while (1) {
-        if (tensor_para_.rank_ == 0) {
-            const int  free_slot_count = batch_.maxSize() - batch_.size() + batch_.finishedCount();
-            const bool is_empty        = free_slot_count == batch_.maxSize();
-
-            request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty);
-
-            // request queue was closed
-            // and there are no unprocessed requests in the queue
-            if (is_empty && infer_requests.empty() && stop_requests.empty()) {
-                // rank 0 sets flag
-                shared_state_->should_stop = true;
-            }
-
-            batch_.verifyRequests(stop_requests, infer_requests);
-        }
-
-        // wait while rank-0 is dequeueing
-        shared_state_->barrier->wait();
-
-        // exit if job is done
-        if (shared_state_->should_stop) {
-            return;
-        }
-
-        bool modified = false;
-
-        if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) {
-            batch_.handleStopRequests(stop_requests);
-            batch_.synchronize();
-            modified = true;
-        }
-
-        const int infer_request_count = infer_requests.size();
-
-        if (!infer_requests.empty()) {
-            batch_.initialize(infer_requests);  // reinitialize when new requests come, possible buffer allocation
-            batch_.contextDecode();
-            modified = true;
-        }
-
-        // wait while shared stop/infer_requests is being used
-        shared_state_->barrier->wait();
-
-        if (batch_.size()) {
-            if (modified) {
-                batch_.initializeGeneration();
-                batch_.initializeSampling(infer_request_count);
-            }
-            for (int i = 0; i < step_length_; ++i) {
-                if (!batch_.generate()) {
-                    break;
-                }
-            }
-            batch_.finish();
-        }
-    }
-}
-
-template
-void LlamaV2::start()
-{
-    int device_id = -1;
-    check_cuda_error(cudaGetDevice(&device_id));
-    internal_thread_ = std::thread(&LlamaV2::internalThreadEntry, this, device_id);
-}
-
 static inline Tensor slice(const Tensor& tensor, int index)
 {
     auto shape = tensor.shape;
@@ -591,15 +442,25 @@ void LlamaV2::forward(std::unordered_map*       outputs,
     bool             has_error = 0;
     if (rank == 0) {
         TM_LOG_INFO("[forward] Enqueue requests");
+
+        std::vector ids;
+        for (const auto& r : requests) {
+            ids.push_back(r->id);
+        }
+
         auto futures = shared_state_->request_queue.enqueue(std::move(requests));
 
+        FT_CHECK_WITH_INFO(ids.size() == futures.size(), "check failed");
+
         TM_LOG_INFO("[forward] Wait for requests to complete ...");
-        for (auto& f : futures) {
-            auto ec = f.get();
+
+        for (int i = 0; i < futures.size(); ++i) {
+            auto ec = futures[i].get();
             error_codes.push_back(ec);
             if (ec) {
                 has_error = true;
             }
+            TM_LOG_INFO("[forward] Request complete for %ld, code %d", (long)ids[i], (int)ec);
         }
     }
 
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index 40633b0a22..19cea4b58e 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -24,10 +24,11 @@
 #include "src/turbomind/layers/DynamicDecodeLayer.h"
 #include "src/turbomind/models/llama/Barrier.h"
 #include "src/turbomind/models/llama/LlamaBatch.h"
-#include "src/turbomind/models/llama/LlamaContextDecoder.h"
-#include "src/turbomind/models/llama/LlamaDecoder.h"
 #include "src/turbomind/models/llama/LlamaWeight.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/llama_params.h"
+#include "src/turbomind/models/llama/unified_decoder.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/cublasMMWrapper.h"
 #include "src/turbomind/utils/instance_comm.h"
@@ -46,9 +47,7 @@ class LlamaV2 {
         std::vector> stop_requests;
         RequestQueue                          request_queue;
         std::shared_ptr              barrier;
-
-        // rank 0 sets flag to true if there are no more tasks in the request_queue
-        bool should_stop = false;
+        bool                                  abort;
     };
 
     ~LlamaV2();
@@ -59,18 +58,14 @@ class LlamaV2 {
             size_t                       inter_size,
             size_t                       num_layer,
             size_t                       vocab_size,
-            const LlamaAttentionParams&  attn_params,
             float                        norm_eps,
-            int                          max_batch_size,
-            int                          max_context_token_num,
-            int                          session_len,
-            int                          step_length,
+            const LlamaAttentionParams&  attn_params,
             int                          start_id,
             int                          end_id,
-            int                          cache_max_entry_count,
-            int                          cache_chunk_size,
+            int                          cache_block_seq_len,
             int                          quant_policy,
             bool                         use_context_fmha,
+            const EngineParams&          engine_params,
             std::shared_ptr shared_state,
             LlamaWeight*              weights,
             NcclParam                    tensor_para,
@@ -104,39 +99,36 @@ class LlamaV2 {
 private:
     friend class Batch;
 
-    void internalThreadEntry(int device_id);
-
-    void
-    initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_context_fmha, int quant_policy);
+    void initialize(const LlamaAttentionParams& attn_params,
+                    size_t                      kv_head_num,
+                    bool                        use_context_fmha,
+                    int                         cache_block_seq_len,
+                    int                         quant_policy);
 
     void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
 
-    void contextDecode(T*         deocder_output,
-                       uintptr_t* k_cache_ptr,
-                       uintptr_t* v_cache_ptr,
-                       T*         context_decoder_input_buf,
-                       T*         context_decoder_output_buf,
-                       const int* input_ids,
-                       const int* input_length,
-                       const int* history_length,
-                       const int* context_length,
-                       size_t     token_num,
-                       size_t     max_input_len,
-                       size_t     max_context_len,
-                       size_t     session_len,
-                       size_t     batch_size);
-
-    void decoderForward(T*         decoder_output,
-                        uintptr_t* k_cache_ptr,
-                        uintptr_t* v_cache_ptr,
-                        T*         decoder_input,
-                        const int* sequence_length,
-                        const int* total_padding_count,
-                        bool*      finished,
-                        int        step,
-                        int        ite,
-                        size_t     session_len,
-                        size_t     batch_size);
+    void forwardUnified(T*           out,
+                        T*           decoder_output,
+                        T*           decoder_input,
+                        void**       k_block_ptrs,
+                        void**       v_block_ptrs,
+                        const int*   input_ids,
+                        const int*   cu_block_cnts,
+                        const float* rope_theta,
+                        const bool*  dc_finished,
+                        const int*   pf_input_length,
+                        const int*   pf_context_length,
+                        T**          pf_tmp_k_ptrs,
+                        T**          pf_tmp_v_ptrs,
+                        size_t       token_num,
+                        int          dc_batch_size,
+                        int          dc_step,
+                        int          dc_sum_seq_len,
+                        int          dc_max_seq_len,
+                        int          pf_batch_size,
+                        int          pf_max_input_len,
+                        int          pf_max_context_len,
+                        int          pf_session_len);
 
     void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);
 
@@ -144,6 +136,7 @@ class LlamaV2 {
                        bool*           finished,
                        int*            sequence_length,
                        bool*           should_stop,
+                       curandState_t*  curand_state,
                        TensorMap*      inputs,
                        TensorMap*      outputs,
                        const float*    logits,
@@ -156,8 +149,6 @@ class LlamaV2 {
                        size_t          token_ids_len,
                        size_t          batch_size);
 
-    void start();
-
 private:
     friend class LlamaBatch;
 
@@ -169,6 +160,8 @@ class LlamaV2 {
     size_t       vocab_size_padded_;
     float        rmsnorm_eps_ = 1e-6f;
 
+    const LlamaAttentionParams attn_params_;
+
     static constexpr bool neox_rotary_style_ = false;
 
     const int    start_id_;
@@ -176,6 +169,7 @@ class LlamaV2 {
     const size_t hidden_units_;
 
     const size_t local_head_num_;
+    const size_t local_kv_head_num_;
     NcclParam    tensor_para_;
 
     cudaStream_t     stream_;
@@ -186,20 +180,14 @@ class LlamaV2 {
 
     const bool debug_{false};
 
-    std::unique_ptr kv_cache_mgr_;
-
-    LlamaWeight*            weights_{};
-    LlamaDecoder*           decoder_{};
-    LlamaContextDecoder*    context_decoder_{};
-    DynamicDecodeLayer* dynamic_decode_layer_{};
-
-    const int                    step_length_;
-    LlamaBatch                batch_;
-    std::shared_ptr shared_state_;
+    LlamaWeight* weights_{};
 
-    std::thread internal_thread_;
+    std::unique_ptr> unified_decoder_;
+    DynamicDecodeLayer*         dynamic_decode_layer_{};
 
-    ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
+    std::shared_ptr   shared_state_;
+    ffi_api_lock_ctrl_t            ffi_lock_;
+    std::unique_ptr> batch_;
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc
index e1287f471b..e270d3ba5c 100644
--- a/src/turbomind/models/llama/LlamaWeight.cc
+++ b/src/turbomind/models/llama/LlamaWeight.cc
@@ -109,6 +109,35 @@ void LlamaWeight::loadModel(std::string dir_path)
     }
 }
 
+template
+TensorMap LlamaWeight::getParams()
+{
+    TensorMap output;
+
+    output.insert(
+        "tok_embeddings.weight",
+        Tensor{MEMORY_GPU, getTensorType(), {vocab_size_ * hidden_units_ * sizeof(T)}, pre_decoder_embedding_table});
+
+    output.insert("norm.weight",
+                  Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, output_norm_weight});
+
+    output.insert(
+        "output.weight",
+        Tensor{
+            MEMORY_GPU, getTensorType(), {hidden_units_ * vocab_size_ * sizeof(T)}, post_decoder_embedding_kernel});
+
+    // transformer layers
+    for (size_t i = 0; i < num_layer_; i++) {
+        std::string prefix = fmtstr("layers.%d", i);
+        TensorMap   layeri = decoder_layer_weights[i]->getParams(prefix);
+        for (auto [name, tensor] : layeri) {
+            output.insert(name, tensor);
+        }
+    }
+
+    return output;
+}
+
 template struct LlamaWeight;
 template struct LlamaWeight;
 
diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h
index be7fda2b98..a896a87a09 100644
--- a/src/turbomind/models/llama/LlamaWeight.h
+++ b/src/turbomind/models/llama/LlamaWeight.h
@@ -47,6 +47,8 @@ struct LlamaWeight {
 
     void loadModel(std::string dir_path);
 
+    TensorMap getParams();
+
     std::vector*> decoder_layer_weights;
     const T*                                 pre_decoder_embedding_table{};
     const T*                                 output_norm_weight{};
diff --git a/src/turbomind/models/llama/Request.h b/src/turbomind/models/llama/Request.h
index 0bccf84a57..fb1a5daae9 100644
--- a/src/turbomind/models/llama/Request.h
+++ b/src/turbomind/models/llama/Request.h
@@ -13,10 +13,12 @@
 namespace turbomind {
 
 struct Request {
-    uint64_t id;
-    bool     start_flag;
-    bool     end_flag;
-    bool     stop_flag;
+    uint64_t id;         // sequence id
+    uint64_t unique_id;  // monotonic increasing
+
+    bool start_flag;
+    bool end_flag;
+    bool stop_flag;
 
     // per rank inputs/outputs
     std::vector inputs;
@@ -31,7 +33,8 @@ struct Request {
         kConflict = 2,
         kBusy     = 3,
         kInactive = 4,
-        kFail     = 5
+        kFail     = 5,
+        kTooLong  = 6
     };
     std::promise signal;
 };
@@ -66,11 +69,16 @@ class RequestQueue {
     void dequeue(std::vector>& stop_requests,
                  std::vector>& infer_requests,
                  unsigned                               max_infer_count,
-                 bool                                   blocking)
+                 bool                                   blocking,
+                 bool&                                  abort)
     {
         std::unique_lock lock(mutex_);
         if (blocking) {
-            cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty() && closed_ == false); });
+            cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()) || closed_; });
+            if (closed_) {
+                abort = true;
+                return;
+            }
         }
 
         stop_requests.clear();
@@ -88,8 +96,10 @@ class RequestQueue {
 
     void close()
     {
-        std::lock_guard lock(mutex_);
-        closed_ = true;
+        {
+            std::lock_guard lock(mutex_);
+            closed_ = true;
+        }
         cv_.notify_all();
     }
 
@@ -98,7 +108,7 @@ class RequestQueue {
     std::queue> infer_queue_;
     std::mutex                           mutex_;
     std::condition_variable              cv_;
-    bool                                 closed_ = false;
+    bool                                 closed_{false};
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
new file mode 100644
index 0000000000..ca97acd37d
--- /dev/null
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -0,0 +1,466 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/BlockManager.h"
+#include "src/turbomind/utils/allocator.h"
+#include "src/turbomind/utils/debug_utils.h"
+#include "src/turbomind/utils/logger.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace turbomind {
+
+SequenceManager::SequenceManager(size_t      layer_num,
+                                 size_t      head_num,
+                                 size_t      head_dim,
+                                 size_t      block_seq_len,
+                                 double      block_count,
+                                 int         chunk_size,
+                                 size_t      elem_bits,
+                                 int         rank,
+                                 IAllocator* allocator):
+    block_seq_len_(block_seq_len)
+{
+    constexpr int kBitsPerByte = 8;
+
+    // [2, L, H, block_seq_len, D]
+    size_t block_size = 2UL * layer_num * head_num * block_seq_len * head_dim * elem_bits / kBitsPerByte;
+
+    block_manager_ = std::make_unique(block_size, block_count, chunk_size, allocator);
+
+    val_offset_ = block_size / 2;
+}
+
+const Sequence* SequenceManager::Create(uint64_t id)
+{
+    Sequence sequence{id};
+    auto     it = sequences_.find(id);
+    if (it != sequences_.end()) {
+        if (rank_ == 0) {
+            TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id);
+        }
+        Erase(it);
+    }
+    it = sequences_.emplace_hint(it, id, std::move(sequence));
+    return &it->second;
+}
+
+const Sequence* SequenceManager::Get(uint64_t id)
+{
+    if (auto it = sequences_.find(id); it != sequences_.end()) {
+        return &it->second;
+    }
+    return nullptr;
+}
+
+bool SequenceManager::Contains(uint64_t id)
+{
+    return sequences_.find(id) != sequences_.end();
+}
+
+void SequenceManager::Erase(std::map::iterator it)
+{
+    auto& seq = it->second;
+    if (seq.status == Sequence::kCached) {
+        const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
+        seq.blocks.resize(count);
+    }
+    else {
+        UpdateAndSetUnlock(seq);
+    }
+    freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
+    sequences_.erase(it);
+}
+
+bool SequenceManager::Erase(uint64_t id)
+{
+    if (auto it = sequences_.find(id); it != sequences_.end()) {
+        Erase(it);
+        return true;
+    }
+    return false;
+}
+
+void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
+{
+    BlockIds blocks;
+    for (const auto& p : sequences) {
+        auto& seq = const_cast(*p);
+        if (seq.status != Sequence::kCached) {
+            continue;
+        }
+        FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
+        if (need_verify_) {
+            const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
+            seq.blocks.resize(count);
+            seq.block_unique_ids.resize(count);
+        }
+        blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
+        seq.cache_len = std::min(seq.cache_len, seq.blocks.size() * block_seq_len_);
+        seq.status    = Sequence::kLocked;
+    }
+    block_manager_->Lock(blocks);
+    need_verify_ = false;
+}
+
+void SequenceManager::CommitUnlockAndFree()
+{
+    if (!unlocked_.empty()) {
+        block_manager_->Unlock(unlocked_);
+        unlocked_.clear();
+    }
+
+    if (!freed_.empty()) {
+        block_manager_->Free(freed_);
+        freed_.clear();
+    }
+}
+
+void SequenceManager::UpdateAndSetUnlock(const Sequence& sequence)
+{
+    FT_CHECK(sequence.status != Sequence::kCached);
+    auto& seq = const_cast(sequence);
+    block_manager_->Touch(seq.blocks);
+    unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
+    seq.status = Sequence::kCached;
+}
+
+namespace {
+
+struct Schedule {
+    int free;
+    int cached;
+
+    int allocate{};
+    int evict{};
+    int preempt{};
+
+    int last;
+
+    int input_count1;
+    int input_count2;
+
+    Sequences        active;
+    std::vector block_counts;
+    Sequences        inactive;
+    Sequences        victims;
+
+    Schedule(Snapshot snapshot, int size, int _input_count1, int _input_count2):
+        free(snapshot.free),
+        cached(snapshot.cached),
+        last(size),
+        use_count_(std::move(snapshot.use_count)),
+        unlocked_(size),
+        it_(size),
+        input_count1(_input_count1),
+        input_count2(_input_count2)
+    {
+    }
+
+    int Unlock(const Sequences& seqs, int vidx)
+    {
+        while (vidx < it_) {
+            const auto& blocks = seqs[--it_]->blocks;
+            int         count  = 0;
+            for (const auto& bid : blocks) {
+                count += static_cast(--use_count_[bid] == 0);
+            }
+            unlocked_[it_] = count;
+        }
+        return unlocked_[vidx];
+    }
+
+private:
+    std::vector use_count_;
+    std::vector unlocked_;
+    int              it_;
+};
+
+template
+std::ostream& operator<<(std::ostream& os, const std::vector& v)
+{
+    os << "[";
+    for (int i = 0; i < v.size(); ++i) {
+        os << (i ? "," : "") << v[i];
+    }
+    os << "]";
+    return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const Schedule& s)
+{
+    os << "free=" << s.free << ", cached=" << s.cached << ", allocate=" << s.allocate << ", evict=" << s.evict
+       << ", preempt=" << s.preempt << ", active=" << s.active << ", victims=" << s.victims
+       << ", block_counts=" << s.block_counts << ", inactive=" << s.inactive;
+    return os;
+}
+
+struct Transaction {
+    int index_;
+    int block_count_;
+    int input_count_;
+
+    int allocate_{};
+    int evict_{};
+    int preempt_{};
+
+    Sequences victims_;
+
+    const Sequences& sequences_;
+    Schedule&        schedule_;
+
+    explicit Transaction(const Sequences& sequences, int index, int block_count, int input_count, Schedule& sched):
+        sequences_(sequences), schedule_(sched), index_(index), block_count_(block_count), input_count_(input_count)
+    {
+    }
+
+    void Process()
+    {
+        if (schedule_.input_count1 > 0) {
+            int count = block_count_;
+
+            int tmp = std::min(schedule_.free, count);
+            count -= tmp;
+            allocate_ += tmp;
+
+            tmp = std::min(schedule_.cached, count);
+            count -= tmp;
+            evict_ += tmp;
+
+            for (int vidx = schedule_.last - 1; count && vidx > index_; --vidx) {
+                if (sequences_[vidx]->status == Sequence::kCached) {
+                    continue;
+                }
+                victims_.push_back(sequences_[vidx]);
+                preempt_ += schedule_.Unlock(sequences_, vidx);
+
+                if (count <= preempt_) {
+                    evict_ += count;
+                    count -= count;
+                    schedule_.last = vidx;  // ! modifiying `sched_.last` is part of commit
+                    break;
+                }
+            }
+            if (count == 0) {
+                return Commit();
+            }
+        }
+
+        const_cast(sequences_[index_])->input_length = 0;
+        schedule_.inactive.push_back(sequences_[index_]);
+    }
+
+    void Commit()
+    {
+        // update available resources
+        schedule_.free -= allocate_;
+        FT_CHECK(schedule_.free >= 0);
+        schedule_.cached += preempt_;
+        schedule_.cached -= evict_;
+        FT_CHECK(schedule_.cached >= 0);
+
+        // update scheduled operations
+        schedule_.allocate += allocate_;
+        schedule_.evict += evict_;
+        schedule_.preempt += preempt_;
+        schedule_.victims.insert(schedule_.victims.end(), victims_.begin(), victims_.end());
+
+        // update active sequences
+        schedule_.active.push_back(sequences_[index_]);
+        schedule_.block_counts.push_back(block_count_);
+
+        if (input_count_ > schedule_.input_count2) {
+            input_count_ = schedule_.input_count1;
+        }
+        schedule_.input_count1 -= input_count_;
+        schedule_.input_count2 -= input_count_;
+        const_cast(sequences_[index_])->input_length = input_count_;
+    }
+};
+
+std::ostream& operator<<(std::ostream& os, const Transaction& trans)
+{
+    os << "index=" << trans.index_ << ", block_count=" << trans.block_count_ << ", allocate=" << trans.allocate_
+       << ", evict=" << trans.evict_ << ", preempt=" << trans.preempt_ << ", victims=" << trans.victims_;
+    return os;
+}
+
+}  // namespace
+
+void SequenceManager::SortByPriority(Sequences&                   sequences,
+                                     std::vector&            context_lengths,
+                                     const std::vector& priorities)
+{
+    // sort according to priority
+    std::vector idxs(sequences.size());
+    std::iota(idxs.begin(), idxs.end(), 0);
+    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {
+        return priorities[i] < priorities[j];  //
+    });
+    Sequences        tmp_sequences(sequences.size());
+    std::vector tmp_lengths(context_lengths.size());
+    for (int i = 0; i < sequences.size(); ++i) {
+        tmp_sequences[i] = sequences[idxs[i]];
+        tmp_lengths[i]   = context_lengths[idxs[i]];
+    }
+    sequences.swap(tmp_sequences);
+    context_lengths.swap(tmp_lengths);
+}
+
+// template
+// void SortByPriority(const std::vector

& priorities, Ts&... ranges) +// { +// // sort according to priority +// std::vector idxs(priorities.size()); +// std::iota(idxs.begin(), idxs.end(), 0); +// std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { +// return priorities[i] < priorities[j]; // +// }); +// auto reorder = [&](auto& src) { +// auto dst = src; +// for (size_t i = 0; i < idxs.size(); ++i) { +// dst[i] = src[idxs[i]]; +// } +// src.swap(dst); +// }; +// (reorder(ranges), ...); +// } + +std::vector SequenceManager::CountRequiredBlocks(const Sequences& sequences, + const std::vector& context_lengths, + int step_length) +{ + std::vector required(sequences.size()); + for (int i = 0; i < sequences.size(); ++i) { + int seq_len = context_lengths[i] + step_length; + int count = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast(sequences[i]->blocks.size()); + required[i] = std::max(0, count); + } + return required; +} + +void SequenceManager::AssignAndActivate(const Sequences& sequences, // + const std::vector& counts, + const BlockIds& blocks, + const UniqueIds& unique_ids) +{ + FT_CHECK(sequences.size() == counts.size()); + int first = 0; + for (int i = 0; i < sequences.size(); ++i) { + auto& s = const_cast(*sequences[i]); + auto count = counts[i]; + int last = first + count; + FT_CHECK(last <= blocks.size()); + s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last); + s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last); + s.status = Sequence::kActive; + first = last; + } +} + +auto SequenceManager::Materialize(Sequences sequences, + std::vector context_lengths, + const std::vector& priorities, + int step_length, + AdjustInputCount adjust) -> Outcome +{ + //////////////////////////////////////////////////////////////////////////////// + /// Schedule the assignment of blocks to sequences + + // process deferred unlock and free operations + CommitUnlockAndFree(); + + SortByPriority(sequences, context_lengths, priorities); + + // SortByPriority(priorities, sequences, context_lengths); + + // Verify and lock cache sequences to avoid their blocks being evicted unnoticed + // the blocks can still be preempted later + VerifyAndLockCached(sequences); + + auto [input_count1, input_count2] = adjust(sequences, context_lengths); + + std::vector required = CountRequiredBlocks(sequences, context_lengths, step_length); + // dbg(required); + + Schedule schedule(block_manager_->TakeSnapshot(), sequences.size(), input_count1, input_count2); + + // `schedule.last` is decreasing in the loop + for (int i = 0; i < schedule.last; ++i) { + const int input_length = context_lengths[i] - sequences[i]->cache_len; + Transaction{sequences, i, required[i], input_length, schedule}.Process(); + } + + // mark remaining sequences invalid + for (int i = schedule.last; i < sequences.size(); ++i) { + schedule.inactive.push_back(sequences[i]); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked) + + // combine allocate and evict since evicted blocks are reused by allocation + schedule.allocate += schedule.evict; + + if (schedule.allocate) { + dbg(*block_manager_); + } + + Outcome outcome{}; + outcome.allocation = schedule.allocate; + outcome.swap_in = std::count_if(schedule.active.begin(), schedule.active.end(), [](auto p) { + if (p->status != Sequence::kActive) { + dbg(*p); + } + return p->status != Sequence::kActive; // + }); + outcome.swap_out = std::count_if(schedule.inactive.begin(), schedule.inactive.end(), [](auto p) { + if (p->status == Sequence::kActive) { + dbg(*p); + } + return p->status == Sequence::kActive; // + }); + + // release preempted blocks -> cached + if (!schedule.victims.empty()) { + for (const auto& p : schedule.victims) { + UpdateAndSetUnlock(*p); + } + CommitUnlockAndFree(); + } + + // evict cached blocks -> free + if (schedule.evict) { + block_manager_->Evict(schedule.evict); + need_verify_ = true; + } + + // allocate & assign blocks + { + BlockIds block_ids; + UniqueIds unique_ids; + if (schedule.allocate) { + std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate); + } + AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids); + } + + // active -> locked + for (const auto& p : schedule.inactive) { + if (p->status == Sequence::kActive) { + const_cast(p)->status = Sequence::kLocked; + } + } + + // TM_LOG_ERROR("active: %4d, cached: %4d, free: %4d", + // block_manager_->active_count(), + // block_manager_->cached_count(), + // block_manager_->free_count()); + + return outcome; +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h new file mode 100644 index 0000000000..b3100b12ad --- /dev/null +++ b/src/turbomind/models/llama/SequenceManager.h @@ -0,0 +1,147 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/turbomind/models/llama/BlockManager.h" +#include + +namespace turbomind { + +struct Sequence { + + enum Status + { + kCached = 0, + kLocked, + kActive + }; + + uint64_t id; + Status status = kCached; + + BlockIds blocks; + UniqueIds block_unique_ids; + + int input_length = 0; + + mutable std::vector tokens; // update by user + + mutable int cache_len = 0; + + // additional data kept round-to-round + mutable std::vector random_state; // update by user + + mutable float rope_theta = 0.f; + + explicit Sequence(uint64_t _id): id(_id) {} + + friend std::ostream& operator<<(std::ostream& os, const Sequence& seq); +}; + +using Sequences = std::vector; + +inline std::ostream& operator<<(std::ostream& os, const Sequence& seq) +{ + os << "id=" << seq.id << ", status=" << seq.status << ", token_count=" << seq.tokens.size() + << ", block_count=" << seq.blocks.size() << ", cache_len=" << seq.cache_len + << ", random_state_size=" << seq.random_state.size(); + return os; +} + +class SequenceManager { +public: + explicit SequenceManager(size_t layer_num, + size_t head_num, + size_t head_dim, + size_t block_seq_len, + double block_count, + int chunk_size, + size_t elem_bits, + int rank, + IAllocator* allocator); + + SequenceManager(const SequenceManager&) = delete; + SequenceManager(SequenceManager&&) noexcept = default; + + [[nodiscard]] const Sequence* Create(uint64_t id); + + [[nodiscard]] const Sequence* Get(uint64_t id); + + [[nodiscard]] bool Contains(uint64_t id); + + [[nodiscard]] bool Erase(uint64_t id); + + void UpdateAndSetUnlock(const Sequence& seq); + + struct Outcome { + int allocation; + int swap_in; + int swap_out; + }; + + using AdjustInputCount = std::function(const Sequences&, const std::vector&)>; + + [[nodiscard]] Outcome Materialize(Sequences sequences, + std::vector context_lengths, + const std::vector& priorities, + int step_length, + AdjustInputCount adjust); + + [[nodiscard]] void* GetKeyPtr(int block_id) + { + return block_manager_->block(block_id).data; + } + + [[nodiscard]] void* GetValPtr(int block_id) + { + return (std::byte*)GetKeyPtr(block_id) + val_offset_; + } + + int max_block_count() const noexcept + { + return block_manager_->max_block_count(); + } + +private: + void Erase(std::map::iterator it); + + void CommitUnlockAndFree(); + + void VerifyAndLockCached(const Sequences& sequences); + + std::vector CountRequiredBlocks(const Sequences& sequences, // + const std::vector& context_lengths, + int step_length); + + static void SortByPriority(Sequences& sequences, // + std::vector& context_lengths, + const std::vector& priorities); + + static void AssignAndActivate(const Sequences& sequences, // + const std::vector& counts, + const BlockIds& blocks, + const UniqueIds& unique_ids); + +private: + int block_seq_len_; + int rank_; + size_t val_offset_{}; + + bool need_verify_{}; + + // Use `std::map` to avoid reference invalidation + std::map sequences_; + + std::unique_ptr block_manager_; + + BlockIds unlocked_; + BlockIds freed_; +}; + +inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc) +{ + os << "allocation: " << oc.allocation << ", swap-in: " << oc.swap_in << ", swap-out: " << oc.swap_out; + return os; +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/copy.h b/src/turbomind/models/llama/copy.h new file mode 100644 index 0000000000..afc55db24c --- /dev/null +++ b/src/turbomind/models/llama/copy.h @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/turbomind/models/llama/llama_kernels.h" +#include "src/turbomind/utils/cuda_utils.h" + +namespace turbomind { + +class BatchedCopy { +public: + template = 0> + T* Add(const T* src, int size, T* dst) + { + src_.push_back((void*)src); + dst_.push_back((void*)dst); + size_.push_back(sizeof(T) * size); + return dst + size; + } + + void Submit(cudaStream_t stream) + { + invokeBatchedCopy(src_.data(), dst_.data(), size_.data(), size_.size(), stream); + sync_check_cuda_error(); + + src_.clear(); + dst_.clear(); + size_.clear(); + } + +private: + std::vector src_; + std::vector dst_; + std::vector size_; +}; + +} // namespace turbomind diff --git a/src/turbomind/models/llama/flash_attention2/CMakeLists.txt b/src/turbomind/models/llama/flash_attention2/CMakeLists.txt index 9f527d7d1a..1a1fe37eaa 100644 --- a/src/turbomind/models/llama/flash_attention2/CMakeLists.txt +++ b/src/turbomind/models/llama/flash_attention2/CMakeLists.txt @@ -4,10 +4,10 @@ project(flash_attention2) add_library(${PROJECT_NAME} STATIC flash_api.cpp - flash_fwd_hdim32_fp16_sm80.cu - flash_fwd_hdim64_fp16_sm80.cu + # flash_fwd_hdim32_fp16_sm80.cu + # flash_fwd_hdim64_fp16_sm80.cu flash_fwd_hdim128_fp16_sm80.cu - flash_fwd_hdim256_fp16_sm80.cu + # flash_fwd_hdim256_fp16_sm80.cu ) target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include) target_link_libraries(${PROJECT_NAME} PRIVATE nvidia::cutlass::cutlass) diff --git a/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h b/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h index 4a94da08b2..e108a55f28 100644 --- a/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h +++ b/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h @@ -14,7 +14,13 @@ template __global__ void flash_fwd_kernel(Flash_fwd_params params) { + +#if __CUDA_ARCH__ >= 800 flash::compute_attn(params); +#else + // TODO: support flash attention2 on sm<80 + assert(false); +#endif } template @@ -57,6 +63,7 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) }); } +#if 0 template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { @@ -94,6 +101,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) } }); } +#endif template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) @@ -139,6 +147,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) }); } +#if 0 template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { @@ -168,3 +177,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) // Is_causal>(params, stream); }); } +#endif diff --git a/src/turbomind/models/llama/flash_attention2/static_switch.h b/src/turbomind/models/llama/flash_attention2/static_switch.h index bf4a9195ea..fd19a0ea61 100644 --- a/src/turbomind/models/llama/flash_attention2/static_switch.h +++ b/src/turbomind/models/llama/flash_attention2/static_switch.h @@ -38,6 +38,7 @@ } \ }() +#if 0 #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ @@ -57,3 +58,10 @@ return __VA_ARGS__(); \ } \ }() +#else +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + }() +#endif diff --git a/src/turbomind/models/llama/llama_decoder_kernels.cu b/src/turbomind/models/llama/llama_decoder_kernels.cu index 18c7f4deea..1fe1281af7 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.cu +++ b/src/turbomind/models/llama/llama_decoder_kernels.cu @@ -101,6 +101,8 @@ __device__ T blockReduceSum(const cg::thread_block& block, T value) return cg::reduce(tile, value, cg::plus{}); } +// r' = r + x +// x' = norm(r') * scales template __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, T* __restrict__ x_data, diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index ebbfa7ee26..ff628dcced 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -1,11 +1,21 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h" +#include "src/turbomind/kernels/gemm_s_f16/common.h" #include "src/turbomind/kernels/reduce_kernel_utils.cuh" #include "src/turbomind/macro.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/cuda_type_utils.cuh" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/dispatch.h" +#include "src/turbomind/utils/logger.h" +#include +#include +#include +#include +#include namespace turbomind { @@ -199,392 +209,248 @@ void invokeCreateCausalMasks( template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t); template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t); -template -__global__ void extend_key_cache(T** k_dst, - const size_t dst_offset, - const T* k_src, - const int head_num, - const int size_per_head, - const int* query_length, - const int* history_length, - const int max_q_len, - const int max_seq_len) -{ - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; - - // x dim is now handled by uint4 type - const auto key_src = reinterpret_cast(k_src); - const auto key_dst = reinterpret_cast(k_dst[batch_id] + dst_offset); - - const auto seq_len = query_length[batch_id]; - const auto t_offset = history_length[batch_id]; - - const int k_head_size_id = idx % size_per_head_div_x; - const int k_seq_len_id = idx / size_per_head_div_x; - - if (k_seq_len_id < seq_len) { - // [B, H, s, D/x] -> [H, D/x, S[t:t+s]] - - const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H - k_head_size_id * max_seq_len + // D/x - t_offset + k_seq_len_id; // s + offset - - const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B - head_id * size_per_head_div_x * max_q_len + // H - k_seq_len_id * size_per_head_div_x + // s - k_head_size_id; // D/x - - key_dst[dst_idx] = key_src[src_idx]; - } -} +template +struct ExtendKvCache { -template -__global__ void extend_value_cache(T** v_dst, - const size_t dst_offset, - const T* v_src, - const int head_num, - const int size_per_head, - const int* query_length, - const int* history_length, - const int max_q_len, - const int max_seq_len) -{ - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; - - // x dim is now handled by uint4 type - const auto val_src = reinterpret_cast(v_src); - const auto val_dst = reinterpret_cast(v_dst[batch_id] + dst_offset); - - const auto seq_len = query_length[batch_id]; - const auto t_offset = history_length[batch_id]; - - const int v_head_size_id = idx % size_per_head_div_x; - const int v_seq_len_id = idx / size_per_head_div_x; - - if (v_seq_len_id < seq_len) { - // [B, H, s, D/x] -> [H, S[t:t+s], D/x] - const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H - (v_seq_len_id + t_offset) * size_per_head_div_x + // s + offset - v_head_size_id; // D/x - - const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B - head_id * size_per_head_div_x * max_q_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x - - val_dst[dst_idx] = val_src[src_idx]; - } -} + static constexpr int MaxElemSize = std::max(sizeof(Ti), sizeof(To)); + static constexpr int X_ELEMS = 16 / MaxElemSize; -inline __device__ float2 float2div(float a, float2 b) -{ - float2 c; - c.x = b.x / a; - c.y = b.y / a; - return c; -} + using Vi = Array; + using Vo = Array; -inline __device__ float2 float2sub(float zp, float2 val) -{ - float2 ret; - ret.x = val.x - zp; - ret.y = val.y - zp; - return ret; -} + using Transform = ConvertKvCache; -static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale, const float zp) -{ - half4 dst; - dst.x = __float2half(value.x * scale + zp); - dst.y = __float2half(value.y * scale + zp); - dst.z = __float2half(value.z * scale + zp); - dst.w = __float2half(value.w * scale + zp); - return dst; -} + struct Params { + To** k_dst_ptrs; + To** v_dst_ptrs; + const Ti* k_src; + const Ti* v_src; + const int* cu_block_counts; + const int* query_length; + const int* context_length; + int block_length; + size_t dst_layer_offset; + int max_q_len; + int head_num; + int head_dim; + Transform transform_k; + Transform transform_v; + }; -static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w) -{ - uint32_t dst; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720 - uint32_t a; - asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); - uint32_t b; - asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); - uint32_t c; - asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); - uint32_t d; - asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + __device__ void operator()(const Params& params) const + { + const int batch_id = blockIdx.y; - asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); - asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); -#else - char4 tmp; - tmp.x = x; - tmp.y = y; - tmp.z = z; - tmp.w = w; - dst = reinterpret_cast(tmp); -#endif - return dst; -} + const int query_len = params.query_length[batch_id]; + const int history_len = params.context_length[batch_id] - query_len; + const int cu_block_cnt = params.cu_block_counts[batch_id]; -template -__global__ void extend_value_cache_int8(int8_t** v_dst, - const size_t dst_offset, - const T* v_src, - const int head_num, - const int size_per_head, - const int* query_length, - const int* history_length, - const int max_q_len, - const int max_seq_len, - const float v_scale, - const float v_zp) -{ - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; - - // x dim is now handled by uint4 type - const auto val_src = reinterpret_cast(v_src); - const auto val_dst = reinterpret_cast(v_dst[batch_id] + dst_offset); - - const auto seq_len = query_length[batch_id]; - const auto t_offset = history_length[batch_id]; - - const int v_head_size_id = idx % size_per_head_div_x; - const int v_seq_len_id = idx / size_per_head_div_x; - - if (v_seq_len_id < seq_len) { - // [B, H, s, D/x] -> [H, S[t:t+s], D/x] - const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H - (v_seq_len_id + t_offset) * size_per_head_div_x + // s + offset - v_head_size_id; // D/x - - const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B - head_id * size_per_head_div_x * max_q_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x - - // scale to int8 and write - const auto value = val_src[src_idx]; - auto to_ptr = reinterpret_cast(val_dst + dst_idx); - - float2 float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.x))); - float2 float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.y))); - to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); - - float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.z))); - float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.w))); - to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); + const int head_id = blockIdx.z; + + const int size_per_head_div_x = params.head_dim / X_ELEMS; + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int head_size_id = idx % size_per_head_div_x; + const int seq_len_id = idx / size_per_head_div_x; + + const int cache_block_index = (seq_len_id + history_len) / params.block_length; + const int cache_block_offset = (seq_len_id + history_len) % params.block_length; + + const auto k_val_src = params.k_src; + const auto v_val_src = params.v_src; + + const auto k_val_dst = (params.k_dst_ptrs + cu_block_cnt)[cache_block_index] + params.dst_layer_offset; + const auto v_val_dst = (params.v_dst_ptrs + cu_block_cnt)[cache_block_index] + params.dst_layer_offset; + + if (seq_len_id < query_len) { + // [B, H, s, D/x] -> [H, S[t:t+s], D/x] + const int64_t dst_idx = head_id * params.block_length * size_per_head_div_x + // H + cache_block_offset * size_per_head_div_x + // s + offset + head_size_id; // D/x + + const int64_t src_idx = batch_id * params.head_num * params.max_q_len * size_per_head_div_x + // B + head_id * params.max_q_len * size_per_head_div_x + // H + seq_len_id * size_per_head_div_x + // s + head_size_id; // D/x + + Vi k_vi; + Vi v_vi; + + Ldg(k_vi, k_val_src + src_idx * X_ELEMS); + Ldg(v_vi, v_val_src + src_idx * X_ELEMS); + + Vo k_vo = params.transform_k(k_vi); + Vo v_vo = params.transform_v(v_vi); + + Store(k_val_dst + dst_idx * X_ELEMS, k_vo); + Store(v_val_dst + dst_idx * X_ELEMS, v_vo); + } } -} +}; + +namespace { + +template +__global__ void KernelWrapper(Params params) +{ + Kernel{}(params); +}; + +} // namespace template -void invokeExtendKVCache(T** k_dst, - T** v_dst, - size_t dst_offset, +void invokeExtendKVCache(void** k_dst_ptrs, + void** v_dst_ptrs, const T* k_src, const T* v_src, - int local_batch_size, + const int* cu_block_counts, const int* query_length, + const int* context_length, + int batch_size, + int block_length, + size_t dst_layer_offset, int max_q_len, - const int* history_length, - int max_seq_len, - int size_per_head, - int local_head_num, - cudaStream_t stream, + int head_dim, + int head_num, int quant, - const float* kv_scale) + const float* kv_params, + cudaStream_t stream) { constexpr int block_sz = 128; - constexpr int x = (sizeof(T) == 4) ? 4 : 8; - - dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); - - if (quant & QuantPolicy::kCacheKVInt8) { - extend_value_cache_int8<<>>(reinterpret_cast(k_dst), - dst_offset, - k_src, - local_head_num, - size_per_head, - query_length, - history_length, - max_q_len, - max_seq_len, - kv_scale[0], - kv_scale[1]); - - extend_value_cache_int8<<>>(reinterpret_cast(v_dst), - dst_offset, - v_src, - local_head_num, - size_per_head, - query_length, - history_length, - max_q_len, - max_seq_len, - kv_scale[2], - kv_scale[3]); - } - else { - extend_value_cache<<>>(k_dst, - dst_offset, - k_src, - local_head_num, - size_per_head, - query_length, - history_length, - max_q_len, - max_seq_len); - - extend_value_cache<<>>(v_dst, - dst_offset, - v_src, - local_head_num, - size_per_head, - query_length, - history_length, - max_q_len, - max_seq_len); - } -} -template void invokeExtendKVCache(float**, - float**, - size_t, - const float*, - const float*, - int, - const int*, - int, - const int*, - int, - int, - int, - cudaStream_t stream, - int, - const float*); - -template void invokeExtendKVCache(half**, - half**, - size_t, - const half*, - const half*, - int, - const int*, - int, - const int*, - int, - int, - int, - cudaStream_t stream, - int, - const float*); - -template -__global__ void transpose_value_cache(T* v_dst, // - const T** v_src, - const size_t src_offset, - const int head_num, - const int head_n_rep, - const int size_per_head, - const int* seq_length, - const int max_kv_len, - const int max_seq_len) -{ - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; - - // x dim is now handled by uint4 type - const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); - const auto val_dst = reinterpret_cast(v_dst); - - const auto seq_len = seq_length[batch_id]; - - const int v_head_size_id = idx % size_per_head_div_x; - const int v_seq_len_id = idx / size_per_head_div_x; - - if (v_seq_len_id < seq_len) { - // [B, H, s, D/x] <- [B, H, S[:s], D/x] - const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x - - const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B - head_id * size_per_head_div_x * max_kv_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x - - val_dst[dst_idx] = val_src[src_idx]; - } + auto fn = [&](auto value) { + using Tout = decltype(value); + using Kernel = ExtendKvCache; + + dim3 grid((max_q_len * head_dim / Kernel::X_ELEMS + block_sz - 1) / block_sz, batch_size, head_num); + + typename Kernel::Params params{(Tout**)k_dst_ptrs, + (Tout**)v_dst_ptrs, + k_src, + v_src, + cu_block_counts, + query_length, + context_length, + block_length, + dst_layer_offset, + max_q_len, + head_num, + head_dim, + {kv_params[0], kv_params[1]}, + {kv_params[2], kv_params[3]}}; + + KernelWrapper<<>>(params); + }; + + (quant & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{}); } -template -__global__ void transpose_value_cache_int8(T* v_dst, // - const int8_t** v_src, - const size_t src_offset, - const int head_num, - const int head_n_rep, - const int size_per_head, - const int* seq_length, - const int max_kv_len, - const int max_seq_len, - const float v_scale, - const float v_zp) -{ - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - int size_per_head_div_x = size_per_head / X_ELEMS; - - // x dim is now handled by uint4 type - const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); - const auto val_dst = reinterpret_cast(v_dst); - - const auto seq_len = seq_length[batch_id]; - - const int v_head_size_id = idx % size_per_head_div_x; - const int v_seq_len_id = idx / size_per_head_div_x; - - if (v_seq_len_id < seq_len) { - // [B, H, s, D/x] <- [B, H, S[:s], D/x] - const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x - - const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B - head_id * size_per_head_div_x * max_kv_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x - - // int8x8 -> fp16x8 - const auto from_ptr = reinterpret_cast(val_src + src_idx); - auto to_ptr = reinterpret_cast(val_dst + dst_idx); - - to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale, v_zp); - to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale, v_zp); +template void invokeExtendKVCache(void** k_dst_ptrs, + void** v_dst_ptrs, + const float* k_src, + const float* v_src, + const int* cu_block_counts, + const int* query_length, + const int* history_length, + int batch_size, + int block_length, + size_t dst_layer_offset, + int max_q_len, + int head_dim, + int head_num, + int quant, + const float* kv_scale, + cudaStream_t stream); + +template void invokeExtendKVCache(void** k_dst_ptrs, + void** v_dst_ptrs, + const half* k_src, + const half* v_src, + const int* cu_block_counts, + const int* query_length, + const int* history_length, + int batch_size, + int block_length, + size_t dst_layer_offset, + int max_q_len, + int head_dim, + int head_num, + int quant, + const float* kv_scale, + cudaStream_t stream); + +template +struct TransposeKvCache { + static constexpr int MaxElemSize = std::max(sizeof(Ti), sizeof(To)); + static constexpr int X_ELEMS = 16 / MaxElemSize; + + using Vi = Array; + using Vo = Array; + + using Transform = ConvertKvCache; + + struct Params { + To* k_dst; + To* v_dst; + const Ti** k_src; + const Ti** v_src; + size_t src_offset; + int head_num; + int head_n_rep; + int size_per_head; + const int* seq_length; + int max_kv_len; + int max_seq_len; + Transform transform_k; + Transform transform_v; + // float k_scale; + // float k_zp; + // float v_scale; + // float v_zp; + }; + + __device__ void operator()(const Params& params) const + { + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int size_per_head_div_x = params.size_per_head / X_ELEMS; + + const auto k_src = params.k_src[batch_id] + params.src_offset; + const auto v_src = params.v_src[batch_id] + params.src_offset; + const auto k_dst = params.k_dst; + const auto v_dst = params.v_dst; + + const auto seq_len = params.seq_length[batch_id]; + + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] <- [B, H, S[:s], D/x] + const int64_t src_idx = head_id / params.head_n_rep * size_per_head_div_x * params.max_seq_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + const int64_t dst_idx = batch_id * params.head_num * size_per_head_div_x * params.max_kv_len + // B + head_id * size_per_head_div_x * params.max_kv_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + Vi k_vi; + Vi v_vi; + + Ldg(k_vi, k_src + src_idx * X_ELEMS); + Ldg(v_vi, v_src + src_idx * X_ELEMS); + + Vo k_vo = params.transform_k(k_vi); + Vo v_vo = params.transform_v(v_vi); + + Store(k_dst + dst_idx * X_ELEMS, k_vo); + Store(v_dst + dst_idx * X_ELEMS, v_vo); + } } -} +}; template void invokeTransposeKVCache(T* key_cache_trans, @@ -601,59 +467,34 @@ void invokeTransposeKVCache(T* key_cache_trans, int head_n_rep, cudaStream_t stream, int quant, - const float* kv_scale) + const float* kv_params) { constexpr int block_sz = 128; - constexpr int x = (sizeof(T) == 4) ? 4 : 8; - - dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num); - - if (quant & QuantPolicy::kCacheKVInt8) { - transpose_value_cache_int8<<>>(key_cache_trans, - reinterpret_cast(key_cache), - src_offset, - head_num, - head_n_rep, - size_per_head, - key_length, - max_kv_len, - max_seq_len, - kv_scale[0], - kv_scale[1]); - - transpose_value_cache_int8<<>>(val_cache_trans, - reinterpret_cast(val_cache), - src_offset, - head_num, - head_n_rep, - size_per_head, - key_length, - max_kv_len, - max_seq_len, - kv_scale[2], - kv_scale[3]); - } - else { - transpose_value_cache<<>>(key_cache_trans, - key_cache, - src_offset, - head_num, - head_n_rep, - size_per_head, - key_length, - max_kv_len, - max_seq_len); - - transpose_value_cache<<>>(val_cache_trans, - val_cache, - src_offset, - head_num, - head_n_rep, - size_per_head, - key_length, - max_kv_len, - max_seq_len); - } + + auto fn = [&](auto value) { + using Tin = decltype(value); + using Kernel = TransposeKvCache; + + dim3 grid((max_kv_len * size_per_head / Kernel::X_ELEMS + block_sz - 1) / block_sz, batch_size, head_num); + + typename Kernel::Params params{key_cache_trans, + val_cache_trans, + (const Tin**)key_cache, + (const Tin**)val_cache, + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len, + {kv_params[0], kv_params[1]}, + {kv_params[2], kv_params[3]}}; + + KernelWrapper<<>>(params); + }; + + (quant & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{}); } template void invokeTransposeKVCache(float*, @@ -704,8 +545,10 @@ __global__ void gatherOutput(int* output_ids, continue; } // skip padding for dst - const int dst_idx = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len); - output_ids[dst_idx] = ids[src_idx * batch_size + batch_id]; + const int dst_idx = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len); + if (dst_idx < max_output_len) { + output_ids[dst_idx] = ids[src_idx * batch_size + batch_id]; + } } } @@ -718,12 +561,292 @@ void invokeGatherOutput(int* output_ids, int batch_size, cudaStream_t stream) { - int block_size = 512; + int block_size = 128; int grid_size = batch_size; gatherOutput<<>>( output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size); } +__global__ void updateOutput(int** request_output_ids_ptrs, + int** request_seqlen_ptrs, + const int* output_ids, + const int* sequence_lengths, + const int* request_output_ids_lens, + int max_session_len, + bool token_generated) +{ + const int batch_id = blockIdx.x; + + auto request_output_ids = request_output_ids_ptrs[batch_id]; + auto request_seqlen = request_seqlen_ptrs[batch_id]; + + output_ids += max_session_len * batch_id; + + const int seqlen = sequence_lengths[batch_id] + (int)token_generated; + const int output_len = min(seqlen, request_output_ids_lens[batch_id]); + + for (int i = threadIdx.x; i < output_len; i += blockDim.x) { + request_output_ids[i] = output_ids[i]; + } + + *request_seqlen = seqlen; +} + +void invokeUpdateOutput(int** request_output_ids_ptrs, + int** request_seqlen_ptrs, + const int* output_ids, + const int* sequence_lengths, + const int* request_output_ids_lens, + int max_session_len, + bool token_generated, + int batch_size, + cudaStream_t stream) +{ + constexpr int block_size = 128; + const int grid_size = batch_size; + + updateOutput<<>>(request_output_ids_ptrs, + request_seqlen_ptrs, + output_ids, + sequence_lengths, + request_output_ids_lens, + max_session_len, + token_generated); +} + +template +__global__ void compactOutputIds( + int* cu_output_ids, const int* output_ids, const int* sequence_lengths, int session_len, bool token_generated) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int batch_idx = blockIdx.x; + + int end = (batch_idx + BLOCK_DIM - 1) / BLOCK_DIM * BLOCK_DIM; // align to BLOCK_DIM boundary + int count = 0; + for (int i = threadIdx.x; i < end; i += blockDim.x) { + int x = threadIdx.x < batch_idx ? sequence_lengths[threadIdx.x] : 0; + count += BlockReduce(temp_storage).Sum(x); + // https://nvlabs.github.io/cub/classcub_1_1_block_reduce.html + __syncthreads(); + } + + __shared__ int offset; + + if (threadIdx.x == 0) { + offset = count; + } + + __syncthreads(); + + auto dst = cu_output_ids + offset; + + const int seq_len = sequence_lengths[batch_idx]; + + for (int i = threadIdx.x; i < seq_len; i += blockDim.x) { + dst[i] = output_ids[batch_idx * session_len + i]; + } +} + +void invokeCompactOutputIds(int* cu_output_ids, + const int* output_ids, + const int* sequence_lengths, + int max_session_len, + bool token_generated, + int batch_size, + cudaStream_t stream) +{ + constexpr int BLOCK_DIM = 128; + compactOutputIds<<>>( + cu_output_ids, output_ids, sequence_lengths, max_session_len, token_generated); +} + +template +struct IndexedCopyParam { + Array src_ptr; + Array dst_ptr; + Array stride; + Array src_idx; + Array dst_idx; + int max_stride; +}; + +template +__global__ void indexedCopy(IndexedCopyParam param) +{ + const int bi = blockIdx.x; + const int si = param.src_idx[bi]; + const int di = param.dst_idx[bi]; + for (int i = threadIdx.x; i < param.max_stride; i += blockDim.x) { + PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (i < param.stride[k]) { + *((T*)param.dst_ptr[k] + param.stride[k] * di + i) = + *((const T*)param.src_ptr[k] + param.stride[k] * si + i); + } + } + } +} + +template +void invokeIndexedCopyImpl(void** h_src_ptr, + void** h_dst_ptr, + const int* h_elem_sz, + const int* h_src_idx, + const int* h_dst_idx, + int count, + cudaStream_t st) +{ + dispatch( // dispatch for num of copy operations + std::integer_sequence{}, + [&](auto C) { return count <= C; }, + [&](auto C) { + // maximum parameter size: sm<70: 4kB, sm>=70: 32kB + static_assert(sizeof(IndexedCopyParam) <= 4096); + IndexedCopyParam param{}; + std::copy_n(h_src_ptr, N, param.src_ptr.data()); + std::copy_n(h_dst_ptr, N, param.dst_ptr.data()); + std::transform(h_elem_sz, h_elem_sz + N, param.stride.data(), [](int size) { + // Basic alignment check + FT_CHECK_WITH_INFO(size % sizeof(T) == 0, fmtstr("misalignment: %d %% %d", size, (int)sizeof(T))); + return size / sizeof(T); + }); + param.max_stride = *std::max_element(param.stride.begin(), param.stride.end()); + auto copy_idx = [](const int* src, int offset, int n, auto dst) { + return src ? (void)std::copy_n(src + offset, n, dst) : std::iota(dst, dst + n, offset); + }; + for (int c = 0; c < count; c += C) { + int batch_size = std::min(count - c, (int)C); + copy_idx(h_src_idx, c, batch_size, param.src_idx.data()); + copy_idx(h_dst_idx, c, batch_size, param.dst_idx.data()); + indexedCopy<<>>(param); + } + }); +} + +void invokeIndexedCopy(void** h_src_ptr, + void** h_dst_ptr, + const int* h_elem_sz, + const int* h_src_idx, + const int* h_dst_idx, + int count, + int n_copys, + cudaStream_t st) +{ + auto success = dispatch(std::integer_sequence{}, [&](auto N) { + if (N == n_copys) { + invokeIndexedCopyImpl(h_src_ptr, h_dst_ptr, h_elem_sz, h_src_idx, h_dst_idx, count, st); + return true; + } + return false; + }); + FT_CHECK(success); +} + +__global__ void padLastTokenIds(int* token_ids, const int* context_length, int max_context_len, int batch_size) +{ + for (int bi = threadIdx.x; bi < batch_size; bi += blockDim.x) { + token_ids[(max_context_len - 1) * batch_size + bi] = token_ids[(context_length[bi] - 1) * batch_size + bi]; + } +} + +void invokePadLastTokenIds( + int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream) +{ + padLastTokenIds<<<1, 512, 0, stream>>>(token_ids, context_length, max_context_len, batch_size); +} + +template +__global__ void getFeatureOfLastToken(T* output, const T* input, const int* cu_seqlens, int dims) +{ + int bi = blockIdx.x; + int ti = cu_seqlens[bi + 1] - 1; + for (int i = threadIdx.x; i < dims; i += blockDim.x) { + output[dims * bi + i] = input[dims * ti + i]; + } +} + +template +void invokeGetFeatureOfLastToken( + T* output, const T* input, const int* cu_seqlens, int dims, int batch_size, cudaStream_t stream) +{ + getFeatureOfLastToken<<>>(output, input, cu_seqlens, dims); +} + +template void invokeGetFeatureOfLastToken(half*, const half*, const int*, int, int, cudaStream_t); +template void invokeGetFeatureOfLastToken(float*, const float*, const int*, int, int, cudaStream_t); + +template +struct BatchedCopyParam { + Array src_ptr; + Array dst_ptr; + Array size; + int count; +}; + +template +__global__ void batchedCopy(BatchedCopyParam param) +{ + const int ti = threadIdx.x + blockIdx.x * blockDim.x; + const int bi = ti / kThrPerCpy; + if (bi >= param.count) { + return; + } + const T* __restrict__ src = param.src_ptr[bi]; + T* __restrict__ dst = param.dst_ptr[bi]; + int size = param.size[bi]; + for (int i = ti % kThrPerCpy; i < size; i += kThrPerCpy) { + dst[i] = src[i]; + } +} + +// MSVC does not like CUDA kernel launch inside nested lambdas +template +struct BatchedCopyLauncher { + int max_size; + int count; + const P* params; + cudaStream_t st; + + template + void operator()(std::integral_constant) const + { + constexpr int threads = 128; + constexpr int items_per_block = threads / S; + const int blocks = (count + items_per_block - 1) / items_per_block; + batchedCopy<<>>(*params); + } +}; + +void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cudaStream_t st) +{ + dispatch( + std::integer_sequence{}, + [&](auto C) { return count <= C; }, + [&](auto C) { + using T = uint32_t; + BatchedCopyParam params{}; + // TODO: on CUDA 12.1 and sm_70+ this can be 32K + static_assert(sizeof(params) <= 4096); + for (int c = 0; c < count; c += C) { + const int bsz = std::min(count - c, C); + params.count = bsz; + for (int i = 0; i < bsz; ++i) { + params.src_ptr[i] = (T*)src_ptr[c + i]; + params.dst_ptr[i] = (T*)dst_ptr[c + i]; + FT_CHECK(size[c + i] % sizeof(T) == 0); + params.size[i] = size[c + i] / sizeof(T); + } + const int max_size = *std::max_element(params.size.begin(), params.size.end()); + dispatch( + std::integer_sequence{}, + [&](auto S) { return max_size <= S; }, + BatchedCopyLauncher>{max_size, count, ¶ms, st}); + } + }); +} + #define VERSION_SWITCH(VERSION, CONST_NAME, ...) \ [&] { \ if (VERSION == 2) { \ diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index 06cb24e042..5dedfe71c0 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -34,21 +34,22 @@ void invokeCreateCausalMasks( T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream); template -void invokeExtendKVCache(T** k_dst, - T** v_dst, - size_t layer_offset, +void invokeExtendKVCache(void** k_dst_ptrs, + void** v_dst_ptrs, const T* k_src, const T* v_src, - int batch_size, + const int* cu_block_counts, const int* query_length, + const int* context_length, + int batch_size, + int block_length, + size_t dst_layer_offset, int max_q_len, - const int* history_length, - int max_seq_len, - int size_per_head, - int local_head_num, - cudaStream_t stream, + int head_dim, + int head_num, int quant, - const float* kv_scale); + const float* kv_scale, + cudaStream_t stream); template void invokeTransposeKVCache(T* key_cache_trans, @@ -76,6 +77,48 @@ void invokeGatherOutput(int* output_ids, int batch_size, cudaStream_t stream); +void invokeUpdateOutput(int** request_output_ids_ptrs, + int** request_seqlen_ptrs, + const int* output_ids, + const int* sequence_lengths, + const int* request_output_ids_lens, + int max_session_len, + bool token_generated, + int batch_size, + cudaStream_t stream); + +// [aaa, bbbb, cc, ddd] -> [aaabbbbccddd] +void invokeCompactOutputIds(int* cu_output_ids, + const int* output_ids, + const int* sequence_lengths, + int max_session_len, + bool token_generated, + int batch_size, + cudaStream_t stream); + +void invokeIndexedCopy(void** h_src_ptr, + void** h_dst_ptr, + const int* h_elem_sz, + const int* h_src_idx, + const int* h_dst_idx, + int count, + int n_copys, + cudaStream_t st); + +void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cudaStream_t st); + +// ABCDe ABCDe e +// ABCDEFGHIJk ABCDEFGHIJk +// ABCDEFGHi -> ABCDEFGHi i +// ABCDEFGh ABCDEFGh h +// ABCd ABCd d +void invokePadLastTokenIds( + int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream); + +template +void invokeGetFeatureOfLastToken( + T* output, const T* input, const int* cu_seqlens, int dims, int batch_size, cudaStream_t stream); + void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st); template diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 8f8c96837b..a8f5a6a1a8 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -5,11 +5,29 @@ namespace turbomind { struct LlamaAttentionParams { - int rotray_embedding_dim; + int rotary_embedding_dim; float rotary_embedding_base; int max_position_embeddings; - bool use_dynamic_ntk; - bool use_logn_attn; + float rope_scaling_factor; + // bool use_dynamic_ntk; + bool use_logn_attn; +}; + +struct EngineParams { + // batch params + int max_batch_size; + int session_len; + int step_length; + + // cache params + float cache_max_block_count; + int cache_chunk_size; + + // chunking params + int max_context_token_num; + int num_tokens_per_iter; + int extra_tokens_per_iter; + int max_prefill_iters; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_utils.cu b/src/turbomind/models/llama/llama_utils.cu index 7050d2d13f..93de6afd58 100644 --- a/src/turbomind/models/llama/llama_utils.cu +++ b/src/turbomind/models/llama/llama_utils.cu @@ -157,4 +157,13 @@ bool isDebug() return is_debug; } +int64_t& gSequenceIds(int batch_idx) +{ + thread_local std::vector ids{}; + if (batch_idx >= ids.size()) { + ids.resize(batch_idx + 1, -1); + } + return ids.at(batch_idx); +} + } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h index 05c10be80b..acfe5054ca 100644 --- a/src/turbomind/models/llama/llama_utils.h +++ b/src/turbomind/models/llama/llama_utils.h @@ -2,6 +2,7 @@ #pragma once #include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/nvtx_utils.h" #include #include #include @@ -66,4 +67,18 @@ size_t curandStateGetSize(); bool isDebug(); +struct NvtxScope { + explicit NvtxScope(const std::string& name) + { + PUSH_RANGE(name.c_str()); + } + + ~NvtxScope() + { + POP_RANGE; + } +}; + +int64_t& gSequenceIds(int batch_idx); + } // namespace turbomind diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc new file mode 100644 index 0000000000..16629565f1 --- /dev/null +++ b/src/turbomind/models/llama/test_cache_manager.cc @@ -0,0 +1,116 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "BlockManager.h" +#include "SequenceManager.h" + +#include "src/turbomind/utils/allocator.h" + +#include "src/turbomind/utils/debug_utils.h" +#include +#include + +using namespace turbomind; + +std::ostream& operator<<(std::ostream& os, const Block* b) +{ + os << "(" << b->id << "," << b->timestamp << ")"; + return os; +} + +TEST_CASE("BlockManager") +{ + Allocator allocator(0); + + BlockManager m(1024, 32, 8, &allocator); + REQUIRE(m.max_block_count() == 32); + REQUIRE(m.free_count() == 32); + + auto blocks1 = m.Allocate(10); + + dbg(blocks1); + + REQUIRE(blocks1.size() == 10); + REQUIRE(m.active_count() == blocks1.size()); + REQUIRE(m.free_count() == 22); + + auto blocks2 = m.Allocate(6); + REQUIRE(blocks2.size() == 6); + REQUIRE(m.active_count() == blocks1.size() + blocks2.size()); + REQUIRE(m.free_count() == 16); + + auto blocks3 = m.Allocate(16); + REQUIRE(blocks3.size() == 16); + REQUIRE(m.active_count() == 32); + REQUIRE(m.free_count() == 0); + + std::copy(blocks3.begin(), blocks3.end(), std::back_inserter(blocks1)); + std::copy(blocks2.begin(), blocks2.end(), std::back_inserter(blocks1)); + + m.Touch(blocks1); + + REQUIRE(m.Unlock(blocks1) == 32); + REQUIRE(m.active_count() == 0); + REQUIRE(m.free_count() == 0); + REQUIRE(m.cached_count() == 32); + + m.Evict(16); + REQUIRE(m.active_count() == 0); + REQUIRE(m.free_count() == 16); + REQUIRE(m.cached_count() == 16); + + auto blocks4 = m.Allocate(14); + REQUIRE(m.active_count() == 14); + REQUIRE(m.free_count() == 2); + REQUIRE(m.cached_count() == 16); +} + +TEST_CASE("SequenceManager basic test") +{ + Allocator allocator(0); + + SequenceManager manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator); + + REQUIRE(manager.max_block_count() == 20); + REQUIRE(manager.Contains(1) == false); + + auto s1 = manager.Create(1); + dbg(*s1); + REQUIRE(manager.Contains(1) == true); + + manager.Erase(1); + REQUIRE(manager.Contains(1) == false); + + s1 = manager.Create(1); + REQUIRE(manager.Contains(1) == true); + + auto outcome = manager.Materialize({s1}, {128}, {100}, 1); + dbg(s1->blocks); + REQUIRE(s1->blocks.size() == 2); + + auto s2 = manager.Create(2); + REQUIRE(manager.Contains(2)); + + outcome = manager.Materialize({s1, s2}, {128, 2559}, {2, 1}, 1); + dbg(outcome); + REQUIRE(outcome.allocation == 20); + REQUIRE(outcome.swap_in == 1); + REQUIRE(outcome.swap_out == 1); + + auto s3 = manager.Create(3); + outcome = manager.Materialize({s1, s2, s3}, {127, 2559, 255}, {1, 100, 2}, 1); + dbg(outcome); +} + +TEST_CASE("SequenceManager functional test") +{ + Allocator allocator(0); + SequenceManager manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator); + + auto seq = manager.Create(1); + for (int i = 0; i < 1024; ++i) { + auto outcome = manager.Materialize({seq}, {i}, {0}, 1); + if (outcome.allocation) { + dbg(i, outcome); + } + } +} diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc new file mode 100644 index 0000000000..d9ae5d0be6 --- /dev/null +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -0,0 +1,630 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/GptContextAttentionLayer.cc + +#include "src/turbomind/models/llama/unified_attention_layer.h" +#include "src/turbomind/kernels/bert_preprocess_kernels.h" +#include "src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h" +#include "src/turbomind/kernels/decoder_multihead_attention/kv_cache.h" +#include "src/turbomind/kernels/unfused_attention_kernels.h" +#include "src/turbomind/macro.h" +#include "src/turbomind/models/llama/LlamaNcclGuard.h" +#include "src/turbomind/models/llama/llama_kernels.h" +#include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/debug_utils.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind { + +template +// void UnifiedAttentionLayer::allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t +// max_k_len) +void UnifiedAttentionLayer::allocateBuffer(size_t num_token, + size_t pf_batch_size, + size_t pf_max_q_len, + size_t pf_max_k_len, + size_t dc_batch_size, + size_t dc_max_split_k) +{ + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + + const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + + // no padding + qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, false); + + // qkv_buf_3_ padding is removed + qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, false); + + if (pf_batch_size) { + [&](size_t bsz, size_t max_q, size_t max_k) { + // padding is rebuilt for q/k/v_buf_2_ + // [qH + 2kvH, B, S, D] + q_buf_2_ = (T*)allocator_->reMalloc( + q_buf_2_, sizeof(T) * local_q_kv_head_num * bsz * max_q * size_per_head_, false); + k_buf_2_ = q_buf_2_ + local_head_num_ * bsz * max_q * size_per_head_; + v_buf_2_ = k_buf_2_ + local_kv_head_num_ * bsz * max_q * size_per_head_; + + if (use_fmha_) { + FlashAttentionOp flash_attention(bsz, local_head_num_, max_k, max_q, size_per_head_); + if (flash_attention.get_workspace_size() > 0) { + qk_buf_float_ = + (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), false); + } + } + else { + // kv heads are repeated for unfused attention + k_cache_buf_ = (T*)allocator_->reMalloc( + k_cache_buf_, 2 * sizeof(T) * bsz * local_head_num_ * max_k * size_per_head_, false); + v_cache_buf_ = k_cache_buf_ + bsz * local_head_num_ * max_k * size_per_head_; + + qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * bsz * local_head_num_ * max_q * max_k, false); + + // qkv_buf_2_ has padding + qkv_buf_2_ = (T*)allocator_->reMalloc( + qkv_buf_2_, sizeof(T) * bsz * max_q * local_head_num_ * size_per_head_, false); + } + }(pf_batch_size, pf_max_q_len, pf_max_k_len); + } + + if (dc_batch_size) { + dc_workspace_ = (float*)allocator_->reMalloc(dc_workspace_, + sizeof(float) * dc_batch_size * local_head_num_ * dc_max_split_k + * (size_per_head_ + 2), + false); + } + + is_allocate_buffer_ = true; +} + +template +void UnifiedAttentionLayer::freeBuffer() +{ + if (is_allocate_buffer_) { + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&q_buf_2_)); + allocator_->free((void**)(&qkv_buf_3_)); + + allocator_->free((void**)&qk_buf_float_); + allocator_->free((void**)(&k_cache_buf_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + + allocator_->free((void**)&dc_workspace_); + + is_allocate_buffer_ = false; + } +} + +template +inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMap* inputs, const WeightType* weights) +{ + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + + /** + * input_tensors: + * \param input_query [token_num, hidden_dim] + * \param attention_mask [batch_size, 1, max_q_len, max_kv_len] + * \param padding_offset [token_num], int + * \param input_lengths [batch_size], int + * \param history_lengths [batch_size], int + * \param context_lengths [batch_size], int + * \param cu_seqlens [batch_size+1], int + * \param cu_block_counts [batch_size+1], int + * \param max_seq_len [1], int on cpu + * \param is_final_layer [1], bool on cpu + * \param layer_id [1], int on cpu + * + * output_tensors: + * \param hidden_features [token_num, hidden_dim] + * \param key_cache [batch_size], uint64 + * \param value_cache [batch_size], uint64 + */ + + ///////////////////////////////////////////// + /// parse inputs + const int num_token = inputs->at("input_query").shape[0]; + const int layer_id = inputs->getVal("layer_id"); + const int session_len = inputs->getVal("session_len"); + + int pf_batch_size = 0; + int pf_max_q_len = 0; + int pf_max_k_len = 0; + T* attention_mask{}; + if (inputs->isExist("attention_mask")) { + pf_batch_size = inputs->at("attention_mask").shape[0]; + pf_max_q_len = inputs->at("attention_mask").shape[2]; + pf_max_k_len = inputs->at("attention_mask").shape[3]; + attention_mask = inputs->getPtr("attention_mask"); + } + + const int dc_batch_size = inputs->getVal("dc_batch_size"); + const int dc_sum_seq_len = inputs->getVal("dc_sum_seq_len"); + const int dc_max_seq_len = inputs->getVal("dc_max_seq_len"); + + T* attention_input = inputs->getPtr("input_query"); + int* input_length = inputs->getPtr("input_lengths"); + int* context_length = inputs->getPtr("context_lengths"); + bool* is_finished = inputs->getPtr("finished"); + int* cu_block_count = inputs->getPtr("cu_block_counts"); + int* cu_seqlens = inputs->getPtr("cu_seqlens", nullptr); + int* padding_offset = inputs->getPtr("padding_offset", nullptr); + float* rope_theta = inputs->getPtr("rope_theta", nullptr); + + auto k_cache_ptrs = outputs->getPtr("key_cache"); + auto v_cache_ptrs = outputs->getPtr("value_cache"); + auto tmp_k_ptrs = outputs->getPtr("tmp_k"); + auto tmp_v_ptrs = outputs->getPtr("tmp_v"); + + T* attention_out = outputs->getPtr("hidden_features"); + + ///////////////////////////////////////////// + /// allocate buffers + allocateBuffer(num_token, // + pf_batch_size, + pf_max_q_len, + pf_max_k_len, + dc_batch_size, + kDecodeMaxSplits); + + // [2, L, H, s, D] + const size_t layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_; + + ////////////////////////////////////////////// + /// qkv gemm + // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] + linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv); + + if (pf_batch_size) { + const int offset = dc_batch_size; + const int pf_num_token = num_token - offset; + prefill(qkv_buf_3_ + offset * weights->output.input_dims, + qkv_buf_ + offset * weights->qkv.output_dims, + k_cache_ptrs, + v_cache_ptrs, + attention_mask, + cu_seqlens, + padding_offset, + tmp_k_ptrs + offset, + tmp_v_ptrs + offset, + input_length + offset, + context_length + offset, + cu_block_count + offset, + rope_theta + offset, + pf_batch_size, + pf_num_token, + layer_offset, + pf_max_q_len, + pf_max_k_len, + session_len, + weights); + } + + if (dc_batch_size) { + decode(qkv_buf_3_, + qkv_buf_, + k_cache_ptrs, + v_cache_ptrs, + cu_block_count, + context_length, + is_finished, + rope_theta, + layer_offset, + dc_batch_size, + dc_sum_seq_len, + dc_max_seq_len, + kDecodeMaxSplits, + weights); + } + + ////////////////////////////////////////////// + /// output gemm -> + linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output); + + if (tensor_para_.world_size_ > 1) { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllReduceSum(attention_out, attention_out, num_token * hidden_units_, tensor_para_, stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } + sync_check_cuda_error(); +} + +template +void UnifiedAttentionLayer::prefill(T* output, + const T* qkv, + void** k_cache_ptrs, + void** v_cache_ptrs, + const T* attention_mask, + const int* cu_seqlens, + const int* padding_offset, + T** tmp_k_ptrs, + T** tmp_v_ptrs, + const int* input_length, + const int* context_length, + const int* cu_block_count, + const float* rope_theta, + int pf_batch_size, + int pf_num_token, + size_t layer_offset, + int pf_max_q_len, + int pf_max_k_len, + int pf_session_len, + const WeightType* weights) +{ + ////////////////////////////////////////////// + /// transpose qkv & apply rotary embedding & rebuild padding + /// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D]) + invokeAddFusedQKVBiasTranspose(q_buf_2_, + k_buf_2_, + v_buf_2_, + (T*)qkv, + weights->qkv.bias, + padding_offset, // padding_offset, + context_length, // used for applying rotary embedding + input_length, + rope_theta, + pf_batch_size, + pf_max_q_len, // seq_len + pf_num_token, + local_head_num_, + local_kv_head_num_, + size_per_head_, + params_.rotary_embedding_dim, + params_.rotary_embedding_base, + params_.max_position_embeddings, + false, // params_.use_dynamic_ntk, + params_.use_logn_attn, + stream_); + sync_check_cuda_error(); + + ////////////////////////////////////////////////////////// + /// insert the k/v computed from inputs into k/v cache + /// transpose kv -> kv cache + // put k/v_buf from shape [B, kvH, s, D] to + // k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x] + // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x] + invokeExtendKVCache(k_cache_ptrs, + v_cache_ptrs, + k_buf_2_, + v_buf_2_, + cu_block_count, + input_length, + context_length, + pf_batch_size, + kv_cache_block_len_, + layer_offset, + pf_max_q_len, + size_per_head_, + local_kv_head_num_, + quant_policy_, + weights->past_kv_scale.data(), + stream_); + sync_check_cuda_error(); + + const int kv_cache_elem_bits = quant_policy_ & QuantPolicy::kCacheKVInt8 ? 8 : sizeof(T) * 8; + + FT_CHECK(weights->past_kv_scale.size() == 4); + ConvertKvCacheBlocksToLinear2((const void**)k_cache_ptrs, + (const void**)v_cache_ptrs, + (T**)tmp_k_ptrs, + (T**)tmp_v_ptrs, + cu_block_count, + context_length, + layer_offset, + kv_cache_block_len_, + pf_session_len, + local_kv_head_num_, + size_per_head_, + pf_batch_size, + quant_policy_, + weights->past_kv_scale.data(), + stream_); + sync_check_cuda_error(); + + if (use_fmha_) { + fusedMultiHeadAttention(output, + q_buf_2_, + tmp_k_ptrs, + tmp_v_ptrs, + 0, + (T*)attention_mask, + (int*)cu_seqlens, + (int*)context_length, + pf_batch_size, + pf_max_q_len, + pf_max_k_len, + pf_session_len); + } + else { + unfusedMultiHeadAttention(output, + q_buf_2_, + tmp_k_ptrs, + tmp_v_ptrs, + 0, + attention_mask, + padding_offset, + context_length, + pf_batch_size, + pf_num_token, + pf_max_q_len, + pf_max_k_len, + pf_session_len, + quant_policy_, + weights->past_kv_scale.data()); + } +} + +template +void UnifiedAttentionLayer::decode(T* output, + const T* qkv, + void** k_cache_ptrs, + void** v_cache_ptrs, + const int* cu_block_count, + const int* context_length, + const bool* is_finished, + const float* rope_theta, + size_t layer_offset, + int batch_size, + int dc_sum_seq_len, + int dc_max_seq_len, + int max_split_k, + const WeightType* weights) +{ + DecoderMultiHeadAttentionParams params{}; + + params.out = output; + params.q = (T*)qkv; + params.k = params.q + local_head_num_ * size_per_head_; + params.v = params.k + local_kv_head_num_ * size_per_head_; + params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + + params.q_bias = weights->qkv.bias; + params.k_bias = params.q_bias + local_head_num_ * size_per_head_; + params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_; + + params.batch_size = batch_size; + params.cu_block_cnts = (int*)cu_block_count; + + params.k_cache_block_ptrs = (void**)k_cache_ptrs; + params.v_cache_block_ptrs = (void**)v_cache_ptrs; + params.kv_cache_block_size = kv_cache_block_len_; + + params.finished = is_finished; + params.context_length = context_length; + params.rope_theta = rope_theta; + + params.layer_offset = layer_offset; + + params.num_heads = local_head_num_; + params.num_kv_heads = local_kv_head_num_; + params.size_per_head = size_per_head_; + params.inv_sqrt_dh = 1.f / std::sqrt((float)params.size_per_head); + + params.rotary_embedding_dim = size_per_head_; + params.rotary_embedding_base = params_.rotary_embedding_base; + params.max_position_embeddings = params_.max_position_embeddings; + // params.use_dynamic_ntk = params_.use_dynamic_ntk; + params.use_logn_attn = params_.use_logn_attn; + + params.partial_O = dc_workspace_; + params.partial_M = params.partial_O + batch_size * local_head_num_ * max_split_k * size_per_head_; + params.partial_L = params.partial_M + batch_size * local_head_num_ * max_split_k; + + const float avg_batch_size = dc_max_seq_len ? (float)dc_sum_seq_len / dc_max_seq_len : 1; + FT_CHECK(avg_batch_size >= 1.f); + + max_split_k = std::max(1, (int)std::ceil(max_split_k / avg_batch_size)); + + params.max_split_k = max_split_k; + params.max_seq_len = dc_max_seq_len; + + params.arch = arch_; + params.stream = stream_; + + params.quant_policy = quant_policy_; + FT_CHECK(std::size(weights->past_kv_scale) == std::size(params.kv_quant_params)); + std::copy(weights->past_kv_scale.begin(), weights->past_kv_scale.end(), std::begin(params.kv_quant_params)); + + { + NvtxScope scope("decoder_multihead_attention"); + DispatchDecoderMultiheadAttention(params); + } +} + +template +void UnifiedAttentionLayer::fusedMultiHeadAttention(T* output, + const T* query, + T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + T* attention_mask, + int* cu_seqlens, + int* context_lengths, + int batch_size, + int max_q_len, + int max_k_len, + int max_seq_len) +{ + ////////////////////////////////////////////// + // flash attention + // flash attention 2 only support half inputs + using AttentionOp = FlashAttentionOp; + using Layout = typename AttentionOp::AttentionLayout; + Layout layout_q{ + int(local_head_num_ * max_q_len * size_per_head_), int(size_per_head_), int(max_q_len * size_per_head_)}; + Layout layout_k{int(local_head_num_ * max_seq_len * size_per_head_), + int(size_per_head_), + int(max_seq_len * size_per_head_), + false, + cache_layer_offset, + key_cache_ptrs}; + Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_), + int(size_per_head_), + int(max_seq_len * size_per_head_), + false, + cache_layer_offset, + val_cache_ptrs}; + Layout layout_o{ + int(local_head_num_ * max_q_len * size_per_head_), + int(local_head_num_ * size_per_head_), + int(size_per_head_), + true, + }; + size_t group_size = size_t(local_head_num_ / local_kv_head_num_); + AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); + typename AttentionOp::Params attn_params{output, + (T*)query, + k_cache_buf_, + v_cache_buf_, + attention_mask, + qk_buf_float_, + cu_seqlens, + nullptr, + nullptr, + context_lengths, + group_size, + layout_q, + layout_k, + layout_v, + layout_o}; + + // + flash_attention(attn_params, stream_); +} + +template +void UnifiedAttentionLayer::unfusedMultiHeadAttention(T* output, + const T* query, + T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + const T* attention_mask, + const int* padding_offset, + const int* context_length, + int batch_size, + int num_token, + int max_q_len, + int max_k_len, + int max_seq_len, + int quant, + const float* kv_scale) +{ + // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] + // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] + invokeTransposeKVCache(k_cache_buf_, + v_cache_buf_, + (const T**)key_cache_ptrs, + (const T**)val_cache_ptrs, + cache_layer_offset, + batch_size, + context_length, // history_len + input_len = context_len + max_k_len, + max_seq_len, + size_per_head_, + local_head_num_, + head_n_rep_, + stream_, + 0, // dequant handled in block->linear conversion + kv_scale); + sync_check_cuda_error(); + + const T qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f)); + + ////////////////////////////////////////////// + /// Q*K batch gemm + /// -> [B, H, s, t + s] + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, + CUBLAS_OP_N, + max_k_len, // m + max_q_len, // n + size_per_head_, // k + k_cache_buf_, // A + size_per_head_, // lda + max_k_len * size_per_head_, // strideA + query, // B + size_per_head_, // ldb + max_q_len * size_per_head_, // strideB + qk_buf_, // C + max_k_len, // ldc + max_q_len * max_k_len, // strideC + batch_size * local_head_num_); // batchCount + + ////////////////////////////////////////////// + /// ! masked softmax (kernel asserts k_length <= 4096) + MaskedSoftmaxParam param{}; + param.attention_score = qk_buf_; + param.qk = qk_buf_; + param.attention_mask = attention_mask; + param.batch_size = batch_size; + param.q_length = max_q_len; + param.k_length = max_k_len; + param.num_heads = local_head_num_; + param.qk_scale = qk_scale; + param.linear_bias_slopes = nullptr; + invokeMaskedSoftmax(param, stream_); + sync_check_cuda_error(); + + ////////////////////////////////////////////// + /// softmax(QK)*V batch gemm + // -> [B, H, S, D] + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + size_per_head_, // m + max_q_len, // n + max_k_len, // k + v_cache_buf_, // A + size_per_head_, // lda + max_k_len * size_per_head_, // strideA, + qk_buf_, // B + max_k_len, // ldb + max_k_len * max_q_len, // strideB + qkv_buf_2_, // C + size_per_head_, // ldc, + max_q_len * size_per_head_, // strideC + batch_size * local_head_num_); // batchCount + + ////////////////////////////////////////////// + /// transpose -> + invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, + output, + num_token, + batch_size, + max_q_len, + local_head_num_, + size_per_head_, + padding_offset, + nullptr, + 0, + stream_); + sync_check_cuda_error(); +} + +template class UnifiedAttentionLayer; +template class UnifiedAttentionLayer; + +} // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h new file mode 100644 index 0000000000..932b4ecf5c --- /dev/null +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -0,0 +1,186 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/GptContextAttentionLayer.h + +#pragma once + +#include "src/turbomind/models/llama/LlamaDenseWeight.h" +#include "src/turbomind/models/llama/LlamaLinear.h" +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/nccl_utils.h" + +namespace turbomind { + +template +class UnifiedAttentionLayer { +public: + using WeightType = LlamaAttentionWeight; + static constexpr int kDecodeMaxSplits = 16; + + void freeBuffer(); + void allocateBuffer(size_t num_token, + size_t pf_batch_size, + size_t pf_max_q_len, + size_t pf_max_k_len, + size_t dc_batch_size, + size_t dc_max_split_k); + + UnifiedAttentionLayer(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + LlamaAttentionParams attn_params, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool use_fmha, + int cache_block_seq_len, + int quant_policy): + head_num_(head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num / tensor_para.world_size_), + local_kv_head_num_(kv_head_num / tensor_para.world_size_), + head_n_rep_(head_num / kv_head_num), + params_(attn_params), + tensor_para_(tensor_para), + stream_(stream), + cublas_wrapper_(cublas_wrapper), + linear_(cublas_wrapper, stream), + allocator_(allocator), + kv_cache_block_len_(cache_block_seq_len), + is_free_buffer_after_forward_(is_free_buffer_after_forward), + use_fmha_(use_fmha), + quant_policy_(quant_policy) + { + FT_CHECK(head_num % kv_head_num == 0); + arch_ = getSMVersion(); + } + + void forward(TensorMap* outputs, const TensorMap* inputs, const LlamaAttentionWeight* weights); + + void prefill(T* output, + const T* qkv, + void** k_cache_ptrs, + void** v_cache_ptrs, + const T* attention_mask, + const int* cu_seqlens, + const int* padding_offset, + T** tmp_k_ptrs, + T** tmp_v_ptrs, + const int* input_length, + const int* context_length, + const int* cu_block_count, + const float* rope_theta, + int pf_batch_size, + int pf_num_token, + size_t layer_offset, + int pf_max_q_len, + int pf_max_k_len, + int pf_session_len, + const WeightType* weights); + + void decode(T* output, + const T* qkv, + void** k_cache_ptrs, + void** v_cache_ptrs, + const int* cu_block_count, + const int* context_length, + const bool* is_finished, + const float* rope_theta, + size_t layer_offset, + int batch_size, + int dc_sum_seq_len, + int dc_max_seq_len, + int max_split_k, + const WeightType* weights); + + void fusedMultiHeadAttention(T* output, + const T* query, + T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + T* attention_mask, + int* cu_seqlens, + int* context_lengths, + int batch_size, + int max_q_len, + int max_k_len, + int max_seq_len); + + void unfusedMultiHeadAttention(T* output, + const T* query, + T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + const T* attention_mask, + const int* padding_offset, + const int* context_length, + int batch_size, + int num_token, + int max_q_len, + int max_k_len, + int max_seq_len, + int quant_policy, + const float* kv_scale); + +private: + const size_t head_num_; + const size_t size_per_head_; + const size_t hidden_units_; + const size_t local_kv_head_num_; + const size_t local_head_num_; + const size_t head_n_rep_; + const size_t kv_cache_block_len_; + const bool is_free_buffer_after_forward_; + + const LlamaAttentionParams params_; + + const bool use_fmha_; + const int quant_policy_; + + NcclParam tensor_para_; + + cudaStream_t stream_; + IAllocator* allocator_; + cublasMMWrapper* cublas_wrapper_; + LlamaLinear linear_; + + int arch_{}; + + T* qkv_buf_{}; + T* q_buf_2_{}; + T* k_buf_2_{}; + T* v_buf_2_{}; + T* k_cache_buf_{}; + T* v_cache_buf_{}; + T* qk_buf_{}; + float* qk_buf_float_{}; + T* qkv_buf_2_{}; + T* qkv_buf_3_{}; + float* dc_workspace_{}; + + bool is_allocate_buffer_ = false; +}; + +} // namespace turbomind diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc new file mode 100644 index 0000000000..20974eeea9 --- /dev/null +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -0,0 +1,257 @@ + +#include "src/turbomind/models/llama/unified_decoder.h" +#include "src/turbomind/kernels/bert_preprocess_kernels.h" +#include "src/turbomind/kernels/gpt_kernels.h" +#include "src/turbomind/models/llama/llama_decoder_kernels.h" +#include "src/turbomind/models/llama/llama_kernels.h" +#include "src/turbomind/models/llama/unified_attention_layer.h" +#include "src/turbomind/utils/cuda_utils.h" + +namespace turbomind { + +template +void UnifiedDecoder::allocateBuffer(size_t num_token, size_t pf_batch_size, size_t pf_max_q_len, size_t pf_max_k_len) +{ + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + + if (pf_batch_size) { + attention_mask_ = + (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false); + padding_offset_ = + (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false); + cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false); + } +} + +template +void UnifiedDecoder::freeBuffer() +{ + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + + allocator_->free((void**)&padding_offset_); + allocator_->free((void**)&cu_seqlens_); + allocator_->free((void**)&attention_mask_); + allocator_->free((void**)&h_pinned_token_num_ptr_, true); +} + +template +void UnifiedDecoder::initialize(const LlamaAttentionParams& attn_params, + size_t kv_head_num, + bool use_fmha, + int cache_block_seq_len, + int quant_policy) +{ + h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); + + attn_layer_ = new UnifiedAttentionLayer(head_num_, + kv_head_num, + size_per_head_, + attn_params, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + use_fmha, + cache_block_seq_len, + quant_policy); + + ffn_layer_ = new LlamaFfnLayer(head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_); +} + +template +void UnifiedDecoder::forwardSelfAttn(T* attn_io, + TensorMap* _outputs, + const TensorMap* _inputs, + size_t token_num, + size_t pf_batch_size, + size_t pf_max_q_len, + size_t pf_max_k_len, + size_t dc_batch_size, + int layer_id, + const LlamaAttentionWeight* weight) +{ + TensorMap inputs(*_inputs); + inputs.insert("input_query", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); + inputs.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id}); + if (pf_batch_size) { + inputs.insert("attention_mask", + {MEMORY_GPU, dtype_, {pf_batch_size, 1, pf_max_q_len, pf_max_k_len}, attention_mask_}); + const size_t pf_token_num = token_num - dc_batch_size; + inputs.insert("padding_offset", {MEMORY_GPU, TYPE_INT32, {pf_token_num}, padding_offset_}); + inputs.insert("cu_seqlens", {MEMORY_GPU, TYPE_INT32, {pf_batch_size + 1}, cu_seqlens_}); + } + + TensorMap outputs(*_outputs); + outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); + + attn_layer_->forward(&outputs, &inputs, weight); +} + +template +UnifiedDecoder::~UnifiedDecoder() +{ + delete attn_layer_; + delete ffn_layer_; + freeBuffer(); +} + +template +void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, const std::vector* weights) +{ + /** + * input tensors: + * \param decoder_input [num_token, hidden_units], float + * \param input_lengths [batch_size], int + * \param history_lengths [batch_size], int + * \param context_legnths [batch_size], int + * \param output_norm_weight [hidden_dims], float + * \param max_q_len [1], int on cpu + * \param max_kv_len [1], int on cpu + * \param max_seq_len [1], int on cpu + * + * output tensors: + * \param decoder_output [num_token, hidden_units], + * \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x] + * \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head] + * \param last_token_hidden_units [batch_size, hidden_units] + */ + + // Session sess{}; + + const size_t token_num = inputs->at("decoder_input").shape[0]; + + const int pf_max_q_len = inputs->getVal("pf_max_q_len"); + const int pf_max_k_len = inputs->getVal("pf_max_k_len"); + const int pf_batch_size = inputs->getVal("pf_batch_size"); + const int dc_batch_size = inputs->getVal("dc_batch_size"); + + const int* input_length = inputs->getPtr("input_lengths"); + const int* context_length = inputs->getPtr("context_lengths"); + + T* decoder_input_output = inputs->getPtr("decoder_input"); + T* decoder_output = outputs->getPtr("decoder_output"); + + T* last_token_hidden_units = outputs->getPtr("last_token_hidden_units"); + + allocateBuffer(token_num, pf_batch_size, pf_max_q_len, pf_max_k_len); + + const int pf_offset = dc_batch_size; + + if (pf_batch_size) { + FT_CHECK(padding_offset_); + + size_t tmp_token_num{}; + // `cu_seqlens` is exclusive sum of "input_lengths" + invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_, + &tmp_token_num, // updated token num + padding_offset_, + cu_seqlens_, + input_length + pf_offset, + pf_batch_size, + pf_max_q_len, + stream_); + sync_check_cuda_error(); + + FT_CHECK(tmp_token_num == token_num - dc_batch_size); + + invokeCreateCausalMasks(attention_mask_, + input_length + pf_offset, + context_length + pf_offset, + pf_max_q_len, + pf_max_k_len, + pf_batch_size, + stream_); + sync_check_cuda_error(); + } + + ///////////////////////////////////////////// + /// RMSNorm + invokeRootMeanSquareNorm(decoder_output, + decoder_input_output, + weights->at(0)->self_attn_norm_weights, + rmsnorm_eps_, + token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + + for (size_t layer = 0; layer < num_layer_; ++layer) { + ///////////////////////////////////////////// + /// self-attention + forwardSelfAttn(decoder_output, + outputs, + inputs, + token_num, + pf_batch_size, + pf_max_q_len, + pf_max_k_len, + dc_batch_size, + layer, + &weights->at(layer)->self_attn_weights); + + invokeFusedAddBiasResidualRMSNorm(decoder_input_output, + decoder_output, + weights->at(layer)->self_attn_weights.output.bias, + weights->at(layer)->ffn_norm_weights, + rmsnorm_eps_, + token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + + //////////////////////////////////////////// + /// feed-forward network + TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; + TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; + ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &weights->at(layer)->ffn_weights); + + const bool is_last_layer = layer == num_layer_ - 1; + + auto scale_weight = !is_last_layer ? weights->at(layer + 1)->self_attn_norm_weights : + inputs->at("output_norm_weight").getPtr(); + invokeFusedAddBiasResidualRMSNorm(decoder_input_output, + decoder_output, + weights->at(layer)->ffn_weights.output.bias, + scale_weight, + rmsnorm_eps_, + token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + if (dc_batch_size) { + check_cuda_error(cudaMemcpyAsync(last_token_hidden_units, + decoder_output, + sizeof(T) * dc_batch_size * hidden_units_, + cudaMemcpyDefault, + stream_)); + } + + if (pf_batch_size) { + invokeGetFeatureOfLastToken(last_token_hidden_units + pf_offset * hidden_units_, // + decoder_output + pf_offset * hidden_units_, + cu_seqlens_, + hidden_units_, + pf_batch_size, + stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_) { + freeBuffer(); + } +} + +template class UnifiedDecoder; +template class UnifiedDecoder; + +} // namespace turbomind diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h new file mode 100644 index 0000000000..daac2b4df6 --- /dev/null +++ b/src/turbomind/models/llama/unified_decoder.h @@ -0,0 +1,99 @@ +#pragma once + +#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" +#include "src/turbomind/models/llama/LlamaFfnLayer.h" +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/unified_attention_layer.h" +#include "src/turbomind/utils/cublasMMWrapper.h" +#include "src/turbomind/utils/nccl_utils.h" + +namespace turbomind { + +template +class UnifiedDecoder { +protected: + void allocateBuffer(size_t num_token, size_t pfill_batch_size, size_t pfill_max_q_len, size_t pfill_max_k_len); + void freeBuffer(); + + void initialize(const LlamaAttentionParams& attn_params, + size_t kv_head_num, + bool use_fmha, + int cache_block_seq_len, + int quant_policy); + + cudaStream_t stream_; + cublasMMWrapper* cublas_wrapper_; + IAllocator* allocator_; + bool is_free_buffer_after_forward_{}; + + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t hidden_units_; + float rmsnorm_eps_; + + NcclParam tensor_para_; + + T* attention_mask_{}; + int* padding_offset_{}; + int* cu_seqlens_{}; // cu for cumulative + + size_t* h_pinned_token_num_ptr_{}; + + UnifiedAttentionLayer* attn_layer_{}; + LlamaFfnLayer* ffn_layer_{}; + + const DataType dtype_; + + using WeightType = LlamaDecoderLayerWeight; + + void forwardSelfAttn(T* attn_io, + TensorMap* _outputs, + const TensorMap* _inputs, + size_t token_num, + size_t pf_batch_size, + size_t pf_max_q_len, + size_t pf_max_k_len, + size_t dc_batch_size, + int layer_id, + const LlamaAttentionWeight* weight); + +public: + UnifiedDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + const LlamaAttentionParams& attn_params, + float rmsnorm_eps, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool use_fmha, + int cache_block_seq_len, + int quant_policy): + stream_(stream), + cublas_wrapper_(cublas_wrapper), + allocator_(allocator), + is_free_buffer_after_forward_(is_free_buffer_after_forward), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + hidden_units_(head_num * size_per_head), + num_layer_(num_layer), + rmsnorm_eps_(rmsnorm_eps), + tensor_para_(tensor_para), + dtype_(getTensorType()) + { + initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy); + } + + ~UnifiedDecoder(); + + void forward(TensorMap* outputs, const TensorMap* inputs, const std::vector* weights); +}; + +} // namespace turbomind diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index b55ed040af..46e8443a86 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -282,6 +282,27 @@ PYBIND11_MODULE(_turbomind, m) return new triton::Tensor(self->where, self->type, new_shape, self->data); }, "new_shape"_a) + .def( + "copy_from", + [](triton::Tensor* self, py::object obj) { + py::capsule cap = obj.attr("__dlpack__")(); + DLManagedTensor* dlmt = + static_cast(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName)); + auto src = DLManagedTensorToTritonTensor(dlmt); + if (self->type == triton::TYPE_FP16 || self->type == triton::TYPE_FP32 + || self->type == triton::TYPE_INT32) { + auto num_element = + std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies()); + auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; + ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]); + cudaMemcpy( + const_cast(self->data), const_cast(src->data), num_bytes, cudaMemcpyDefault); + } + else { + ft::FT_CHECK(0); + } + }, + "tensor"_a) .def( "__dlpack__", [](triton::Tensor* self, long stream) { @@ -340,6 +361,7 @@ PYBIND11_MODULE(_turbomind, m) .def_static( "create_llama_model", [](std::string model_dir, + std::string config, size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, @@ -354,18 +376,19 @@ PYBIND11_MODULE(_turbomind, m) }; if (data_type == "half" || data_type == "fp16" || data_type == "int4") { auto model = std::make_shared>( - tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); model->setFfiLock(gil_control); return model; } else { auto model = std::make_shared>( - tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); model->setFfiLock(gil_control); return model; } }, "model_dir"_a, + "config"_a = "", "tensor_para_size"_a = 1, "pipeline_para_size"_a = 1, "enable_custom_all_reduce"_a = 0, @@ -406,6 +429,15 @@ PYBIND11_MODULE(_turbomind, m) py::call_guard(), "device_id"_a, "rank"_a) + .def( + "get_params", + [](AbstractTransformerModel* model, int deviceId, int rank) { + TensorMap output = model->getParams(deviceId, rank); + return output; + }, + py::call_guard(), + "device_id"_a, + "rank"_a) .def("__str__", &AbstractTransformerModel::toString) .def("__repr__", &AbstractTransformerModel::toString) .def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize) diff --git a/src/turbomind/triton_backend/libfastertransformer.cc b/src/turbomind/triton_backend/libfastertransformer.cc index bef458521c..0e8b927716 100644 --- a/src/turbomind/triton_backend/libfastertransformer.cc +++ b/src/turbomind/triton_backend/libfastertransformer.cc @@ -1715,7 +1715,7 @@ void ModelInstanceState::ReadOutputTensors(size_t output_dtype, batchn_shape, output_buffer, - TRITONSERVER_MEMORY_GPU, + TRITONSERVER_MEMORY_CPU, model_instance_device_id_start_); } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 8a7674a2ab..33711b502d 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -64,36 +64,59 @@ void LlamaTritonModel::handleMissingParams() TM_LOG_WARNING("[LlamaTritonModel] `kv_head_num` is not set, default to `head_num` (%d).", (int)kv_head_num_); } - if (!max_batch_size_) { - max_batch_size_ = 32; - TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_); + if (!attn_params_.max_position_embeddings) { + attn_params_.max_position_embeddings = 2048; + TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.", + (int)attn_params_.max_position_embeddings); } - if (!session_len_) { - session_len_ = 2160; - TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)session_len_); + if (!engine_params_.max_batch_size) { + engine_params_.max_batch_size = 64; + TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", + (int)engine_params_.max_batch_size); } - if (!max_context_token_num_) { - max_context_token_num_ = (int)std::sqrt(max_batch_size_); + if (!engine_params_.session_len) { + engine_params_.session_len = attn_params_.max_position_embeddings; + TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_params_.session_len); + } + + if (!engine_params_.max_context_token_num) { + engine_params_.max_context_token_num = engine_params_.session_len; TM_LOG_WARNING("[LlamaTritonModel] `max_context_token_num` is not set, default to %d.", - (int)max_context_token_num_); + (int)engine_params_.max_context_token_num); + } + + if (engine_params_.max_context_token_num <= engine_params_.max_batch_size) { + engine_params_.max_context_token_num *= engine_params_.session_len; + TM_LOG_WARNING("[LlamaTritonModel] `max_context_token_num` = %d.", (int)engine_params_.max_context_token_num); + } + + if (!engine_params_.step_length) { + engine_params_.step_length = 1; + } + + if (!engine_params_.cache_max_block_count) { + engine_params_.cache_max_block_count = .95f; + TM_LOG_WARNING("[LlamaTritonModel] `cache_max_entry_count` is not set, default to %f.", + engine_params_.cache_max_block_count); } - if (!step_length_) { - step_length_ = 1; - TM_LOG_WARNING("[LlamaTritonModel] `step_length` is not set, default to %d.", (int)step_length_); + if (!cache_block_seq_len_) { + cache_block_seq_len_ = 128; + TM_LOG_WARNING("[LlamaTritonModel] `cache_block_seq_len` is not set, default to %d.", cache_block_seq_len_); } - if (!cache_max_entry_count_) { - cache_max_entry_count_ = 32; - TM_LOG_WARNING("[LlamaTritonModel] `cache_max_entry_count` is not set, default to %d.", - (int)cache_max_entry_count_); + if (!engine_params_.cache_chunk_size) { + engine_params_.cache_chunk_size = engine_params_.cache_max_block_count; + TM_LOG_WARNING("[LlamaTritonModel] `cache_chunk_size` is not set, default to %d.", + (int)engine_params_.cache_chunk_size); } - if (!cache_chunk_size_) { - cache_chunk_size_ = cache_max_entry_count_; - TM_LOG_WARNING("[LlamaTritonModel] `cache_chunk_size` is not set, default to %d.", (int)cache_chunk_size_); + if (!engine_params_.num_tokens_per_iter) { + engine_params_.num_tokens_per_iter = engine_params_.max_context_token_num; + TM_LOG_WARNING("[LlamaTritonModel] `num_tokens_per_iter` is not set, default to `max_context_token_num` (%d).", + (int)engine_params_.num_tokens_per_iter); } } @@ -101,52 +124,75 @@ template LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, - std::string model_dir): + std::string model_dir, + std::string config): tensor_para_size_(tensor_para_size), pipeline_para_size_(pipeline_para_size), shared_weights_(std::vector>>(ft::getDeviceCount())), enable_custom_all_reduce_(enable_custom_all_reduce) { - model_dir_ = model_dir; - const std::string inifile{model_dir + "/config.ini"}; - INIReader reader = INIReader(inifile); - if (reader.ParseError() < 0) { - std::cout << "[ERROR] Can't load '" << inifile << "'\n"; - ft::FT_CHECK(false); + INIReader reader; + FT_CHECK_WITH_INFO((config.empty() ^ model_dir.empty()), "invalid init options"); + + if (!config.empty()) { + std::FILE* tmpf = std::tmpfile(); + std::fputs(config.c_str(), tmpf); + std::rewind(tmpf); + reader = INIReader(tmpf); + if (reader.ParseError() < 0) { + TM_LOG_ERROR("[ERROR] Can't init with config %s", config.c_str()); + ft::FT_CHECK(false); + } + } + + if (!model_dir.empty()) { + model_dir_ = model_dir; + const std::string inifile{model_dir + "/config.ini"}; + reader = INIReader(inifile); + if (reader.ParseError() < 0) { + TM_LOG_ERROR("[ERROR] Can't load %s", inifile.c_str()); + ft::FT_CHECK(false); + } } - model_name_ = reader.Get("llama", "model_name"); - head_num_ = reader.GetInteger("llama", "head_num"); - kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0); - size_per_head_ = reader.GetInteger("llama", "size_per_head"); - inter_size_ = reader.GetInteger("llama", "inter_size"); - num_layer_ = reader.GetInteger("llama", "num_layer"); - vocab_size_ = reader.GetInteger("llama", "vocab_size"); - norm_eps_ = reader.GetFloat("llama", "norm_eps"); - start_id_ = reader.GetInteger("llama", "start_id"); - end_id_ = reader.GetInteger("llama", "end_id"); - max_batch_size_ = reader.GetInteger("llama", "max_batch_size", 0); - max_context_token_num_ = reader.GetInteger("llama", "max_context_token_num", 0); - session_len_ = reader.GetInteger("llama", "session_len", 0); - step_length_ = reader.GetInteger("llama", "step_length", 0); - cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0); - use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); - cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); - attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); - quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); - group_size_ = reader.GetInteger("llama", "group_size", 0); - - attn_params_.rotray_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); + model_name_ = reader.Get("llama", "model_name"); + head_num_ = reader.GetInteger("llama", "head_num"); + kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0); + size_per_head_ = reader.GetInteger("llama", "size_per_head"); + inter_size_ = reader.GetInteger("llama", "inter_size"); + num_layer_ = reader.GetInteger("llama", "num_layer"); + vocab_size_ = reader.GetInteger("llama", "vocab_size"); + norm_eps_ = reader.GetFloat("llama", "norm_eps"); + start_id_ = reader.GetInteger("llama", "start_id"); + end_id_ = reader.GetInteger("llama", "end_id"); + use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); + cache_block_seq_len_ = reader.GetInteger("llama", "cache_block_seq_len", 0); + + attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); + quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); + group_size_ = reader.GetInteger("llama", "group_size", 0); + + // rotary embedding parameters + attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); attn_params_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f); + attn_params_.rope_scaling_factor = reader.GetFloat("llama", "rope_scaling_factor", 0.f); attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0); - attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0); - attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0); + // attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0); + attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0); - handleMissingParams(); + engine_params_.max_batch_size = reader.GetInteger("llama", "max_batch_size", 0); + engine_params_.max_context_token_num = reader.GetInteger("llama", "max_context_token_num", 0); + engine_params_.session_len = reader.GetInteger("llama", "session_len", 0); + engine_params_.step_length = reader.GetInteger("llama", "step_length", 0); - if (max_context_token_num_ <= max_batch_size_) { - max_context_token_num_ *= session_len_; - } + engine_params_.cache_max_block_count = reader.GetFloat("llama", "cache_max_entry_count", 0); + engine_params_.cache_chunk_size = reader.GetInteger("llama", "cache_chunk_size", 0); + + engine_params_.num_tokens_per_iter = reader.GetInteger("llama", "num_tokens_per_iter", 0); + engine_params_.extra_tokens_per_iter = reader.GetInteger("llama", "extra_tokens_per_iter", 0); + engine_params_.max_prefill_iters = reader.GetInteger("llama", "max_prefill_iters", 1); + + handleMissingParams(); shared_state_ = std::make_shared::SharedState>(); shared_state_->barrier = std::make_shared(tensor_para_size); @@ -219,7 +265,7 @@ std::unique_ptr> LlamaTritonModel::createSh ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; ft::FT_CHECK(tensor_para.world_size_ == tensor_para_size_); - ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_); + ft::FT_CHECK(pipeline_para.world_size_ == pipeline_para_size_); auto llama = std::make_unique>(head_num_, kv_head_num_, @@ -227,18 +273,14 @@ std::unique_ptr> LlamaTritonModel::createSh inter_size_, num_layer_, vocab_size_, - attn_params_, norm_eps_, - max_batch_size_, - max_context_token_num_, - session_len_, - step_length_, + attn_params_, start_id_, end_id_, - cache_max_entry_count_, - cache_chunk_size_, + cache_block_seq_len_, quant_policy_, use_context_fmha_, + engine_params_, shared_state_, shared_weights_[device_id].get(), tensor_para, @@ -256,7 +298,7 @@ std::unique_ptr> LlamaTritonModel::createSh std::move(cuda_device_prop_ptr), shared_weights_[device_id], std::move(llama), - session_len_}); + engine_params_.session_len}); } template @@ -307,10 +349,27 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) group_size_, tensor_para_size_, tensor_para_rank); - shared_weights_[device_id]->loadModel(model_dir_); + // model inited with model_dir + if (model_dir_ != "") { + shared_weights_[device_id]->loadModel(model_dir_); + } return; } +template +TensorMap LlamaTritonModel::getParams(int deviceId, int rank) +{ + ft::check_cuda_error(cudaSetDevice(deviceId)); + // shared_weight should be created before getParams + ft::FT_CHECK(shared_weights_[deviceId] != nullptr); + ft::TensorMap output = shared_weights_[deviceId]->getParams(); + TensorMap result; + for (auto [name, tensor] : output) { + result.emplace(name, triton::Tensor{tensor.where, tensor.type, tensor.shape, tensor.data}); + } + return result; +} + template std::string LlamaTritonModel::toString() { @@ -318,14 +377,16 @@ std::string LlamaTritonModel::toString() ss << "Model: " << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ - << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << max_batch_size_ - << "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_ - << "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_ - << "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_ - << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_ - << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ - << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ - << "\ngroup_size: " << group_size_ << std::endl; + << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << engine_params_.max_batch_size + << "\nmax_context_token_num: " << engine_params_.max_context_token_num + << "\nsession_len: " << engine_params_.session_len << "\nstep_length: " << engine_params_.step_length + << "\ncache_max_entry_count: " << engine_params_.cache_max_block_count + << "\ncache_block_seq_len: " << cache_block_seq_len_ << "\ncache_chunk_size: " << engine_params_.cache_chunk_size + << "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_ + << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ + << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ + << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << "\ngroup_size: " << group_size_ + << std::endl; return ss.str(); } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index b7d8f439ca..ff086a9099 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -40,7 +40,8 @@ struct LlamaTritonModel: public AbstractTransformerModel { LlamaTritonModel(size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, - std::string model_dir); + std::string model_dir, + std::string config = ""); ~LlamaTritonModel() = default; @@ -53,6 +54,8 @@ struct LlamaTritonModel: public AbstractTransformerModel { void createSharedWeights(int deviceId, int rank) override; + TensorMap getParams(int deviceId, int rank) override; + void createCustomComms(std::vector>* custom_all_reduce_comms, int world_size) override; @@ -86,15 +89,11 @@ struct LlamaTritonModel: public AbstractTransformerModel { size_t num_layer_; size_t vocab_size_; turbomind::LlamaAttentionParams attn_params_; + turbomind::EngineParams engine_params_; float norm_eps_; - int max_batch_size_; - int max_context_token_num_; - int session_len_; - int step_length_; int start_id_; int end_id_; - int cache_max_entry_count_; - int cache_chunk_size_; + int cache_block_seq_len_; int use_context_fmha_; size_t tensor_para_size_; size_t pipeline_para_size_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc index 102b324b8e..b4666bd1e7 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc @@ -64,59 +64,11 @@ std::unordered_map LlamaTritonModelInstance::convert h_total_output_lengths_ = (uint32_t*)std::realloc((void*)h_total_output_lengths_, request_batch_size * sizeof(uint32_t)); - std::unordered_map ft_input_tensors = std::unordered_map{ - {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, - // {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, - }; - - if (input_tensors->find("bad_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); - ft_input_tensors.insert( - {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); - } - - if (input_tensors->find("stop_words_list") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); - ft_input_tensors.insert( - {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); - } - - if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") - && input_tensors->count("request_prompt_type")) { - - move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); - ft_input_tensors.insert( - {"request_prompt_lengths", - as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); - - move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); - ft_input_tensors.insert( - {"request_prompt_embedding", - as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); - } - - if (input_tensors->find("top_p_decay") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("top_p_decay"), d_top_p_decay_, &allocator_); - ft_input_tensors.insert({"top_p_decay", as_GPU_tensor(input_tensors->at("top_p_decay"), d_top_p_decay_)}); - } - if (input_tensors->find("top_p_min") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("top_p_min"), d_top_p_min_, &allocator_); - ft_input_tensors.insert({"top_p_min", as_GPU_tensor(input_tensors->at("top_p_min"), d_top_p_min_)}); - } - if (input_tensors->find("top_p_reset_ids") != input_tensors->end()) { - move_tensor_H2D(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_, &allocator_); - ft_input_tensors.insert( - {"top_p_reset_ids", as_GPU_tensor(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_)}); - } + std::unordered_map ft_input_tensors{}; for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { - if (t->first.find("input_ids") == std::string::npos // && t->first.find("input_lengths") == std::string::npos - && t->first.find("output_seq_len") == std::string::npos - && t->first.find("prefix_soft_prompt_embedding") == std::string::npos - && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { - if (ft_input_tensors.count(t->first) == 0) { - ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); - } + if (ft_input_tensors.count(t->first) == 0) { + ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); } } @@ -204,12 +156,12 @@ LlamaTritonModelInstance::forward(std::shared_ptr output_tensors = std::unordered_map{ {"output_ids", - ft::Tensor{ft::MEMORY_GPU, + ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{request_batch_size, beam_width, (size_t)instance_->session_len}, d_output_ids_}}, {"sequence_length", - ft::Tensor{ft::MEMORY_GPU, + ft::Tensor{ft::MEMORY_CPU, ft::TYPE_UINT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths_}}}; @@ -267,10 +219,9 @@ void LlamaTritonModelInstance::allocateBuffer(const size_t request_batch_size const size_t session_len, const bool is_return_logits) { - d_output_ids_ = - (int*)(allocator_->reMalloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len, false)); - d_sequence_lengths_ = - (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_ids_ = (int*)std::realloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len); + d_sequence_lengths_ = (int*)std::realloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width); + d_output_log_probs_ = (float*)(allocator_->reMalloc( d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false)); d_cum_log_probs_ = @@ -284,8 +235,8 @@ void LlamaTritonModelInstance::allocateBuffer(const size_t request_batch_size template void LlamaTritonModelInstance::freeBuffer() { - allocator_->free((void**)(&d_output_ids_)); - allocator_->free((void**)(&d_sequence_lengths_)); + std::free(d_output_ids_); + std::free(d_sequence_lengths_); allocator_->free((void**)(&d_output_log_probs_)); allocator_->free((void**)(&d_cum_log_probs_)); std::free(h_total_output_lengths_); diff --git a/src/turbomind/triton_backend/transformer_triton_backend.hpp b/src/turbomind/triton_backend/transformer_triton_backend.hpp index 8f1f88f5a6..aee45d080f 100644 --- a/src/turbomind/triton_backend/transformer_triton_backend.hpp +++ b/src/turbomind/triton_backend/transformer_triton_backend.hpp @@ -272,6 +272,7 @@ struct AbstractTransformerModelInstance; struct AbstractTransformerModelInstance { virtual ~AbstractTransformerModelInstance() = default; + virtual std::shared_ptr> forward(std::shared_ptr> input_tensors) = 0; @@ -300,6 +301,8 @@ struct AbstractTransformerModelInstance { void* stream_ctx_ = nullptr; }; +using TensorMap = std::unordered_map; + struct AbstractTransformerModel { static std::shared_ptr createLlamaModel(std::string model_dir); @@ -323,6 +326,8 @@ struct AbstractTransformerModel { virtual void createSharedWeights(int deviceId, int rank) = 0; + virtual TensorMap getParams(int deviceId, int rank) = 0; + virtual std::string toString() = 0; virtual int getTensorParaSize() = 0; virtual int getPipelineParaSize() = 0; diff --git a/src/turbomind/triton_backend/triton_utils.hpp b/src/turbomind/triton_backend/triton_utils.hpp index 1547e10d4e..a87dd7d6f4 100644 --- a/src/turbomind/triton_backend/triton_utils.hpp +++ b/src/turbomind/triton_backend/triton_utils.hpp @@ -52,5 +52,6 @@ ft::Tensor as_GPU_tensor(const triton::Tensor& tensor, T* d_ptr) inline ft::Tensor as_CPU_tensor(const triton::Tensor& tensor) { + ft::FT_CHECK(tensor.where == triton::MEMORY_CPU); return ft::Tensor{ft::MEMORY_CPU, triton::Tensor::convertTritonTypeToFt(tensor.type), tensor.shape, tensor.data}; } diff --git a/src/turbomind/utils/cuda_utils.h b/src/turbomind/utils/cuda_utils.h index be0b85d69a..f066a0c25b 100644 --- a/src/turbomind/utils/cuda_utils.h +++ b/src/turbomind/utils/cuda_utils.h @@ -131,7 +131,7 @@ void check(T result, char const* const func, const char* const file, int const l inline void syncAndCheck(const char* const file, int const line) { // When FT_DEBUG_LEVEL=DEBUG, must check error - static char* level_name = std::getenv("FT_DEBUG_LEVEL"); + static char* level_name = std::getenv("TM_DEBUG_LEVEL"); if (level_name != nullptr) { static std::string level = std::string(level_name); if (level == "DEBUG") { diff --git a/src/turbomind/utils/debug_utils.h b/src/turbomind/utils/debug_utils.h new file mode 100644 index 0000000000..f07af38db2 --- /dev/null +++ b/src/turbomind/utils/debug_utils.h @@ -0,0 +1,7 @@ +#pragma once + +#if __has_include("3rdparty/dbg.h") +#include "3rdparty/dbg.h" +#else +#define dbg(...) +#endif diff --git a/src/turbomind/utils/dispatch.h b/src/turbomind/utils/dispatch.h new file mode 100644 index 0000000000..c1c1568e5f --- /dev/null +++ b/src/turbomind/utils/dispatch.h @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include + +namespace turbomind { + +namespace detail { + +template +inline constexpr std::integral_constant _Int{}; + +template +bool dispatch_impl(F&& f, P&& p, G g, std::integer_sequence, std::index_sequence) +{ + constexpr int N = sizeof...(Xs); + return (((((P &&) p)(_Int) || (g && Is == N - 1)) && (((F &&) f)(_Int), 1)) || ...); +} + +} // namespace detail + +template +bool dispatch(std::integer_sequence seq, P&& p, F&& f, G g = {}) +{ + return detail::dispatch_impl((F &&) f, (P &&) p, g, seq, std::make_index_sequence{}); +} + +template +bool dispatch(std::integer_sequence seq, F&& f) +{ + return (((F &&) f)(detail::_Int) || ...); +} + +} // namespace turbomind diff --git a/tests/pytorch/kernel/test_mbgmm.py b/tests/pytorch/kernel/test_mbgmm.py new file mode 100644 index 0000000000..54f820f86a --- /dev/null +++ b/tests/pytorch/kernel/test_mbgmm.py @@ -0,0 +1,122 @@ +import pytest +import torch +from torch.nn.utils.rnn import pad_sequence + +from lmdeploy.pytorch.kernels.mbgmm import mbgmm_a, mbgmm_b + + +class TestMBGMM: + + @pytest.fixture + def dtype(self): + yield torch.float16 + + @pytest.fixture + def head_size(self): + yield 32 + + @pytest.fixture + def out_head_size(self): + yield 16 + + @pytest.fixture + def seq_lens(self): + yield torch.tensor([2, 4, 6, 8]).cuda() + + @pytest.fixture + def ranks(self): + yield torch.tensor([2, 4]).cuda() + + @pytest.fixture + def start_loc(self, seq_lens): + yield seq_lens.cumsum(0) - seq_lens + + @pytest.fixture + def input(self, seq_lens, head_size, dtype): + total_len = seq_lens.sum() + yield torch.rand(total_len, head_size, dtype=dtype).cuda() + + @pytest.fixture + def rank_ids(self, seq_lens, ranks): + num_ranks = len(ranks) + num_seqs = len(seq_lens) + ret = torch.randint(0, num_ranks, (num_seqs, )).cuda() + yield ret + + @pytest.fixture + def lora_a(self, ranks, head_size, dtype): + out = [] + for rank in ranks: + w = torch.rand(head_size, rank, dtype=dtype).cuda() + out.append(w) + yield out + + @pytest.fixture + def lora_b(self, ranks, out_head_size, dtype): + out = [] + for rank in ranks: + w = torch.rand(rank, out_head_size, dtype=dtype).cuda() + out.append(w) + yield out + + @pytest.fixture + def page_table(self, ranks): + total_ranks = sum(ranks) + index = torch.randperm(total_ranks) + index = index.split(ranks.tolist()) + yield pad_sequence(index, batch_first=True).cuda() + + @pytest.fixture + def paged_lora_a(self, lora_a, ranks, page_table, head_size, dtype): + num_pages = sum(ranks) + cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() + for index, r, w in zip(page_table, ranks, lora_a): + cache[index[:r]] = w.t() + yield cache + + @pytest.fixture + def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size, + dtype): + num_pages = sum(ranks) + cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() + for index, r, w in zip(page_table, ranks, lora_b): + cache[index[:r], :out_head_size] = w + yield cache + + @pytest.fixture + def gt(self, input, start_loc, seq_lens, rank_ids, lora_a, lora_b): + out = [] + for loc, s_len, r_id in zip(start_loc, seq_lens, rank_ids): + inp = input[loc:loc + s_len] + l_a = lora_a[r_id] + l_b = lora_b[r_id] + out.append(inp @ l_a @ l_b) + + yield torch.cat(out) + + def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, + start_loc, seq_lens, rank_ids, page_table, ranks, gt): + max_seq_len = max(seq_lens).item() + max_rank = page_table.size(-1) + + xa = mbgmm_a(input, + paged_lora_a, + b_start_loc=start_loc, + b_seq_lens=seq_lens, + b_rank_ids=rank_ids, + rank_page_table=page_table, + ranks=ranks, + max_seq_len=max_seq_len, + max_rank=max_rank) + + output = mbgmm_b(xa, + paged_lora_b[..., :out_head_size], + b_start_loc=start_loc, + b_seq_lens=seq_lens, + b_rank_ids=rank_ids, + rank_page_table=page_table, + ranks=ranks, + max_seq_len=max_seq_len, + max_rank=max_rank) + + torch.testing.assert_close(gt, output) diff --git a/tests/pytorch/kernel/test_mbgmv.py b/tests/pytorch/kernel/test_mbgmv.py new file mode 100644 index 0000000000..08866ba294 --- /dev/null +++ b/tests/pytorch/kernel/test_mbgmv.py @@ -0,0 +1,112 @@ +import pytest +import torch +from torch.nn.utils.rnn import pad_sequence + +from lmdeploy.pytorch.kernels.mbgmv import mbgmv_a, mbgmv_b + + +class TestMBGMV: + + @pytest.fixture + def dtype(self): + yield torch.float16 + + @pytest.fixture + def head_size(self): + yield 64 + + @pytest.fixture + def out_head_size(self): + yield 32 + + @pytest.fixture + def batch_size(self): + yield 8 + + @pytest.fixture + def ranks(self): + yield torch.tensor([2, 4]).cuda() + + @pytest.fixture + def input(self, batch_size, head_size, dtype): + x = torch.rand(batch_size, head_size, dtype=dtype).cuda() + x -= 0.5 + yield x + + @pytest.fixture + def rank_ids(self, batch_size, ranks): + num_ranks = len(ranks) + ret = torch.randint(0, num_ranks, (batch_size, )).cuda() + yield ret + + @pytest.fixture + def lora_a(self, ranks, head_size, dtype): + out = [] + for rank in ranks: + w = torch.rand(head_size, rank, dtype=dtype).cuda() + w -= 0.5 + out.append(w) + yield out + + @pytest.fixture + def lora_b(self, ranks, out_head_size, dtype): + out = [] + for rank in ranks: + w = torch.rand(rank, out_head_size, dtype=dtype).cuda() + w -= 0.5 + out.append(w) + yield out + + @pytest.fixture + def page_table(self, ranks): + total_ranks = sum(ranks) + index = torch.randperm(total_ranks) + index = index.split(ranks.tolist()) + yield pad_sequence(index, batch_first=True).cuda() + + @pytest.fixture + def paged_lora_a(self, lora_a, ranks, page_table, head_size, dtype): + num_pages = sum(ranks) + cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() + for index, r, w in zip(page_table, ranks, lora_a): + cache[index[:r]] = w.t() + yield cache + + @pytest.fixture + def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size, + dtype): + num_pages = sum(ranks) + cache = torch.empty(num_pages, head_size, dtype=dtype).cuda() + for index, r, w in zip(page_table, ranks, lora_b): + cache[index[:r], :out_head_size] = w + yield cache + + @pytest.fixture + def gt(self, input, rank_ids, lora_a, lora_b): + out = [] + for inp, r_id in zip(input, rank_ids): + inp = inp.unsqueeze(0) + l_a = lora_a[r_id] + l_b = lora_b[r_id] + out.append(inp @ l_a @ l_b) + + yield torch.cat(out) + + def test_mbgmv(self, input, paged_lora_a, paged_lora_b, out_head_size, + rank_ids, page_table, ranks, gt): + max_rank = page_table.size(-1) + + xa = mbgmv_a(input, + paged_lora_a, + b_rank_ids=rank_ids, + rank_page_table=page_table, + ranks=ranks, + max_rank=max_rank) + + output = mbgmv_b(xa, + paged_lora_b[..., :out_head_size], + b_rank_ids=rank_ids, + rank_page_table=page_table, + ranks=ranks, + max_rank=max_rank) + torch.testing.assert_close(gt, output, atol=1e-3, rtol=1e-5) diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index cf9e046142..d943581a7b 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -95,8 +95,8 @@ def num_heads_k(self, request): yield request.param @pytest.fixture - def block_size(self): - yield 16 + def block_size(self, request): + yield request.param @pytest.fixture def seq_lens(self, request): @@ -192,6 +192,7 @@ def conti_gt(self, gt, seq_lens): [([30, 50, 70, 90], [50, 40, 30, 20]), ([1, 1, 1, 1], [50, 40, 30, 20])], indirect=True) + @pytest.mark.parametrize('block_size', [1, 16], indirect=True) def test_paged_attention(self, conti_q, blocked_kv, block_offsets, start_loc, seq_lens, history_lens, block_size, conti_gt): @@ -210,9 +211,7 @@ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, b_start_loc=start_loc, b_seq_len=seq_lens, b_kv_seq_len=kv_seq_lens, - max_input_len=max_seq_len, - BLOCK=block_size) - + max_input_len=max_seq_len) torch.testing.assert_close(out, conti_gt, atol=5e-4, rtol=1e-5) @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], @@ -220,6 +219,7 @@ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, @pytest.mark.parametrize(['seq_lens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True) + @pytest.mark.parametrize('block_size', [16], indirect=True) def test_biased_paged_attention(self, conti_q, blocked_kv, block_offsets, start_loc, seq_lens, history_lens, block_size, mask, conti_gt): diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index 340d4fcc67..3c4dd9c114 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -2,47 +2,71 @@ import torch from lmdeploy.pytorch.messages import SchedulerSession -from lmdeploy.pytorch.paging.block_manager import BlockAllocator, BlockManager +from lmdeploy.pytorch.paging.block_manager import (BlockManager, + LogicalAllocator) class TestAllocator: @pytest.fixture - def block_size(self): + def num_gpu_blocks(self): yield 16 @pytest.fixture - def block_num(self): + def num_cpu_blocks(self): yield 4 @pytest.fixture - def device(self): - yield 'cpu' + def allocator(self, num_cpu_blocks, num_gpu_blocks): + yield LogicalAllocator(num_cpu_blocks, num_gpu_blocks) - @pytest.fixture - def allocator(self, block_size, block_num, device): - yield BlockAllocator(block_size, block_num, device) + def test_alloc(self, allocator, num_cpu_blocks, num_gpu_blocks): + + # initialize + num_blocks = num_cpu_blocks + num_gpu_blocks + gpu_allocator = allocator.get_phy_allocator('gpu') + cpu_allocator = allocator.get_phy_allocator('cpu') + assert allocator.get_num_free_blocks() == num_blocks + assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks + assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks - def test_alloc(self, allocator, block_num): - assert allocator.get_num_free_blocks() == block_num # test allocate - block = allocator.allocate() - assert allocator.get_num_free_blocks() == block_num - 1 + block_size = 4 + blocks = allocator.allocate(block_size, 'gpu') + assert len(blocks) == block_size + assert allocator.get_num_free_blocks() == num_blocks - block_size + assert gpu_allocator.get_num_free_blocks( + ) == num_gpu_blocks - block_size + # test free - block.ref_count += 1 - allocator.free(block) - assert allocator.get_num_free_blocks() == block_num - 1 - allocator.free(block) - assert allocator.get_num_free_blocks() == block_num + allocator.add_ref_count(blocks, 1) + allocator.free(blocks) + assert allocator.get_num_free_blocks() == num_blocks - block_size + allocator.free(blocks) + assert allocator.get_num_free_blocks() == num_blocks + assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks + assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks + + def test_full(self, allocator, num_cpu_blocks, num_gpu_blocks): + + num_blocks = num_cpu_blocks + num_gpu_blocks + gpu_allocator = allocator.get_phy_allocator('gpu') + cpu_allocator = allocator.get_phy_allocator('cpu') + # no free blocks - blocks = [allocator.allocate() for _ in range(block_num)] + gpu_block_size = num_gpu_blocks + gpu_blocks = allocator.allocate(gpu_block_size, 'gpu') + cpu_block_size = num_cpu_blocks + cpu_blocks = allocator.allocate(cpu_block_size, 'cpu') + assert cpu_allocator.get_num_free_blocks() == 0 + assert gpu_allocator.get_num_free_blocks() == 0 with pytest.raises(MemoryError): - allocator.allocate() - for block in blocks: - allocator.free(block) - # double free - with pytest.raises(ValueError): - allocator.free(blocks[0]) + allocator.allocate(1, 'gpu') + allocator.free(gpu_blocks) + allocator.free(cpu_blocks) + assert allocator.get_num_free_blocks() == num_blocks + assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks + assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks class TestBlockManager: @@ -60,16 +84,16 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def block_mgr(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield BlockManager(block_size, num_cpu_blocks, num_gpu_blocks) + def block_mgr(self, num_cpu_blocks, num_gpu_blocks): + yield BlockManager(num_cpu_blocks, num_gpu_blocks) def test_alloc(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0) + sess = SchedulerSession(0, block_size) # test alloc token_ids = torch.tensor([1]) msg = sess.add_sequence(token_ids) - msg.append_tokens(1, block_size) + # msg.append_tokens(1, block_size) assert block_mgr.can_allocate(msg) block_mgr.allocate(msg) block_table = block_mgr.get_block_table(msg) @@ -80,95 +104,100 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks): # test free block_mgr.free(msg) block_table = block_mgr.get_block_table(msg) - assert block_table is None + assert block_table is None or len(block_table) == 0 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks # alloc over limit + token_ids = torch.zeros((num_gpu_blocks * block_size + 1, ), + dtype=torch.int64) msg = sess.add_sequence(token_ids) - msg.append_tokens(num_gpu_blocks * block_size + 1, block_size) assert not block_mgr.can_allocate(msg) def test_append_slot(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0) + sess = SchedulerSession(0, block_size) # test append token_ids = torch.tensor([1]) msg = sess.add_sequence(token_ids) - msg.append_tokens(1, block_size) block_mgr.allocate(msg) block_table = block_mgr.get_block_table(msg) + assert len(block_table) == 1 # no new logical block - msg.append_tokens(block_size - 1, block_size) + msg.update_token_ids(torch.tensor([1] * (block_size - 1))) assert block_mgr.can_append_slot(msg) block_mgr.append_slot(msg) + block_table = block_mgr.get_block_table(msg) assert len(block_table) == 1 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1 # with new logical block - msg.append_tokens(1, block_size) + msg.update_token_ids(torch.tensor([1])) block_mgr.append_slot(msg) + block_table = block_mgr.get_block_table(msg) assert len(block_table) == 2 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 def test_fork(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0) + sess = SchedulerSession(0, block_size) - token_ids = torch.tensor([1]) + token_ids = torch.tensor([1] * (block_size * 2 + 1)) from_msg = sess.add_sequence(token_ids) - from_msg.append_tokens(block_size + 1, block_size) block_mgr.allocate(from_msg) from_block_table = block_mgr.get_block_table(from_msg) + assert len(from_block_table) == 3 - to_msg = sess.fork_sequence(token_ids, from_msg) - to_msg.append_tokens(1, block_size) + to_msg = sess.fork_sequence(torch.tensor([1]), from_msg) # fork assert block_mgr.can_fork(from_msg) copy_map = block_mgr.fork(from_msg, to_msg) block_table = block_mgr.get_block_table(to_msg) - assert len(block_table) == 2 - assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + assert len(block_table) == 3 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 4 assert block_table[0] == from_block_table[0] - assert block_table[0].ref_count == 2 - assert block_table[1] != from_block_table[1] + assert block_table[1] == from_block_table[1] + assert block_table[2] != from_block_table[2] assert len(copy_map) == 1 - assert copy_map[ - from_block_table[1].block_id] == block_table[1].block_id + assert copy_map[from_block_table[2]] == block_table[2] # can not fork - assert block_mgr.can_fork(from_msg) + assert not block_mgr.can_fork(from_msg) def test_swap(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0) + sess = SchedulerSession(0, block_size) - token_ids = torch.tensor([1]) + token_ids = torch.tensor([1] * (block_size + 1)) msg = sess.add_sequence(token_ids) - msg.append_tokens(block_size + 1, block_size) block_mgr.allocate(msg) - block_table = block_mgr.get_block_table(msg) - gpu_block_id = [block.block_id for block in block_table] - assert block_mgr.can_swap_out(msg) - swap_out_map = block_mgr.swap_out(msg) + old_phy_blocks = block_mgr.get_block_table(msg) + success, swap_map = block_mgr.try_swap_out(msg) + new_phy_blocks = block_mgr.get_block_table(msg) + assert success assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks assert block_mgr.get_num_free_cpu_blocks() == num_gpu_blocks - 2 - assert len(swap_out_map) == 2 - for block_id in gpu_block_id: - assert block_id in swap_out_map - for block in block_table: - assert block.device == 'cpu' - - assert block_mgr.can_swap_in(msg) - swap_in_map = block_mgr.swap_in(msg) + assert len(swap_map) == 2 + for block_id in old_phy_blocks: + assert block_id in swap_map + for block_id in new_phy_blocks: + assert block_id - num_gpu_blocks in swap_map.values() + + old_phy_blocks = block_mgr.get_block_table(msg) + success, swap_map = block_mgr.try_swap_in(msg) + new_phy_blocks = block_mgr.get_block_table(msg) assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 assert block_mgr.get_num_free_cpu_blocks() == num_gpu_blocks - assert len(swap_in_map) == 2 - for block in block_table: - assert block.device == 'gpu' - - swap_out_map = block_mgr.swap_out(msg) + assert len(swap_map) == 2 + for block_id in old_phy_blocks: + assert block_id - num_gpu_blocks in swap_map + for block_id in new_phy_blocks: + assert block_id in swap_map.values() + + success, swap_map = block_mgr.try_swap_out(msg) + assert success + token_ids = torch.tensor([1] * (block_size * 4)) msg_full = sess.add_sequence(token_ids) - msg_full.append_tokens(block_size * 4, block_size) block_mgr.allocate(msg_full) - assert not block_mgr.can_swap_out(msg_full) + success, swap_map = block_mgr.try_swap_out(msg) + assert not success diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index 2c32287962..b507fbf6c5 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -30,7 +30,8 @@ def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): def scheduler_config(self): yield SchedulerConfig(max_batches=4, max_session_len=128, - max_request_output_len=64) + max_request_output_len=64, + eviction_type='copy') @pytest.fixture def scheduler(self, cache_config, scheduler_config): @@ -52,12 +53,13 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): assert seq.status == MessageStatus.WAITING assert seq in scheduler.waiting - output = scheduler.schedule() + output = scheduler.schedule(is_prefill=True) + block_tables = scheduler.get_block_tables(output.running) assert seq.status == MessageStatus.RUNNING assert seq in output.running - assert len(output.block_tables) == 1 - assert len(output.block_tables[0]) == num_blocks + assert len(block_tables) == 1 + assert len(block_tables[0]) == num_blocks assert block_manager.get_num_free_gpu_blocks( ) == num_gpu_blocks - num_blocks @@ -80,7 +82,7 @@ def test_update(self, scheduler, block_size, num_gpu_blocks): seq3 = session2.add_sequence(token_ids3) scheduler.add_sequence(seq3) - scheduler.schedule() + scheduler.schedule(is_prefill=True) assert seq1.status == MessageStatus.RUNNING assert seq2.status == MessageStatus.RUNNING assert seq3.status == MessageStatus.WAITING @@ -130,7 +132,7 @@ def test_swap(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): token_ids3 = torch.tensor([0] * block_size * 3) seq3 = session.add_sequence(token_ids3) scheduler.add_sequence(seq3) - scheduler.schedule() + scheduler.schedule(is_prefill=True) # seq1: 1 running gpu # seq2: 2 running gpu # seq3: 3 waiting empty @@ -146,7 +148,7 @@ def test_swap(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): assert len(scheduler.waiting) == 1 assert len(scheduler.hanging) == 1 - output = scheduler.schedule() + output = scheduler.schedule(is_prefill=True) # seq1: 1 running gpu # seq2: 2 hanging cpu # seq3: 3 waiting gpu @@ -166,7 +168,7 @@ def test_swap(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): assert len(scheduler.waiting) == 1 assert len(scheduler.hanging) == 0 - output = scheduler.schedule() + output = scheduler.schedule(is_prefill=True) # seq1: 1 running gpu # seq2: 3 running gpu # seq3: 3 nan @@ -181,8 +183,7 @@ def test_swap(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): seq2.update_token_ids(torch.tensor([1] * block_size)) scheduler.update() assert len(scheduler.running) == 2 - - output = scheduler.schedule() + output = scheduler.schedule(is_prefill=False) # seq1: 1 waiting cpu # seq2: 4 running gpu # seq3: 3 nan diff --git a/tests/test_lmdeploy/test_cli.py b/tests/test_lmdeploy/test_cli.py new file mode 100644 index 0000000000..a41eab442e --- /dev/null +++ b/tests/test_lmdeploy/test_cli.py @@ -0,0 +1,51 @@ +import inspect + + +def compare_func(class_method, function): + """Compare if a class method has same arguments as a function.""" + + argspec_cls = inspect.getfullargspec(class_method) + argspec_func = inspect.getfullargspec(function) + assert argspec_cls.args[1:] == argspec_func.args + assert argspec_cls.defaults == argspec_func.defaults + assert argspec_cls.annotations == argspec_func.annotations + + +def test_cli(): + + from lmdeploy.cli.cli import CLI + from lmdeploy.serve.turbomind.deploy import main as convert + compare_func(CLI.convert, convert) + + +def test_subcli_chat(): + from lmdeploy.cli.chat import SubCliChat + from lmdeploy.pytorch.chat import main as run_torch_model + from lmdeploy.turbomind.chat import main as run_turbomind_model + + compare_func(SubCliChat.torch, run_torch_model) + compare_func(SubCliChat.turbomind, run_turbomind_model) + + +def test_subcli_lite(): + from lmdeploy.cli.lite import SubCliLite + from lmdeploy.lite.apis.auto_awq import auto_awq + from lmdeploy.lite.apis.calibrate import calibrate + from lmdeploy.lite.apis.kv_qparams import main as run_kv_qparams + + compare_func(SubCliLite.auto_awq, auto_awq) + compare_func(SubCliLite.calibrate, calibrate) + compare_func(SubCliLite.kv_qparams, run_kv_qparams) + + +def test_subcli_serve(): + from lmdeploy.cli.serve import SubCliServe + from lmdeploy.serve.client import main as run_triton_client + from lmdeploy.serve.gradio.app import run as run_gradio + from lmdeploy.serve.openai.api_client import main as run_api_client + from lmdeploy.serve.openai.api_server import main as run_api_server + + compare_func(SubCliServe.gradio, run_gradio) + compare_func(SubCliServe.api_server, run_api_server) + compare_func(SubCliServe.api_client, run_api_client) + compare_func(SubCliServe.triton_client, run_triton_client) diff --git a/tests/test_lmdeploy/test_tokenizer.py b/tests/test_lmdeploy/test_tokenizer.py new file mode 100644 index 0000000000..ff7d8047b2 --- /dev/null +++ b/tests/test_lmdeploy/test_tokenizer.py @@ -0,0 +1,24 @@ +import pytest + +from lmdeploy.tokenizer import HuggingFaceTokenizer + + +@pytest.mark.parametrize('model_path', [ + 'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat', + 'baichuan-inc/Baichuan-7B', 'codellama/CodeLlama-7b-hf', + 'upstage/SOLAR-0-70b-16bit' +]) +@pytest.mark.parametrize( + 'input', ['hi, this is a test 😆😆! ' * 5, '為什麼我還在用繁體字 😆😆 gg! ' * 5]) +def test_tokenizer(model_path, input): + tokenizer = HuggingFaceTokenizer(model_path) + encoded = tokenizer.encode(input) + output = '' + offset = 0 + for i in range(1, len(encoded) + 1): + decoded = tokenizer.decode(encoded[:i], offset) + if decoded.endswith('�'): + continue + output += decoded + offset = i + assert input == output, 'input string should equal to output after enc-dec'