Skip to content

Commit

Permalink
[Feature] Blazing fast W4A16 inference (#202)
Browse files Browse the repository at this point in the history
* add w4a16

* fix `deploy.py`

* add doc

* add w4a16 kernels

* fuse w1/w3 & bugfixes

* fix typo

* python

* guard sm75/80 features

* add missing header

* refactor

* qkvo bias

* update cost model

* fix lint

* update `deploy.py`
  • Loading branch information
lzhangzz authored Aug 14, 2023
1 parent d3dbe17 commit c3290ca
Show file tree
Hide file tree
Showing 27 changed files with 2,804 additions and 134 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:llama_fmha>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
Expand Down
13 changes: 13 additions & 0 deletions docs/en/serving.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ bash workspace/service_docker_up.sh

</details>

<details open>
<summary><b>7B with INT4 weight only quantization</b></summary>

```shell
python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-7b-chat-hf \
--model_format awq \
--group_size 128 \
--quant_path /path/to/awq-quant-weight.pt
bash workspace/service_docker_up.sh
```

</details>

## Serving [LLaMA](https://github.com/facebookresearch/llama)

Weights for the LLaMA models can be obtained from by filling out [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
Expand Down
260 changes: 251 additions & 9 deletions lmdeploy/serve/turbomind/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import os.path as osp
import re
import shutil
import sys
from pathlib import Path

import fire
import safetensors
import torch
from sentencepiece import SentencePieceProcessor

import lmdeploy
from lmdeploy.model import MODELS

supported_formats = ['llama', 'hf']
supported_formats = ['llama', 'hf', 'awq']


def get_package_root_path():
Expand Down Expand Up @@ -107,7 +109,9 @@ def export(model_name: str,
tokenizer_path: str,
out_dir: str,
tp: int,
size_per_head: int = 128):
size_per_head: int = 128,
group_size: int = 0,
weight_type: str = 'fp16'):
"""Export deploying information to a config file.
Args:
Expand All @@ -127,9 +131,10 @@ def save_bin(param: torch.Tensor, name):
print(name, param.shape)
if param.dtype in [torch.float, torch.bfloat16]:
param = param.half()
param.contiguous().numpy().tofile(osp.join(out_dir, name))
param.contiguous().cpu().numpy().tofile(osp.join(out_dir, name))

attn_bias = False
inter_size = 0

# reverse the splitting axes since the weights are transposed above
for param_name, param_data in model_params.items():
Expand All @@ -141,10 +146,14 @@ def save_bin(param: torch.Tensor, name):
if key == 'w_qkv' and ext == 'bias':
attn_bias = True
copy = False
if key in ['w1', 'w3']:
if key in ['w1', 'w3', 'w13']:
split_dim = -1
# TODO: move parameter extraction outside of the loop
if key == 'w1':
inter_size = param_data.shape[-1]
inter_size = max(inter_size, param_data.shape[-1])
elif key == 'w13':
inter_size = max(inter_size, param_data.shape[-1] // 2)

elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']:
Expand All @@ -170,6 +179,8 @@ def save_bin(param: torch.Tensor, name):
else:
save_bin(param_data, param_name)

assert inter_size > 0

# export config and save it to {out_dir}/config.ini
model = MODELS.get(model_name)()
vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path)
Expand All @@ -188,7 +199,8 @@ def save_bin(param: torch.Tensor, name):
attn_bias=int(attn_bias),
start_id=bos_id,
end_id=eos_id,
weight_type='fp16',
weight_type=weight_type,
group_size=group_size,
# parameters for turbomind
max_batch_size=32,
max_context_token_num=4,
Expand Down Expand Up @@ -329,7 +341,7 @@ def get_param(_name, _size):

def permute(x: torch.Tensor):
SIZE_PER_HEAD = 128
if x.shape[-1] > 1: # qweights
if x.shape[-1] > 1:
dim = x.shape[-1]
n_heads = dim // SIZE_PER_HEAD
return x.view(-1, n_heads, 2,
Expand Down Expand Up @@ -491,6 +503,228 @@ def get_tensor_transposed(name: str):
tokenizer_path, triton_models_path, tp)


def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int, quant_path: str,
group_size: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
if tokenizer_path is None:
tokenizer_path = osp.join(model_path, 'tokenizer.model')
if osp.exists(tokenizer_path):
shutil.copy(tokenizer_path,
osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
for _file in os.listdir(model_path):
if _file.endswith('.json') or _file.endswith('.py'):
json_path = osp.join(model_path, _file)
shutil.copy(json_path,
osp.join(triton_models_path, 'tokenizer', _file))
with get_package_root_path() as root_path:
shutil.copy(osp.join(root_path, 'turbomind/tokenizer.py'),
osp.join(triton_models_path, 'tokenizer'))
else:
print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1)

# read model arguments from params.json
try:
params_path = osp.join(model_path, 'config.json')
with open(params_path) as f:
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
except Exception as e:
print(f'get "num_hidden_layers" and "rms_norm_eps" from '
f'{params_path} failed: {e}')
return False

# convert weights from hf to turbomind
if quant_path is None:
_files = [
osp.join(model_path, file) for file in os.listdir(model_path)
if file.endswith('.bin')
]
_files = sorted(_files)
else:
_files = [quant_path]

model_params = {}

_params = {}
for _file in _files:
_tmp = torch.load(_file, map_location='cpu')
_params.update(_tmp)

def get_tensor(name):
"""return tensor according its name."""
return _params[name].cuda().contiguous()

# import _turbomind as _tm
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402

def transpose_qk(src: torch.Tensor):
assert src.is_contiguous()
dst = torch.zeros_like(src)
_tm.transpose_qk_s4_k_m8(src, dst,
src.size(-1) * 8, src.size(0), group_size)
return dst

def fuse_w1_w3(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
w1_s: torch.Tensor, w3_qw: torch.Tensor,
w3_qz: torch.Tensor, w3_s: torch.Tensor):

def fuse(a: torch.Tensor, b: torch.Tensor):
ab = torch.cat((a, b)).contiguous()
_ab = torch.zeros_like(ab)
_tm.fuse_w1_w3_s4_k_m8(ab, _ab, a.size(-1) * 8, a.size(0))
return _ab.view(a.size(0), -1)

w13_qw = fuse(w1_qw, w3_qw)
w13_qz = fuse(w1_qz, w3_qz)

w13_s = torch.cat((w1_s, w3_s)).view(2, w1_s.size(0), -1)
w13_s = w13_s.permute(1, 2, 0).contiguous().view(w1_s.size(0), -1)

return w13_qw, w13_qz, w13_s

def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
assert qw.is_contiguous()
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32)
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz

attn_bias = False

for i in range(num_layer):
print(i)

# attention weights
q_qw = get_tensor(f'model.layers.{i}.self_attn.q_proj.qweight')
k_qw = get_tensor(f'model.layers.{i}.self_attn.k_proj.qweight')
v_qw = get_tensor(f'model.layers.{i}.self_attn.v_proj.qweight')
o_qw = get_tensor(f'model.layers.{i}.self_attn.o_proj.qweight')

q_qz = get_tensor(f'model.layers.{i}.self_attn.q_proj.qzeros')
k_qz = get_tensor(f'model.layers.{i}.self_attn.k_proj.qzeros')
v_qz = get_tensor(f'model.layers.{i}.self_attn.v_proj.qzeros')
o_qz = get_tensor(f'model.layers.{i}.self_attn.o_proj.qzeros')

q_s = get_tensor(f'model.layers.{i}.self_attn.q_proj.scales')
k_s = get_tensor(f'model.layers.{i}.self_attn.k_proj.scales')
v_s = get_tensor(f'model.layers.{i}.self_attn.v_proj.scales')
o_s = get_tensor(f'model.layers.{i}.self_attn.o_proj.scales')

try:
q_b = get_tensor(f'model.layers.{i}.self_attn.q_proj.bias')
k_b = get_tensor(f'model.layers.{i}.self_attn.k_proj.bias')
v_b = get_tensor(f'model.layers.{i}.self_attn.v_proj.bias')
o_b = get_tensor(f'model.layers.{i}.self_attn.o_proj.bias')
attn_bias = True
except: # noqa: E722
pass

q_qw = transpose_qk(q_qw)
k_qw = transpose_qk(k_qw)
q_qz = transpose_qk(q_qz)
k_qz = transpose_qk(k_qz)
q_s = permute(q_s)
k_s = permute(k_s)

qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2)
qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)

qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)

model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw
model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz

o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)

model_params[f'layers.{i}.attention.wo.qweight'] = o_qw
model_params[f'layers.{i}.attention.wo.scales_zeros'] = o_sz

if attn_bias:
q_b = permute(q_b)
k_b = permute(k_b)
qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
model_params[f'layers.{i}.attention.wo.bias'] = o_b

# ffn weights
w1_qw = get_tensor(f'model.layers.{i}.mlp.gate_proj.qweight')
w2_qw = get_tensor(f'model.layers.{i}.mlp.down_proj.qweight')
w3_qw = get_tensor(f'model.layers.{i}.mlp.up_proj.qweight')

w1_qz = get_tensor(f'model.layers.{i}.mlp.gate_proj.qzeros')
w2_qz = get_tensor(f'model.layers.{i}.mlp.down_proj.qzeros')
w3_qz = get_tensor(f'model.layers.{i}.mlp.up_proj.qzeros')

w1_s = get_tensor(f'model.layers.{i}.mlp.gate_proj.scales')
w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales')
w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales')

w13_qw, w13_qz, w13_s = fuse_w1_w3(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s)

w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)

model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw
model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz

model_params[f'layers.{i}.feed_forward.w2.qweight'] = w2_qw
model_params[f'layers.{i}.feed_forward.w2.scales_zeros'] = w2_sz

# norm weights
attn_norm = get_tensor(f'model.layers.{i}.input_layernorm.weight')
ffn_norm = get_tensor(
f'model.layers.{i}.post_attention_layernorm.weight')

model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm

other = [('tok_embeddings.weight', 'model.embed_tokens.weight'),
('norm.weight', 'model.norm.weight'),
('output.weight', 'lm_head.weight')]
for ft, hf in other:
model_params[ft] = get_tensor(hf)

return export(model_name,
num_layer,
norm_eps,
kv_head_num,
model_params,
tokenizer_path,
triton_models_path,
tp,
weight_type='int4',
group_size=group_size)


def pack_model_repository(workspace_path: str):
"""package the model repository.
Expand Down Expand Up @@ -521,7 +755,9 @@ def main(model_name: str,
model_format: str = 'hf',
tokenizer_path: str = None,
dst_path: str = './workspace',
tp: int = 1):
tp: int = 1,
quant_path: str = None,
group_size: int = 0):
"""deploy llama family models via turbomind.
Args:
Expand All @@ -533,6 +769,9 @@ def main(model_name: str,
tokenizer_path (str): the path of tokenizer model
dst_path (str): the destination path that saves outputs
tp (int): the number of GPUs used for tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand All @@ -558,9 +797,12 @@ def main(model_name: str,
if model_format == 'llama':
res = deploy_llama(model_name, model_path, tokenizer_path,
triton_models_path, tp)
else:
elif model_format == 'hf':
res = deploy_hf(model_name, model_path, tokenizer_path,
triton_models_path, tp)
elif model_format == 'awq':
res = deploy_awq(model_name, model_path, tokenizer_path,
triton_models_path, tp, quant_path, group_size)

# update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,5 @@ set_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOL
add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

add_subdirectory(gemm_s_f16)
7 changes: 7 additions & 0 deletions src/turbomind/kernels/gemm_s_f16/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu)
target_compile_options(gemm_s4_f16 PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
set_property(TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm_s4_f16 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Loading

0 comments on commit c3290ca

Please sign in to comment.