We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0545c9f commit 51ec61cCopy full SHA for 51ec61c
lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py
@@ -19,8 +19,15 @@ def apply_rotary_pos_emb(
19
query_states_reshaped = query_states.reshape(1, bs, head, dim)
20
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim)
21
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
22
- cos = cos[position_ids_1d].view(1, bs, 1, -1)
23
- sin = sin[position_ids_1d].view(1, bs, 1, -1)
+ if len(cos.shape) == 3 and len(sin.shape) == 3:
+ cos = cos[:, position_ids_1d].view(1, bs, 1, -1)
24
+ sin = sin[:, position_ids_1d].view(1, bs, 1, -1)
25
+ elif len(cos.shape) == 2 and len(sin.shape) == 2:
26
+ cos = cos[position_ids_1d].view(1, bs, 1, -1)
27
+ sin = sin[position_ids_1d].view(1, bs, 1, -1)
28
+ else:
29
+ raise RuntimeError("Cannot handle cos/sin shape dims!")
30
+
31
if context:
32
setattr(context, 'cos', cos)
33
setattr(context, 'sin', sin)
0 commit comments