From 34b3e55117b461ca070f4ed02d7c6d03a757d186 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Thu, 24 Oct 2024 07:28:01 +0000 Subject: [PATCH] update code --- dlinfer/vendor/ascend/torch_npu_ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 98b000dd..70d7eeec 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -109,7 +109,6 @@ def prefill_attention( scale_value = ( softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1]) ) - assert SocVersion.is_Ascend910B() or SocVersion.is_Ascend310P() if SocVersion.is_Ascend910B(): attn_output[:] = torch.ops.npu.npu_fusion_attention( query, @@ -123,6 +122,7 @@ def prefill_attention( actual_seq_kvlen=seq_kvlen_list, )[0] elif SocVersion.is_Ascend310P(): + assert num_q_heads == num_kv_heads, f"Ascend310P only support mha models." batch = q_start_loc.size(0) for i in range(batch): start = q_start_loc[i] @@ -152,6 +152,10 @@ def prefill_attention( input_layout="BSH", num_key_value_heads=num_kv_heads, ) + else: + raise ValueError( + f"dlinfer doesn't support {SocVersion.device_name} device currently." + ) else: # For now, the value of attn_mask is None only in vit seq_len_list = None if q_seq_len is None else q_seq_len.tolist()