From 6b00f6239012b2bd6f1450a44107fe8665906451 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Wed, 22 Nov 2023 19:05:38 +0800 Subject: [PATCH] Support loading hf model directly (#685) * turbomind support export model params * fix overflow * support turbomind.from_pretrained * fix tp * support AutoModel * support load kv qparams * update auto_awq * udpate docstring * export lmdeploy version * update doc * remove download_hf_repo * LmdeployForCausalLM -> LmdeployForCausalLM * refactor turbomind.py * update comment * add bfloat16 convert back * support gradio run_locl load hf * support resuful api server load hf * add docs * support loading previous quantized model * adapt pr 690 * udpate docs * not export turbomind config when quantize a model * check model_name when can not get it from config.json * update readme * remove model_name in auto_awq * update * update * udpate * fix build * absolute import --- .gitignore | 1 + README.md | 52 +-- README_zh-CN.md | 52 +-- docs/en/load_hf.md | 71 ++++ docs/zh_cn/load_hf.md | 72 ++++ lmdeploy/cli/serve.py | 17 +- lmdeploy/lite/apis/auto_awq.py | 10 + lmdeploy/lite/apis/kv_qparams.py | 42 ++- lmdeploy/lite/utils/export_turbomind.py | 70 ++++ lmdeploy/serve/async_engine.py | 14 +- lmdeploy/serve/gradio/app.py | 2 +- lmdeploy/serve/gradio/turbomind_coupled.py | 6 +- lmdeploy/turbomind/chat.py | 53 +-- lmdeploy/turbomind/deploy/converter.py | 2 +- .../turbomind/deploy/source_model/base.py | 1 + .../turbomind/deploy/target_model/base.py | 42 ++- lmdeploy/turbomind/hf_repo/config.json | 11 + .../hf_repo/configuration_lmdeploy.py | 36 ++ .../turbomind/hf_repo/modeling_lmdeploy.py | 226 ++++++++++++ lmdeploy/turbomind/turbomind.py | 324 +++++++++++++++--- lmdeploy/turbomind/utils.py | 120 +++++++ .../models/llama/LlamaDecoderLayerWeight.cc | 73 +++- .../models/llama/LlamaDecoderLayerWeight.h | 3 + src/turbomind/models/llama/LlamaWeight.cc | 29 ++ src/turbomind/models/llama/LlamaWeight.h | 2 + src/turbomind/python/bind.cpp | 36 +- .../triton_backend/llama/LlamaTritonModel.cc | 52 ++- .../triton_backend/llama/LlamaTritonModel.h | 5 +- .../transformer_triton_backend.hpp | 4 + 29 files changed, 1196 insertions(+), 232 deletions(-) create mode 100644 docs/en/load_hf.md create mode 100644 docs/zh_cn/load_hf.md create mode 100644 lmdeploy/lite/utils/export_turbomind.py create mode 100644 lmdeploy/turbomind/hf_repo/config.json create mode 100644 lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py create mode 100644 lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py create mode 100644 lmdeploy/turbomind/utils.py diff --git a/.gitignore b/.gitignore index 79a716bd9d..6a6104d0ae 100644 --- a/.gitignore +++ b/.gitignore @@ -58,6 +58,7 @@ work_dir*/ *.bin *config.json *generate_config.json +!lmdeploy/turbomind/hf_repo/config.json # Pytorch *.pth diff --git a/README.md b/README.md index 81ef52d177..7da9778b40 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ ______________________________________________________________________ ## News 🎉 +- \[2023/11\] Turbomind supports loading hf model directly. Click [here](./docs/en/load_hf.md) for details. - \[2023/11\] TurboMind major upgrades, including: Paged Attention, faster attention kernels without sequence length limitation, 2x faster KV8 kernels, Split-K decoding (Flash Decoding), and W4A16 inference for sm_75 - \[2023/09\] TurboMind supports Qwen-14B - \[2023/09\] TurboMind supports InternLM-20B @@ -114,30 +115,18 @@ pip install lmdeploy ### Deploy InternLM -#### Get InternLM model +To use TurboMind inference engine, you need to first convert the model into TurboMind format. Currently, we support online conversion and offline conversion. With online conversion, TurboMind can load the Huggingface model directly. While with offline conversion, you should save the converted model first before using it. -```shell -# 1. Download InternLM model - -# Make sure you have git-lfs installed (https://git-lfs.com) -git lfs install -git clone https://huggingface.co/internlm/internlm-chat-7b-v1_1 /path/to/internlm-chat-7b - -# if you want to clone without large files – just their pointers -# prepend your git clone with the following env var: -GIT_LFS_SKIP_SMUDGE=1 - -# 2. Convert InternLM model to turbomind's format, which will be in "./workspace" by default -lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b - -``` +The following use [internlm/internlm-chat-7b-v1_1](https://huggingface.co/internlm/internlm-chat-7b-v1_1) as a example to show how to use turbomind with online conversion. You can refer to [load_hf.md](docs/en/load_hf.md) for other methods. #### Inference by TurboMind ```shell -lmdeploy chat turbomind ./workspace +lmdeploy chat turbomind internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` +> **Note**
The internlm/internlm-chat-7b-v1_1 model will be downloaded under `.cache` folder. You can also use a local path here. + > **Note**
> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind.
> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc. @@ -152,7 +141,7 @@ lmdeploy chat turbomind ./workspace # install lmdeploy with extra dependencies pip install lmdeploy[serve] -lmdeploy serve gradio ./workspace +lmdeploy serve gradio internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) @@ -165,13 +154,13 @@ Launch inference server by: # install lmdeploy with extra dependencies pip install lmdeploy[serve] -lmdeploy serve api_server ./workspace --instance_num 32 --tp 1 +lmdeploy serve api_server internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b --instance_num 32 --tp 1 ``` Then, you can communicate with it by command line, ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 lmdeploy serve api_client api_server_url ``` @@ -186,29 +175,6 @@ lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port Refer to [restful_api.md](docs/en/restful_api.md) for more details. -#### Serving with Triton Inference Server - -Launch inference server by: - -```shell -bash workspace/service_docker_up.sh -``` - -Then, you can communicate with the inference server by command line, - -```shell -python3 -m pip install tritonclient[grpc] -lmdeploy serve triton_client {server_ip_addresss}:33337 -``` - -or webui, - -```shell -lmdeploy serve gradio {server_ip_addresss}:33337 -``` - -For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and so on, you can find the guide from [here](docs/en/serving.md) - ### Inference with PyTorch For detailed instructions on Inference pytorch models, see [here](docs/en/pytorch.md). diff --git a/README_zh-CN.md b/README_zh-CN.md index a126351c21..e9b3734e41 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -20,6 +20,7 @@ ______________________________________________________________________ ## 更新 🎉 +- \[2023/11\] Turbomind 支持直接读取 Huggingface 模型。点击[这里](./docs/en/load_hf.md)查看使用方法 - \[2023/11\] TurboMind 重磅升级。包括:Paged Attention、更快的且不受序列最大长度限制的 attention kernel、2+倍快的 KV8 kernels、Split-K decoding (Flash Decoding) 和 支持 sm_75 架构的 W4A16 - \[2023/09\] TurboMind 支持 Qwen-14B - \[2023/09\] TurboMind 支持 InternLM-20B 模型 @@ -114,30 +115,18 @@ pip install lmdeploy ### 部署 InternLM -#### 获取 InternLM 模型 +使用 TurboMind 推理模型需要先将模型转化为 TurboMind 的格式,目前支持在线转换和离线转换两种形式。在线转换可以直接加载 Huggingface 模型,离线转换需需要先保存模型再加载。 -```shell -# 1. 下载 InternLM 模型 - -# Make sure you have git-lfs installed (https://git-lfs.com) -git lfs install -git clone https://huggingface.co/internlm/internlm-chat-7b-v1_1 /path/to/internlm-chat-7b - -# if you want to clone without large files – just their pointers -# prepend your git clone with the following env var: -GIT_LFS_SKIP_SMUDGE=1 - -# 2. 转换为 trubomind 要求的格式。默认存放路径为 ./workspace -lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b - -``` +下面以 [internlm/internlm-chat-7b-v1_1](https://huggingface.co/internlm/internlm-chat-7b-v1_1) 为例,展示在线转换的使用方式。其他方式可参考[load_hf.md](docs/zh_cn/load_hf.md) #### 使用 turbomind 推理 ```shell -lmdeploy chat turbomind ./workspace +lmdeploy chat turbomind internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` +> **Note**
internlm/internlm-chat-7b-v1_1 会自动下载到 `.cache` 文件夹,这里也可以传下载好的路径。 + > **Note**
> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。 @@ -151,7 +140,7 @@ lmdeploy chat turbomind ./workspace # 安装lmdeploy额外依赖 pip install lmdeploy[serve] -lmdeploy serve gradio ./workspace +lmdeploy serve gradio internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b ``` ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) @@ -164,13 +153,13 @@ lmdeploy serve gradio ./workspace # 安装lmdeploy额外依赖 pip install lmdeploy[serve] -lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${server_port} --instance_num 32 --tp 1 +lmdeploy serve api_server internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b --instance_num 32 --tp 1 ``` 你可以通过命令行方式与推理服务进行对话: ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 lmdeploy serve api_client api_server_url ``` @@ -185,29 +174,6 @@ lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port 更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)。 -#### 通过容器部署推理服务 - -使用下面的命令启动推理服务: - -```shell -bash workspace/service_docker_up.sh -``` - -你可以通过命令行方式与推理服务进行对话: - -```shell -python3 -m pip install tritonclient[grpc] -lmdeploy serve triton_client {server_ip_addresss}:33337 -``` - -也可以通过 WebUI 方式来对话: - -```shell -lmdeploy serve gradio {server_ip_addresss}:33337 -``` - -其他模型的部署方式,比如 LLaMA,LLaMA-2,vicuna等等,请参考[这里](docs/zh_cn/serving.md) - ### 基于 PyTorch 的推理 你必须确保环境中有安装 deepspeed: diff --git a/docs/en/load_hf.md b/docs/en/load_hf.md new file mode 100644 index 0000000000..ddf6fe8bfd --- /dev/null +++ b/docs/en/load_hf.md @@ -0,0 +1,71 @@ +# Load huggingface model directly + +Starting from v0.1.0, Turbomind adds the ability to pre-process the model parameters on-the-fly while loading them from huggingface style models. + +## Supported model type + +Currently, Turbomind support loading three types of model: + +1. A lmdeploy-quantized model hosted on huggingface.co, such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc. +2. Other LM models on huggingface.co like Qwen/Qwen-7B-Chat +3. A model converted by `lmdeploy convert`, legacy format + +## Usage + +### 1) A lmdeploy-quantized model + +For models quantized by `lmdeploy.lite` such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc. + +``` +repo_id=internlm/internlm-chat-20b-4bit +model_name=internlm-chat-20b +# or +# repo_id=/path/to/downloaded_model + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 2) Other LM models + +For other LM models such as Qwen/Qwen-7B-Chat or baichuan-inc/Baichuan2-7B-Chat. LMDeploy supported models can be viewed through `lmdeploy list`. + +``` +repo_id=Qwen/Qwen-7B-Chat +model_name=qwen-7b +# or +# repo_id=/path/to/Qwen-7B-Chat/local_path + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 3) A model converted by `lmdeploy convert` + +The usage is like previous + +``` +# Convert a model +lmdeploy convert /path/to/model ./workspace --model-name MODEL_NAME + +# Inference by TurboMind +lmdeploy chat turbomind ./workspace + +# Serving with gradio +lmdeploy serve gradio ./workspace + +# Serving with Restful API +lmdeploy serve api_server ./workspace --instance_num 32 --tp 1 +``` diff --git a/docs/zh_cn/load_hf.md b/docs/zh_cn/load_hf.md new file mode 100644 index 0000000000..63c08fe2d9 --- /dev/null +++ b/docs/zh_cn/load_hf.md @@ -0,0 +1,72 @@ +# 直接读取 huggingface 模型 + +从 v0.1.0 开始,Turbomid 添加了直接读取 Huggingface 格式权重的能力。 + +## 支持的类型 + +目前,TurboMind 支持加载三种类型的模型: + +1. 在 huggingface.co 上面通过 lmdeploy 量化的模型,如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit) +2. huggingface.co 上面其他 LM 模型,如Qwen/Qwen-7B-Chat +3. 通过 `lmdeploy convert` 命令转换好的模型,兼容旧格式 + +## 使用方式 + +### 1) 通过 lmdeploy 量化的模型 + +对于通过 `lmdeploy.lite` 量化的模型,TurboMind 可以直接加载,比如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit). + +``` +repo_id=internlm/internlm-chat-20b-4bit +model_name=internlm-chat-20b + +# or +# repo_id=/path/to/downloaded_model + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 2) 其他的 LM 模型 + +其他 LM 模型比如 Qwen/Qwen-7B-Chat, baichuan-inc/Baichuan2-7B-Chat。LMDeploy 模型支持情况可通过 `lmdeploy list` 查看。 + +``` +repo_id=Qwen/Qwen-7B-Chat +model_name=qwen-7b +# or +# repo_id=/path/to/Qwen-7B-Chat/local_path + +# Inference by TurboMind +lmdeploy chat turbomind $repo_id --model-name $model_name + +# Serving with gradio +lmdeploy serve gradio $repo_id --model-name $model_name + +# Serving with Restful API +lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1 +``` + +### 3) 通过 `lmdeploy convert` 命令转换好的模型 + +使用方式与之前相同 + +``` +# Convert a model +lmdeploy convert /path/to/model ./workspace --model-name MODEL_NAME + +# Inference by TurboMind +lmdeploy chat turbomind ./workspace + +# Serving with gradio +lmdeploy serve gradio ./workspace + +# Serving with Restful API +lmdeploy serve api_server ./workspace --instance_num 32 --tp 1 +``` diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 33580cdfe1..30185376f5 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -11,7 +11,7 @@ def gradio(self, server_port: int = 6006, batch_size: int = 32, tp: int = 1, - restful_api: bool = False): + **kwargs): """Serve LLMs with web ui using gradio. Example 1: @@ -21,7 +21,6 @@ def gradio(self, lmdeploy serve gradio http://0.0.0.0:23333 --server_name 0.0.0.0 --server_port 6006 - --restful_api True Example 3: lmdeploy serve gradio ${triton_server_ip_addresss}:33337 @@ -30,13 +29,12 @@ def gradio(self, model_path_or_server (str): the path of the deployed model or the tritonserver URL or restful api URL. The former is for directly running service with gradio. The latter is for running with - tritonserver by default. If the input URL is restful api. - Please enable another flag `restful_api`. + tritonserver by default. server_name (str): the ip address of gradio server server_port (int): the port of gradio server batch_size (int): batch size for running Turbomind directly tp (int): tensor parallel for Turbomind - restful_api (bool): a flag for model_path_or_server + kwargs (dict): extra params to init """ from lmdeploy.serve.gradio.app import run run(model_path_or_server, @@ -44,7 +42,7 @@ def gradio(self, server_port=server_port, batch_size=batch_size, tp=tp, - restful_api=restful_api) + **kwargs) def api_server(self, model_path: str, @@ -55,7 +53,8 @@ def api_server(self, allow_origins: List[str] = ['*'], allow_credentials: bool = True, allow_methods: List[str] = ['*'], - allow_headers: List[str] = ['*']): + allow_headers: List[str] = ['*'], + **kwargs): """Serve LLMs with restful api using fastapi. Args: @@ -68,6 +67,7 @@ def api_server(self, allow_credentials (bool): whether to allow credentials for CORS allow_methods (List[str]): a list of allowed HTTP methods for CORS allow_headers (List[str]): a list of allowed HTTP headers for CORS + kwargs (dict) extra params to init api server """ from lmdeploy.serve.openai.api_server import main as run_api_server @@ -79,7 +79,8 @@ def api_server(self, allow_origins=allow_origins, allow_credentials=allow_credentials, allow_methods=allow_methods, - allow_headers=allow_headers) + allow_headers=allow_headers, + **kwargs) def api_client(self, restful_api_url: str, session_id: int = 0): """Interact with restful api server in terminal. diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py index e470bd0733..4a4f8ea983 100644 --- a/lmdeploy/lite/apis/auto_awq.py +++ b/lmdeploy/lite/apis/auto_awq.py @@ -10,6 +10,8 @@ quant_weights, smooth_layers) from lmdeploy.lite.utils import collect_target_modules, load_hf_from_pretrained +# from lmdeploy.lite.utils.export_turbomind import export_turbomind_config + LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', 'QWenLMHeadModel': 'QWenBlock', @@ -33,6 +35,9 @@ def auto_awq(model: str, w_group_size: int = 128, device: str = 'cuda'): + assert model != work_dir, '$WORK_DIR and $HF_MODEL should be different' + model_path = model # noqa + # Load tokenizer and configuration tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, @@ -61,6 +66,11 @@ def auto_awq(model: str, model.save_pretrained(work_dir, max_shard_size='2GB') tokenizer.save_pretrained(work_dir) + # export_turbomind_config(model_name, + # model_path, + # work_dir, + # group_size=w_group_size) + if __name__ == '__main__': import fire diff --git a/lmdeploy/lite/apis/kv_qparams.py b/lmdeploy/lite/apis/kv_qparams.py index f31fee0299..873bc5b047 100644 --- a/lmdeploy/lite/apis/kv_qparams.py +++ b/lmdeploy/lite/apis/kv_qparams.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from pathlib import Path from typing import Union @@ -6,11 +7,28 @@ import torch +def _export_weight(into: str, + kv_qparams: np.array, + out_path: str, + tm_params: dict = None): + """Save kv_qparams to disk or copy to tm_params.""" + if tm_params is None: + print(into) + kv_qparams.tofile(out_path) + else: + name = os.path.basename(out_path) + src = torch.from_numpy(kv_qparams) + for tm_tensor in tm_params[name]: + tm_tensor.copy_from(src) + tm_params.pop(name) + + def _export_sym(key_stats: dict, value_stats: dict, bits: int, out_dir: Union[str, Path], - tp: int = 1) -> None: + tp: int = 1, + tm_params: dict = None) -> None: """Export symmetric quantization parameters to specified directory.""" keys_absmax = key_stats['absmax'] values_absmax = value_stats['absmax'] @@ -31,15 +49,16 @@ def _export_sym(key_stats: dict, kv_qparams = np.array([k_s, v_s], dtype=np.float32) out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501 - kv_qparams.tofile(out_path) - print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}') + info = f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}' + _export_weight(info, kv_qparams, out_path, tm_params) def _export_asym(key_stats: dict, value_stats: dict, bits: int, out_dir: Union[str, Path], - tp: int = 1) -> None: + tp: int = 1, + tm_params: dict = None) -> None: """Export asymmetric quantization parameters to specified directory.""" keys_min = key_stats['min'] values_min = value_stats['min'] @@ -81,16 +100,17 @@ def _export_asym(key_stats: dict, kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp], dtype=np.float32) out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' - kv_qparams.tofile(out_path) - print(f'Layer {layer_idx} MP {i} qparam: ' - f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}') + info = f'Layer {layer_idx} MP {i} qparam: ' \ + f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}' + _export_weight(info, kv_qparams, out_path, tm_params) def main(work_dir: str, turbomind_dir: str, kv_bits: int = 8, kv_sym: bool = False, - num_tp: int = 1) -> None: + num_tp: int = 1, + tm_params: dict = None) -> None: """Main function to export key and value stats. Args: @@ -102,6 +122,7 @@ def main(work_dir: str, kv_sym (bool, optional): Whether to use symmetric quantizaiton. Defaults to False. num_tp (int, optional): Number of tensor parallelism. Defaults to 1. + tm_params (dict): turbomind model weights. """ work_dir = Path(work_dir) @@ -113,9 +134,10 @@ def main(work_dir: str, value_stats = torch.load(work_dir / 'value_stats.pth') if kv_sym: - _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp, tm_params) else: - _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp, + tm_params) if __name__ == '__main__': diff --git a/lmdeploy/lite/utils/export_turbomind.py b/lmdeploy/lite/utils/export_turbomind.py new file mode 100644 index 0000000000..393a980041 --- /dev/null +++ b/lmdeploy/lite/utils/export_turbomind.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import shutil + +from huggingface_hub import snapshot_download + +from lmdeploy.turbomind.utils import get_hf_config_content + + +def export_turbomind_config(model_name: str, + model_path: str, + work_dir: str, + model_format: str = 'awq', + group_size: int = 128, + tp: int = 1): + """Export hf lmdeploy model and config.json.""" + import lmdeploy + from lmdeploy.model import MODELS + from lmdeploy.turbomind.deploy.converter import get_model_format + from lmdeploy.turbomind.deploy.source_model.base import INPUT_MODELS + from lmdeploy.turbomind.deploy.target_model.base import ( + OUTPUT_MODELS, TurbomindModelConfig) + + assert model_name in MODELS.module_dict.keys(), \ + f"'{model_name}' is not supported. " \ + f'The supported models are: {MODELS.module_dict.keys()}' + + if not os.path.exists(model_path): + model_path = snapshot_download(model_path, local_files_only=True) + + lmdeploy_dir = os.path.split(lmdeploy.__file__)[0] + hf_repo = os.path.join(lmdeploy_dir, 'turbomind', 'hf_repo') + files = os.listdir(hf_repo) + for file in files: + src = os.path.join(hf_repo, file) + dst = os.path.join(work_dir, file) + shutil.copy(src, dst) + + cfg = TurbomindModelConfig.from_dict({}, allow_none=True) + cfg.model_name = model_name + cfg.tensor_para_size = tp + cfg.rotary_embedding = cfg.size_per_head + cfg.group_size = group_size + cfg.weight_type = 'int4' + output_format = 'w4' + + inferred_model_format = get_model_format(model_name, model_format) + input_model = INPUT_MODELS.get(inferred_model_format)( + model_path=model_path, tokenizer_path=work_dir, ckpt_path=work_dir) + output_model = OUTPUT_MODELS.get(output_format)(input_model=input_model, + cfg=cfg, + to_file=False, + out_dir='') + + old_data = get_hf_config_content(model_path) + config = output_model.cfg.__dict__ + config_file = os.path.join(work_dir, 'config.json') + with open(config_file) as f: + data = json.load(f) + for k, v in old_data.items(): + if k in data: + data[f'__{k}'] = v + else: + data[k] = v + data['turbomind'] = config + from lmdeploy.version import __version__ + data['lmdeploy_version'] = __version__ + with open(config_file, 'w') as f: + f.write(json.dumps(data, indent=2) + '\n') diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 406e504af8..eb1c317c1e 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import dataclasses -import os.path as osp import random from contextlib import contextmanager from typing import List, Literal, Optional @@ -28,15 +27,10 @@ class AsyncEngine: def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None: from lmdeploy import turbomind as tm - from lmdeploy.tokenizer import Tokenizer - tokenizer_model_path = osp.join(model_path, 'triton_models', - 'tokenizer') - tokenizer = Tokenizer(tokenizer_model_path) - self.tm_model = tm.TurboMind(model_path, - eos_id=tokenizer.eos_token_id, - tp=tp, - **kwargs) - self.tokenizer = tokenizer + self.tm_model = tm.TurboMind.from_pretrained(model_path, + tp=tp, + **kwargs) + self.tokenizer = self.tm_model.tokenizer self.generators = [ self.tm_model.create_instance() for i in range(instance_num) ] diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py index 5b1668224d..cf8815ad0f 100644 --- a/lmdeploy/serve/gradio/app.py +++ b/lmdeploy/serve/gradio/app.py @@ -32,7 +32,7 @@ def run(model_path_or_server: str, else: from lmdeploy.serve.gradio.turbomind_coupled import run_local run_local(model_path_or_server, server_name, server_port, batch_size, - tp) + tp, **kwargs) if __name__ == '__main__': diff --git a/lmdeploy/serve/gradio/turbomind_coupled.py b/lmdeploy/serve/gradio/turbomind_coupled.py index d3f686089c..dfb38bf89f 100644 --- a/lmdeploy/serve/gradio/turbomind_coupled.py +++ b/lmdeploy/serve/gradio/turbomind_coupled.py @@ -118,7 +118,8 @@ def run_local(model_path: str, server_name: str = 'localhost', server_port: int = 6006, batch_size: int = 4, - tp: int = 1): + tp: int = 1, + **kwargs): """chat with AI assistant through web ui. Args: @@ -130,7 +131,8 @@ def run_local(model_path: str, """ InterFace.async_engine = AsyncEngine(model_path=model_path, instance_num=batch_size, - tp=tp) + tp=tp, + **kwargs) with gr.Blocks(css=CSS, theme=THEME) as demo: state_chatbot = gr.State([]) diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index bd32a9cc41..c0d2c8f4ed 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -1,22 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import dataclasses import os -import os.path as osp import random -os.environ['TM_LOG_LEVEL'] = 'ERROR' - +from lmdeploy.turbomind.utils import get_gen_param -@dataclasses.dataclass -class GenParam: - top_p: float - top_k: float - temperature: float - repetition_penalty: float - sequence_start: bool = False - sequence_end: bool = False - step: int = 0 - request_output_len: int = 512 +os.environ['TM_LOG_LEVEL'] = 'ERROR' def input_prompt(model_name): @@ -40,30 +29,6 @@ def valid_str(string, coding='utf-8'): return ret -def get_gen_param(cap, - sampling_param, - nth_round, - step, - request_output_len=512, - **kwargs): - """return parameters used by token generation.""" - gen_param = GenParam(**dataclasses.asdict(sampling_param), - request_output_len=request_output_len) - # Fix me later. turbomind.py doesn't support None top_k - if gen_param.top_k is None: - gen_param.top_k = 40 - - if cap == 'chat': - gen_param.sequence_start = (nth_round == 1) - gen_param.sequence_end = False - gen_param.step = step - else: - gen_param.sequence_start = True - gen_param.sequence_end = True - gen_param.step = 0 - return gen_param - - def main(model_path, session_id: int = 1, cap: str = 'chat', @@ -84,15 +49,11 @@ def main(model_path, **kwarg (dict): other arguments for initializing model's chat template """ from lmdeploy import turbomind as tm - from lmdeploy.tokenizer import Tokenizer - - tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') - tokenizer = Tokenizer(tokenizer_model_path) - tm_model = tm.TurboMind(model_path, - eos_id=tokenizer.eos_token_id, - tp=tp, - capability=cap, - **kwargs) + tm_model = tm.TurboMind.from_pretrained(model_path, + tp=tp, + capability=cap, + **kwargs) + tokenizer = tm_model.tokenizer generator = tm_model.create_instance() nth_round = 1 diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 4876002020..4e6a03203e 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -203,7 +203,7 @@ def main(model_name: str, if inferred_model_format.find('awq') != -1: cfg.weight_type = 'int4' output_format = 'w4' - assert group_size > 0, 'group_size should > 0' + assert group_size > 0, f'group_size: {group_size} should > 0' # convert print('model_name ', model_name) diff --git a/lmdeploy/turbomind/deploy/source_model/base.py b/lmdeploy/turbomind/deploy/source_model/base.py index 89f18033e9..c335b4c10b 100644 --- a/lmdeploy/turbomind/deploy/source_model/base.py +++ b/lmdeploy/turbomind/deploy/source_model/base.py @@ -64,6 +64,7 @@ def clean_up(self, last: bool) -> None: for key in self.params: layer_id = re.findall(self.attn_layer_patten, key) if len(layer_id) == 0: + # tok, norm, output to_remove.append(key) else: layer_id = int(layer_id[0]) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index c90ff9610f..29aaa124e0 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -18,6 +18,9 @@ def tprint(*args, **kwargs): + to_file = kwargs.pop('to_file', False) + if not to_file: + return from io import StringIO s = StringIO() print(*args, **kwargs, file=s, end='') @@ -90,10 +93,13 @@ def __init__(self, out_dir: str = ''): super().__init__() self.input_model = input_model - self.cfg = self.get_config(cfg) + self.cfg = cfg + if not cfg.valid: + self.cfg = self.get_config(cfg) assert self.cfg.valid self.to_file = to_file self.out_dir = out_dir + self.tm_params = {} @abstractmethod def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig: @@ -136,6 +142,27 @@ def export_weight(self, param: torch.Tensor, name: str) -> None: tprint(name, param.shape) param.contiguous().cpu().numpy().tofile( osp.join(self.out_dir, name)) + elif len(self.tm_params) > 0: + tm_params = self.tm_params + weight_type = self.cfg.weight_type + assert weight_type in ['fp16', 'fp32', 'int4'] + + # currently, the tensor type should in + # [torch.float, torch.half, torch.int32] + torch_tensor = param.cuda().contiguous() + assert torch_tensor.dtype in [ + torch.int32, torch.float, torch.half, torch.bfloat16 + ] + if torch_tensor.dtype != torch.int32: + if weight_type in ['fp16', 'int4']: + torch_tensor = torch_tensor.half() + else: + torch_tensor = torch_tensor.float() + for tm_tensor in tm_params[name]: + tm_tensor.copy_from(torch_tensor) + tm_params.pop(name) + else: + tprint('skip export', name, param.shape) def save_split(self, tensor: torch.Tensor, @@ -145,8 +172,10 @@ def save_split(self, """save split.""" tp = self.cfg.tensor_para_size if split_dim is not None: - tprint(f'*** splitting {name}, shape={tensor.shape}, ' - f'split_dim={split_dim}, tp={tp}') + tprint( + f'*** splitting {name}, shape={tensor.shape}, ' + f'split_dim={split_dim}, tp={tp}', + to_file=self.to_file) assert tensor.shape[split_dim] % tp == 0 split_size = tensor.shape[split_dim] // tp splits = torch.split(tensor, split_size, dim=split_dim) @@ -154,7 +183,8 @@ def save_split(self, prefix, ext = osp.splitext(name) self.export_weight(split, f'{prefix}.{i}{ext}') elif copy: - tprint(f'### copying {name}, shape={tensor.shape}') + tprint(f'### copying {name}, shape={tensor.shape}', + to_file=self.to_file) copies = [tensor] * tp for i, copy in enumerate(copies): prefix, ext = osp.splitext(name) @@ -166,7 +196,9 @@ def export(self) -> None: """Export to turbomind model format.""" num_layer = self.cfg.num_layer from tqdm import tqdm - pbar = tqdm(total=num_layer, desc='Convert to turbomind format') + pbar = tqdm(total=num_layer, + desc='Convert to turbomind format', + leave=self.to_file) self.export_config() for bin in self.input_model.bins(): self.export_misc(bin) diff --git a/lmdeploy/turbomind/hf_repo/config.json b/lmdeploy/turbomind/hf_repo/config.json new file mode 100644 index 0000000000..9778905e33 --- /dev/null +++ b/lmdeploy/turbomind/hf_repo/config.json @@ -0,0 +1,11 @@ +{ + "architectures": [ + "LMDeployForCausalLM" + ], + "auto_map": { + "AutoConfig": "configuration_lmdeploy.LMDeployConfig", + "AutoModel": "modeling_lmdeploy.LMDeployForCausalLM", + "AutoModelForCausalLM": "modeling_lmdeploy.LMDeployForCausalLM" + }, + "turbomind": {} +} diff --git a/lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py b/lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py new file mode 100644 index 0000000000..880ad66e81 --- /dev/null +++ b/lmdeploy/turbomind/hf_repo/configuration_lmdeploy.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from transformers import PretrainedConfig + +from lmdeploy.turbomind.deploy.target_model.base import TurbomindModelConfig +from lmdeploy.version import __version__ as lm_version + + +class LMDeployConfig(PretrainedConfig): + """Lmdeploy config.""" + + def __init__(self, turbomind: dict = None, **kwargs): + default_tm_cfg = copy.deepcopy( + TurbomindModelConfig.from_dict({}, allow_none=True).__dict__) + if turbomind is not None: + default_tm_cfg.update(turbomind) + self.turbomind = default_tm_cfg + self.lmdeploy_version = lm_version + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) + config, kwargs = super().from_pretrained(pretrained_model_name_or_path, + return_unused_kwargs=True, + **kwargs) + for k, v in kwargs.items(): + if k in config.turbomind.keys(): + config.turbomind[k] = v + if 'tp' in kwargs: + config.turbomind['tensor_para_size'] = kwargs['tp'] + if return_unused_kwargs: + return config, kwargs + else: + return config diff --git a/lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py b/lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py new file mode 100644 index 0000000000..ffb9b05613 --- /dev/null +++ b/lmdeploy/turbomind/hf_repo/modeling_lmdeploy.py @@ -0,0 +1,226 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dataclasses +import os +from contextlib import contextmanager +from dataclasses import dataclass, field +from itertools import count +from queue import Queue +from typing import List, Optional, Tuple, Union + +from huggingface_hub import snapshot_download +from transformers import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from lmdeploy.turbomind import TurboMind +from lmdeploy.turbomind.utils import get_gen_param + +from .configuration_lmdeploy import LMDeployConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class Session: + _count = count() + _session_id: int = None + _message: List[Tuple[str, str]] = field(default_factory=list) + _step: int = 0 + _nth_round: int = 0 + _error: int = 0 + + def __init__(self): + self._session_id = next(Session._count) + self._message = [] + self._step = 0 + self._nth_round = 0 + + @property + def session_id(self): + return self._session_id + + @property + def message(self): + return self._message + + @property + def step(self): + return self._step + + @property + def nth_round(self): + return self._nth_round + + @property + def error(self): + return self._error + + +class LMDeployForCausalLM(PreTrainedModel): + config_class = LMDeployConfig + + def __init__(self, + config: LMDeployConfig, + *inputs, + model_path: str = None, + **kwargs): + super().__init__(config) + self.tm_model = TurboMind.from_pretrained(model_path, **kwargs) + que = Queue() + for _ in range(config.turbomind['max_batch_size']): + que.put(self.tm_model.create_instance()) + self.que = que + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path, + *model_args, + config: Optional[Union[PretrainedConfig, str, + os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = 'main', + **kwargs): + """Instantiate a LM model with turbomind backend.""" + + resume_download = kwargs.pop('resume_download', True) + proxies = kwargs.pop('proxies', None) + + if os.path.isdir(pretrained_model_name_or_path): + local_folder = pretrained_model_name_or_path + else: + local_folder = snapshot_download( + pretrained_model_name_or_path, + revision=revision, + cache_dir=cache_dir, + proxies=proxies, + resume_download=resume_download, + force_download=force_download, + token=token, + local_files_only=local_files_only, + ) + + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else local_folder + kwargs.pop('return_unused_kwargs') + config, model_kwargs = cls.config_class.from_pretrained( + config_path, return_unused_kwargs=True, **kwargs) + else: + model_kwargs = kwargs + + model = cls(config, + *model_args, + model_path=local_folder, + **model_kwargs) + + generation_config = model.tm_model.model.sampling_param + for k, v in dataclasses.asdict(generation_config).items(): + if hasattr(model.generation_config, k): + base_value = getattr(model.generation_config, k) + setattr(generation_config, k, base_value) + if k in kwargs: + setattr(generation_config, k, v) + model.generation_config = generation_config + + return model + + @contextmanager + def managed_generator(self, session: Session): + generator = self.que.get() + try: + yield generator + except: # noqa E722 + for _ in generator.stream_infer(session.session_id, [0], + request_output_len=0, + sequence_start=False, + sequence_end=False, + stop=True): + pass + session._error = 1 + finally: + self.que.put(generator) + + def generate( + self, + input_ids: List[int], + session: Session, + **kwargs, + ): + """Generates sequences of token ids for models with a language modeling + head. + + Args: + input_ids (List(int)): list of input token ids + session (Session) session information + kwargs (dict): hoc parametrization of generation + """ + with self.managed_generator(session) as generator: + for outputs in generator.stream_infer( + session_id=session.session_id, + input_ids=[input_ids], + **kwargs, + ): + res, tokens = outputs[0] + yield res, tokens + + def chat( + self, + query: str, + session: Optional[Session] = None, + cap: str = 'chat', + request_output_len: int = 512, + stream_output: bool = False, + ignore_eos=False, + random_seed: Optional[int] = None, + **kwargs, + ) -> Tuple[str, Session]: + """chat.""" + + if session is None: + session = Session() + assert session._error == 0, 'An error occurred before, ' \ + 'please start a new session.' + + session._message.append([query, '']) + + prompt = self.tm_model.model.get_prompt(query, session.nth_round == 0) + input_ids = self.tm_model.tokenizer.encode(prompt) + + if len( + input_ids + ) + session.step + request_output_len >= self.tm_model.session_len: + logger.error( + f'session_length exceeded {self.tm_model.session_len}') + session._error = 1 + yield '', session + else: + gen_param = get_gen_param(cap, self.generation_config, + session.nth_round + 1, session.step, + request_output_len, **kwargs) + gen_kwargs = dataclasses.asdict(gen_param) + gen_kwargs.update( + random_seed=random_seed if session.nth_round == 0 else None, + stream_output=stream_output, + ignore_eos=ignore_eos, + **kwargs) + + _step = session._step + _nth_round = session._nth_round + response_size = 0 + + for res, tokens in self.generate(input_ids, + session=session, + **gen_kwargs): + response = self.tm_model.tokenizer.decode(res.tolist(), + offset=response_size) + if response.endswith('�'): + continue + response_size = tokens + + session._message[-1][-1] += response + session._nth_round = _nth_round + 1 + session._step = _step + response_size + + yield response, session diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 6dfcf6d383..8668dd803a 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -1,15 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +import copy +import io +import json +import logging import os.path as osp import sys from configparser import ConfigParser from contextlib import contextmanager from queue import Queue from threading import Thread -from typing import Iterable, List +from typing import Iterable, List, Optional import numpy as np import torch +from huggingface_hub import snapshot_download from torch.nn.utils.rnn import pad_sequence import lmdeploy @@ -17,19 +22,27 @@ from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger +from .deploy.converter import get_model_format, supported_formats +from .deploy.source_model.base import INPUT_MODELS +from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig +from .utils import (ModelSource, check_tm_model_input, create_hf_download_args, + get_hf_config_content, get_model_source) + # TODO: find another way import _turbomind lmdeploy_dir = osp.split(lmdeploy.__file__)[0] sys.path.append(osp.join(lmdeploy_dir, 'lib')) import _turbomind as _tm # noqa: E402 +logger = logging.getLogger(__name__) + def _stop_words(stop_words: List[str], tokenizer: Tokenizer): """return list of stop-words to numpy.ndarray.""" if stop_words is None: return None assert isinstance(stop_words, List) and \ - all(isinstance(elem, str) for elem in stop_words), \ - f'stop_words must be a list but got {type(stop_words)}' + all(isinstance(elem, str) for elem in stop_words), \ + f'stop_words must be a list but got {type(stop_words)}' stop_words = [ tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words ] @@ -76,77 +89,289 @@ class TurboMind: Args: model_path (str): the path of turbomind's model - eos_id (int): eos token id + model_source (int): model source + model_name (str): needed when model_path is a hf model and not + managed by lmdeploy + model_format (str): needed when model_path is a hf model and not + managed by lmdeploy + group_size (int): needed when model_path is a hf model and not + managed by lmdeploy tp (int): tensor parallel """ def __init__(self, model_path: str, - eos_id: int = 2, - tp: int = 1, + model_source: ModelSource = ModelSource.WORKSPACE, + model_name: Optional[str] = None, + model_format: Optional[str] = None, + group_size: Optional[int] = None, + tp: Optional[int] = None, **kwargs): - self.eos_id = eos_id - - # TODO: support mpi - node_id = 0 - node_num = 1 - - # read meta from model path - assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n' - self.gpu_count = tp - data_type = 'fp16' - ini_path = osp.join(model_path, 'triton_models/weights/config.ini') - with open(ini_path, 'r') as f: - parser = ConfigParser() - parser.read_file(f) - section_name = '' - if 'turbomind' in parser: - section_name = 'turbomind' - elif 'llama' in parser: - section_name = 'llama' - - if len(section_name) > 0: - tp_cfg = parser.getint(section_name, 'tensor_para_size') - if tp_cfg != 1 and tp_cfg != tp: - get_logger('turbomind').info( - f'found tp={tp_cfg} in config.ini.') - self.gpu_count = tp_cfg - self.model_name = parser.get(section_name, 'model_name') - data_type = parser.get(section_name, 'weight_type') + if tp is not None: + assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n' + self.gpu_count = tp if tp is not None else 1 + + if model_source == ModelSource.WORKSPACE: + tokenizer_model_path = osp.join(model_path, 'triton_models', + 'tokenizer') + self.tokenizer = Tokenizer(tokenizer_model_path) + self.model_comm = self._from_workspace(model_path) + else: + self.tokenizer = Tokenizer(model_path) + self.model_comm = self._from_hf(model_source=model_source, + model_path=model_path, + model_name=model_name, + model_format=model_format, + group_size=group_size, + tp=tp, + **kwargs) + + self.eos_id = self.tokenizer.eos_token_id self.model: BaseModel = MODELS.get(self.model_name)(**kwargs) self.session_len = self.model.session_len - tokenizer_model_path = osp.join(model_path, 'triton_models', - 'tokenizer') - tokenizer = Tokenizer(tokenizer_model_path) - self.stop_words = _stop_words(self.model.stop_words, tokenizer) + self.stop_words = _stop_words(self.model.stop_words, self.tokenizer) - # params - self.node_id = node_id - self.node_num = node_num - self.world_size = self.node_num * self.gpu_count + def _create_weight(self, model_comm): + """Allocate weight buffer, load params if from_workspace.""" - # create model - weight_dir = osp.join(model_path, 'triton_models', 'weights') - model_comm = _tm.AbstractTransformerModel.create_llama_model( - weight_dir, tensor_para_size=self.gpu_count, data_type=data_type) - self.model_comm = model_comm + # TODO: support mpi + self.node_id = 0 + self.node_num = 1 self.nccl_params = model_comm.create_nccl_params(self.node_id) torch.cuda.synchronize() # create weight - def _create_weight(device_id): + def _create_weight_func(device_id): with cuda_ctx(device_id): rank = self.node_id * self.gpu_count + device_id model_comm.create_shared_weights(device_id, rank) threads = [] for device_id in range(self.gpu_count): - t = Thread(target=_create_weight, args=(device_id, )) + t = Thread(target=_create_weight_func, args=(device_id, )) t.start() threads.append(t) for t in threads: t.join() + def _load_kv_qparams(self, model_path, tm_params, **kwargs): + """Load kv qparams when loading from hf.""" + if self.config.quant_policy: + logger.warning('loading kv_cache quant scale') + from lmdeploy.lite.apis.kv_qparams import main as kv_loader + kv_sym = kwargs.get('kv_sym', False) + kv_bits = kwargs.get('kv_bits', 8) + tp = self.config.tensor_para_size + kv_loader(model_path, model_path, kv_bits, kv_sym, tp, tm_params) + else: + for key in list(tm_params.keys()): + if 'past_kv_scale' in key: + tm_params.pop(key) + + def _get_model_params(self, model_comm, tm_params): + """Get turbomind model params when loading from hf.""" + + def _get_params(device_id, que): + with cuda_ctx(device_id): + rank = self.node_id * self.gpu_count + device_id + out = model_comm.get_params(device_id, rank) + que.put(out) + + que = Queue() + threads = [] + for device_id in range(self.gpu_count): + t = Thread(target=_get_params, args=(device_id, que)) + t.start() + threads.append(t) + for t in threads: + t.join() + + for _ in range(self.gpu_count): + tensor_map = que.get() + for k, v in tensor_map.items(): + if k not in tm_params: + tm_params[k] = [] + tm_params[k].append(v) + + def _from_hf(self, + model_source: ModelSource, + model_path: str, + model_name: Optional[str] = None, + model_format: Optional[str] = None, + group_size: Optional[int] = None, + tp: Optional[int] = None, + **kwargs): + """Load model which is in hf format.""" + # get model_name, group_size if is lmdeploy managed. + if model_source == ModelSource.HF_LMDEPLOY: + config = get_hf_config_content(model_path, local_files_only=True) + tm_config = config['turbomind'] + tm_config.update(kwargs) + var_shoud_be_none = dict(model_name=model_name, + model_format=model_format, + group_size=group_size) + for key, value in var_shoud_be_none.items(): + assert value is None, f'{key} should be None when model is '\ + f'from {model_source}' + model_name = tm_config['model_name'] + group_size = tm_config['group_size'] + if tm_config['weight_type'] == 'int4': + model_format = 'awq' + else: + assert model_name is not None, 'please supply model_name when ' \ + f'model is form {model_source}' + if osp.exists(osp.join(model_path, 'outputs_stats.pth')): + model_format = 'awq' if model_format is None else model_format + group_size = 128 if group_size is None else group_size + tm_config = kwargs + + assert model_name in MODELS.module_dict.keys(), \ + f"'{model_name}' is not supported. " \ + f'The supported models are: {MODELS.module_dict.keys()}' + assert model_format in supported_formats, 'the model format ' \ + f'should be in {supported_formats}' + + data_type = 'fp16' + output_format = 'fp16' + inferred_model_format = get_model_format(model_name, model_format) + cfg = TurbomindModelConfig.from_dict(tm_config, allow_none=True) + + # overwrite with input params + cfg.model_name = model_name + cfg.tensor_para_size = 1 if tp is None else tp + cfg.rotary_embedding = cfg.size_per_head + cfg.group_size = group_size + if inferred_model_format.find('awq') != -1: + cfg.weight_type = 'int4' + output_format = 'w4' + data_type = 'int4' + assert group_size > 0, f'group_size: {group_size} should > 0' + + self.config = cfg + self.model_name = model_name + self.data_type = data_type + + input_model = INPUT_MODELS.get(inferred_model_format)( + model_path=model_path, tokenizer_path=model_path, ckpt_path=None) + + output_model = OUTPUT_MODELS.get(output_format)( + input_model=input_model, cfg=cfg, to_file=False, out_dir='') + + config = copy.deepcopy(output_model.cfg.__dict__) + logger.warning(f'model_config:\n{json.dumps(config, indent=2)}') + parser = ConfigParser() + parser['llama'] = config + with io.StringIO() as ss: + parser.write(ss) + ss.seek(0) + config = ss.read() + + model_comm = _tm.AbstractTransformerModel.create_llama_model( + model_dir='', + config=config, + tensor_para_size=self.gpu_count, + data_type=data_type) + + # create empty weight + self._create_weight(model_comm) + + # copy hf model weight to turbomind weight + tm_params = output_model.tm_params + self._get_model_params(model_comm, tm_params) + logger.warning(f'get {len(tm_params)} model params') + output_model.export() + + # load kv qparams + self._load_kv_qparams(model_path, tm_params, **kwargs) + assert len(tm_params) == 0, f'missing {tm_params.keys()}' + + return model_comm + + def _from_workspace(self, model_path: str): + """Load model which is converted by `lmdeploy convert`""" + ini_path = osp.join(model_path, 'triton_models', 'weights', + 'config.ini') + with open(ini_path, 'r') as f: + parser = ConfigParser() + parser.read_file(f) + section_name = 'llama' + tp_cfg = parser.getint(section_name, 'tensor_para_size') + + if tp_cfg != 1 and tp_cfg != self.gpu_count: + get_logger('turbomind').info( + f'found tp={tp_cfg} in config.ini.') + self.gpu_count = tp_cfg + self.model_name = parser.get(section_name, 'model_name') + self.data_type = parser.get(section_name, 'weight_type') + cfg = parser._sections[section_name] + cfg = TurbomindModelConfig.from_dict(cfg) + self.config = cfg + + # create model + weight_dir = osp.join(model_path, 'triton_models', 'weights') + model_comm = _tm.AbstractTransformerModel.create_llama_model( + weight_dir, + tensor_para_size=self.gpu_count, + data_type=self.data_type) + + # create weight and load params + self._create_weight(model_comm) + return model_comm + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + model_name: Optional[str] = None, + model_format: Optional[str] = None, + group_size: Optional[int] = None, + tp: Optional[int] = None, + **kwargs): + """LMDeploy's turbomind inference engine. + + Args: + pretrained_model_name_or_path (str): + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii) + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "InternLM/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when pretrained_model_name_or_path is c) + model_format (str): model format + group_size (int): group size + tp (int): tensor parallel size + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update configuration when initialize the engine. + """ + model_source = get_model_source(pretrained_model_name_or_path) + if model_source == ModelSource.WORKSPACE: + local_path = pretrained_model_name_or_path + else: + check_tm_model_input(pretrained_model_name_or_path, + model_name=model_name, + **kwargs) + if not osp.exists(pretrained_model_name_or_path): + download_kwargs = create_hf_download_args(**kwargs) + local_path = snapshot_download(pretrained_model_name_or_path, + **download_kwargs) + else: + local_path = pretrained_model_name_or_path + + logger.warning(f'model_source: {model_source}') + return cls(model_source=model_source, + model_path=local_path, + model_name=model_name, + model_format=model_format, + group_size=group_size, + tp=tp, + **kwargs) + def create_instance(self, cuda_stream_id=0): """Create a turbomind instance. @@ -336,6 +561,7 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): tm_inputs = _np_dict_to_tm_dict(inputs) # start forward thread + self.que = Queue() self._forward_thread(tm_inputs) seq_start = input_lengths + input_lengths.new_tensor(step) diff --git a/lmdeploy/turbomind/utils.py b/lmdeploy/turbomind/utils.py new file mode 100644 index 0000000000..20540c1df3 --- /dev/null +++ b/lmdeploy/turbomind/utils.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dataclasses +import json +import logging +import os + +from huggingface_hub import hf_hub_download +from transformers.utils import ExplicitEnum + +logger = logging.getLogger(__name__) + + +class ModelSource(ExplicitEnum): + """Turbomind model source.""" + WORKSPACE = 'workspace' + HF_MODEL = 'hf_model' + HF_LMDEPLOY = 'hf_lmdeploy' + + +def create_hf_download_args(**kwargs) -> dict: + download_kwargs = { + 'revision': None, + 'cache_dir': None, + 'proxies': None, + 'resume_download': True, + 'force_download': False, + 'token': None, + 'local_files_only': False + } + for k in download_kwargs.keys(): + if k in kwargs: + download_kwargs[k] = kwargs[k] + return download_kwargs + + +def get_hf_config_path(pretrained_model_name_or_path, **kwargs) -> str: + """Get local hf config local file path.""" + if os.path.exists(pretrained_model_name_or_path): + config_path = os.path.join(pretrained_model_name_or_path, + 'config.json') + else: + download_kwargs = create_hf_download_args(**kwargs) + config_path = hf_hub_download(pretrained_model_name_or_path, + 'config.json', **download_kwargs) + return config_path + + +def get_hf_config_content(pretrained_model_name_or_path, **kwargs) -> dict: + """Get config content of a hf model.""" + config_path = get_hf_config_path(pretrained_model_name_or_path, **kwargs) + with open(config_path, 'r') as f: + config = json.load(f) + return config + + +def get_model_source(pretrained_model_name_or_path: str, + **kwargs) -> ModelSource: + """Get model source.""" + triton_model_path = os.path.join(pretrained_model_name_or_path, + 'triton_models') + if os.path.exists(triton_model_path): + return ModelSource.WORKSPACE + config = get_hf_config_content(pretrained_model_name_or_path, **kwargs) + model_source = ModelSource.HF_LMDEPLOY if 'turbomind' in config \ + else ModelSource.HF_MODEL + return model_source + + +def check_tm_model_input(pretrained_model_name_or_path, **kwargs): + """Check if single input pretrained_model_name_or_path is enough to use.""" + if kwargs.get('model_name', None): + return + + model_source = get_model_source(pretrained_model_name_or_path, **kwargs) + if model_source == ModelSource.WORKSPACE: + return + + config = get_hf_config_content(pretrained_model_name_or_path, **kwargs) + if 'turbomind' in config and config['turbomind']['model_name'] != '': + return + + assert (0), '\nCan not get model name from input model, '\ + 'please supply model name with arg --model-name,' \ + 'you can list supported models by `lmdeploy list`' + + +@dataclasses.dataclass +class GenParam: + top_p: float + top_k: float + temperature: float + repetition_penalty: float + sequence_start: bool = False + sequence_end: bool = False + step: int = 0 + request_output_len: int = 512 + + +def get_gen_param(cap, + sampling_param, + nth_round, + step, + request_output_len=512, + **kwargs): + """return parameters used by token generation.""" + gen_param = GenParam(**dataclasses.asdict(sampling_param), + request_output_len=request_output_len) + # Fix me later. turbomind.py doesn't support None top_k + if gen_param.top_k is None: + gen_param.top_k = 40 + + if cap == 'chat': + gen_param.sequence_start = (nth_round == 1) + gen_param.sequence_end = False + gen_param.step = step + else: + gen_param.sequence_start = True + gen_param.sequence_end = True + gen_param.step = 0 + return gen_param diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 275c89ddff..ab3eb783c4 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -110,6 +110,47 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias) } } +template +std::string concat(FirstArg&& first, Args&&... args) +{ + std::stringstream stream; + stream << first; + ((stream << "." << args), ...); + return stream.str(); +} + +template +void getWeightTensor(LlamaDenseWeight& weights, bool bias, const std::string& prefix, TensorMap& output) +{ + auto get_name = [=](const std::string& name) { return concat(prefix, name); }; + + if (bias) { + output.insert(get_name("bias"), + Tensor{MEMORY_GPU, getTensorType(), {weights.output_dims * sizeof(T)}, weights.bias}); + } + const size_t bit_size = getBitSize(weights.type); + if (bit_size >= 16) { + output.insert(get_name("weight"), + Tensor{MEMORY_GPU, + getTensorType(), + {weights.input_dims * weights.output_dims * sizeof(T)}, + weights.kernel}); + } + else { // int8, int4 + const int factor = sizeof(float) * 8 / bit_size; + output.insert(get_name("qweight"), + Tensor{MEMORY_GPU, + TYPE_INT32, + {weights.input_dims * weights.output_dims * sizeof(int) / factor}, + weights.kernel}); + output.insert(get_name("scales_zeros"), + Tensor{MEMORY_GPU, + getTensorType(), + {weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)}, + weights.scales_and_zeros}); + } +} + template void loadWeights(LlamaDenseWeight& w, std::string prefix, @@ -226,6 +267,7 @@ void LlamaDecoderLayerWeight::mallocWeights() turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_); turbomind::mallocWeights(self_attn_weights.output, attn_bias_); + self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f}; if (weight_type_ == WeightType::kINT4) { turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false); @@ -294,16 +336,43 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0); // load kv_cache quant scale - // if file not exist, get empty vector std::string scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight"; std::ifstream in(scale_path, std::ios::in); if (in.is_open()) { in.close(); self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path); } +} + +template +TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix) +{ + TensorMap output; + + output.insert(concat(prefix, "attention_norm.weight"), + Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, self_attn_norm_weights}); + + output.insert(concat(prefix, "ffn_norm.weight"), + Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, ffn_norm_weights}); + + auto get_prefix = [=](std::string_view name) { return concat(prefix, name, tensor_para_rank_); }; + + getWeightTensor(self_attn_weights.qkv, attn_bias_, get_prefix("attention.w_qkv"), output); + + getWeightTensor(self_attn_weights.output, attn_bias_, get_prefix("attention.wo"), output); + + if (weight_type_ == WeightType::kINT4) { + getWeightTensor(ffn_weights.fused_gating_intermediate, false, get_prefix("feed_forward.w13"), output); + } else { - self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f}; + getWeightTensor(ffn_weights.gating, false, get_prefix("feed_forward.w1"), output); + getWeightTensor(ffn_weights.intermediate, false, get_prefix("feed_forward.w3"), output); } + getWeightTensor(ffn_weights.output, false, get_prefix("feed_forward.w2"), output); + output.insert(concat(prefix, "past_kv_scale", tensor_para_rank_, "weight"), + Tensor{MEMORY_CPU, TYPE_FP32, {4 * sizeof(float)}, self_attn_weights.past_kv_scale.data()}); + + return output; } template struct LlamaDecoderLayerWeight; diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index 2141f72e7f..169a3aa9e6 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -21,6 +21,7 @@ #pragma once #include "src/turbomind/models/llama/LlamaDenseWeight.h" +#include "src/turbomind/utils/Tensor.h" namespace turbomind { @@ -43,6 +44,8 @@ struct LlamaDecoderLayerWeight { void loadModel(std::string dir_path, FtCudaDataType model_file_type); + TensorMap getParams(std::string prefix); + T* self_attn_norm_weights{}; T* ffn_norm_weights{}; LlamaAttentionWeight self_attn_weights{}; diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index e1287f471b..e270d3ba5c 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -109,6 +109,35 @@ void LlamaWeight::loadModel(std::string dir_path) } } +template +TensorMap LlamaWeight::getParams() +{ + TensorMap output; + + output.insert( + "tok_embeddings.weight", + Tensor{MEMORY_GPU, getTensorType(), {vocab_size_ * hidden_units_ * sizeof(T)}, pre_decoder_embedding_table}); + + output.insert("norm.weight", + Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, output_norm_weight}); + + output.insert( + "output.weight", + Tensor{ + MEMORY_GPU, getTensorType(), {hidden_units_ * vocab_size_ * sizeof(T)}, post_decoder_embedding_kernel}); + + // transformer layers + for (size_t i = 0; i < num_layer_; i++) { + std::string prefix = fmtstr("layers.%d", i); + TensorMap layeri = decoder_layer_weights[i]->getParams(prefix); + for (auto [name, tensor] : layeri) { + output.insert(name, tensor); + } + } + + return output; +} + template struct LlamaWeight; template struct LlamaWeight; diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index be7fda2b98..a896a87a09 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -47,6 +47,8 @@ struct LlamaWeight { void loadModel(std::string dir_path); + TensorMap getParams(); + std::vector*> decoder_layer_weights; const T* pre_decoder_embedding_table{}; const T* output_norm_weight{}; diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index b55ed040af..46e8443a86 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -282,6 +282,27 @@ PYBIND11_MODULE(_turbomind, m) return new triton::Tensor(self->where, self->type, new_shape, self->data); }, "new_shape"_a) + .def( + "copy_from", + [](triton::Tensor* self, py::object obj) { + py::capsule cap = obj.attr("__dlpack__")(); + DLManagedTensor* dlmt = + static_cast(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName)); + auto src = DLManagedTensorToTritonTensor(dlmt); + if (self->type == triton::TYPE_FP16 || self->type == triton::TYPE_FP32 + || self->type == triton::TYPE_INT32) { + auto num_element = + std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies()); + auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; + ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]); + cudaMemcpy( + const_cast(self->data), const_cast(src->data), num_bytes, cudaMemcpyDefault); + } + else { + ft::FT_CHECK(0); + } + }, + "tensor"_a) .def( "__dlpack__", [](triton::Tensor* self, long stream) { @@ -340,6 +361,7 @@ PYBIND11_MODULE(_turbomind, m) .def_static( "create_llama_model", [](std::string model_dir, + std::string config, size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, @@ -354,18 +376,19 @@ PYBIND11_MODULE(_turbomind, m) }; if (data_type == "half" || data_type == "fp16" || data_type == "int4") { auto model = std::make_shared>( - tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); model->setFfiLock(gil_control); return model; } else { auto model = std::make_shared>( - tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); model->setFfiLock(gil_control); return model; } }, "model_dir"_a, + "config"_a = "", "tensor_para_size"_a = 1, "pipeline_para_size"_a = 1, "enable_custom_all_reduce"_a = 0, @@ -406,6 +429,15 @@ PYBIND11_MODULE(_turbomind, m) py::call_guard(), "device_id"_a, "rank"_a) + .def( + "get_params", + [](AbstractTransformerModel* model, int deviceId, int rank) { + TensorMap output = model->getParams(deviceId, rank); + return output; + }, + py::call_guard(), + "device_id"_a, + "rank"_a) .def("__str__", &AbstractTransformerModel::toString) .def("__repr__", &AbstractTransformerModel::toString) .def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize) diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index fb54346ac0..580f6f0e7b 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -111,18 +111,35 @@ template LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, - std::string model_dir): + std::string model_dir, + std::string config): tensor_para_size_(tensor_para_size), pipeline_para_size_(pipeline_para_size), shared_weights_(std::vector>>(ft::getDeviceCount())), enable_custom_all_reduce_(enable_custom_all_reduce) { - model_dir_ = model_dir; - const std::string inifile{model_dir + "/config.ini"}; - INIReader reader = INIReader(inifile); - if (reader.ParseError() < 0) { - std::cout << "[ERROR] Can't load '" << inifile << "'\n"; - ft::FT_CHECK(false); + INIReader reader; + FT_CHECK_WITH_INFO((config.empty() ^ model_dir.empty()), "invalid init options"); + + if (!config.empty()) { + std::FILE* tmpf = std::tmpfile(); + std::fputs(config.c_str(), tmpf); + std::rewind(tmpf); + reader = INIReader(tmpf); + if (reader.ParseError() < 0) { + TM_LOG_ERROR("[ERROR] Can't init with config %s", config.c_str()); + ft::FT_CHECK(false); + } + } + + if (!model_dir.empty()) { + model_dir_ = model_dir; + const std::string inifile{model_dir + "/config.ini"}; + reader = INIReader(inifile); + if (reader.ParseError() < 0) { + TM_LOG_ERROR("[ERROR] Can't load %s", inifile.c_str()); + ft::FT_CHECK(false); + } } model_name_ = reader.Get("llama", "model_name"); @@ -154,7 +171,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, attn_params_.rope_scaling_factor = reader.GetFloat("llama", "rope_scaling_factor", 0.f); attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0); // attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0); - attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0); + attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0); handleMissingParams(); @@ -322,10 +339,27 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) group_size_, tensor_para_size_, tensor_para_rank); - shared_weights_[device_id]->loadModel(model_dir_); + // model inited with model_dir + if (model_dir_ != "") { + shared_weights_[device_id]->loadModel(model_dir_); + } return; } +template +TensorMap LlamaTritonModel::getParams(int deviceId, int rank) +{ + ft::check_cuda_error(cudaSetDevice(deviceId)); + // shared_weight should be created before getParams + ft::FT_CHECK(shared_weights_[deviceId] != nullptr); + ft::TensorMap output = shared_weights_[deviceId]->getParams(); + TensorMap result; + for (auto [name, tensor] : output) { + result.emplace(name, triton::Tensor{tensor.where, tensor.type, tensor.shape, tensor.data}); + } + return result; +} + template std::string LlamaTritonModel::toString() { diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index 0e2b89bff8..cdc56f2214 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -40,7 +40,8 @@ struct LlamaTritonModel: public AbstractTransformerModel { LlamaTritonModel(size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, - std::string model_dir); + std::string model_dir, + std::string config = ""); ~LlamaTritonModel() = default; @@ -53,6 +54,8 @@ struct LlamaTritonModel: public AbstractTransformerModel { void createSharedWeights(int deviceId, int rank) override; + TensorMap getParams(int deviceId, int rank) override; + void createCustomComms(std::vector>* custom_all_reduce_comms, int world_size) override; diff --git a/src/turbomind/triton_backend/transformer_triton_backend.hpp b/src/turbomind/triton_backend/transformer_triton_backend.hpp index 483651b8db..aee45d080f 100644 --- a/src/turbomind/triton_backend/transformer_triton_backend.hpp +++ b/src/turbomind/triton_backend/transformer_triton_backend.hpp @@ -301,6 +301,8 @@ struct AbstractTransformerModelInstance { void* stream_ctx_ = nullptr; }; +using TensorMap = std::unordered_map; + struct AbstractTransformerModel { static std::shared_ptr createLlamaModel(std::string model_dir); @@ -324,6 +326,8 @@ struct AbstractTransformerModel { virtual void createSharedWeights(int deviceId, int rank) = 0; + virtual TensorMap getParams(int deviceId, int rank) = 0; + virtual std::string toString() = 0; virtual int getTensorParaSize() = 0; virtual int getPipelineParaSize() = 0;