diff --git a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py index a68287b52..acfa44a42 100644 --- a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py @@ -19,8 +19,15 @@ def apply_rotary_pos_emb( query_states_reshaped = query_states.reshape(1, bs, head, dim) key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = cos[position_ids_1d].view(1, bs, 1, -1) - sin = sin[position_ids_1d].view(1, bs, 1, -1) + if len(cos.shape) == 3 and len(sin.shape) == 3: + cos = cos[:, position_ids_1d].view(1, bs, 1, -1) + sin = sin[:, position_ids_1d].view(1, bs, 1, -1) + elif len(cos.shape) == 2 and len(sin.shape) == 2: + cos = cos[position_ids_1d].view(1, bs, 1, -1) + sin = sin[position_ids_1d].view(1, bs, 1, -1) + else: + raise RuntimeError("Cannot handle cos/sin shape dims!") + if context: setattr(context, 'cos', cos) setattr(context, 'sin', sin)