Skip to content

Commit

Permalink
Support TP for w4a16 (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Aug 18, 2023
1 parent 4a60b45 commit 89f3d32
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions lmdeploy/serve/turbomind/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,15 @@ def save_bin(param: torch.Tensor, name):
if key == 'w_qkv' and ext == 'bias':
attn_bias = True
copy = False
if key in ['w1', 'w3', 'w13']:
if key in ['w1', 'w3', 'w13', 'w_qkv']:
split_dim = -1
# TODO: move parameter extraction outside of the loop
if key == 'w1':
inter_size = max(inter_size, param_data.shape[-1])
elif key == 'w13':
inter_size = max(inter_size, param_data.shape[-1] // 2)

elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']:
if ext in ['scales', 'zeros', 'bias']:
if ext in ['bias']:
copy = True
else:
split_dim = 0
Expand Down Expand Up @@ -243,7 +240,10 @@ def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
def reshape(x):
return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)

return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)
qkv = torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)

# (input_dim, head_num + 2 * kv_head_num)
return qkv.view(q.size(0), -1)


def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
Expand Down Expand Up @@ -594,16 +594,16 @@ def get_tensor(name):
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402

def transpose_qk(src: torch.Tensor):
def transpose_qk_s4(src: torch.Tensor):
assert src.is_contiguous()
dst = torch.zeros_like(src)
_tm.transpose_qk_s4_k_m8(src, dst,
src.size(-1) * 8, src.size(0), group_size)
return dst

def fuse_w1_w3(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
w1_s: torch.Tensor, w3_qw: torch.Tensor,
w3_qz: torch.Tensor, w3_s: torch.Tensor):
def fuse_w1_w3_s4(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
w1_s: torch.Tensor, w3_qw: torch.Tensor,
w3_qz: torch.Tensor, w3_s: torch.Tensor):

def fuse(a: torch.Tensor, b: torch.Tensor):
ab = torch.cat((a, b)).contiguous()
Expand All @@ -625,12 +625,16 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32)
_sz = torch.zeros_like(s, dtype=torch.int32) # half2
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz

def tp_m_s4(x: torch.Tensor, tp: int):
return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
1).contiguous()

attn_bias = False

for i in range(num_layer):
Expand Down Expand Up @@ -661,10 +665,10 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
except: # noqa: E722
pass

q_qw = transpose_qk(q_qw)
k_qw = transpose_qk(k_qw)
q_qz = transpose_qk(q_qz)
k_qz = transpose_qk(k_qz)
q_qw = transpose_qk_s4(q_qw)
k_qw = transpose_qk_s4(k_qw)
q_qz = transpose_qk_s4(q_qz)
k_qz = transpose_qk_s4(k_qz)
q_s = permute(q_s)
k_s = permute(k_s)

Expand All @@ -674,6 +678,8 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,

qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)

qkv_qw = tp_m_s4(qkv_qw, tp)

model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw
model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz

Expand Down Expand Up @@ -702,12 +708,14 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales')
w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales')

w13_qw, w13_qz, w13_s = fuse_w1_w3(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s)
w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s)

w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)

w13_qw = tp_m_s4(w13_qw, tp)

model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw
model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz

Expand Down

0 comments on commit 89f3d32

Please sign in to comment.