Skip to content

Commit 020610e

Browse files
committed
fix dynamic Cfp8 computing error
1 parent 6e064a6 commit 020610e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
869869
local_max = __hmax(local_max, __habs(out_vec2[i]));
870870
}
871871
#pragma unroll
872-
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
872+
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
873873
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
874874
}
875875

tests/layers/test_append_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def init_tensor(self):
379379
)
380380
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
381381
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
382-
self.seq_lens_this_time = self.seq_lens_encoder
382+
self.seq_lens_this_time = copy.deepcopy(self.seq_lens_encoder)
383383

384384
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
385385
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
@@ -640,7 +640,7 @@ def test_all(self):
640640
)
641641
# encoder
642642
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
643-
self.seq_lens_this_time = self.seq_lens_encoder
643+
self.seq_lens_this_time = copy.deepcopy(self.seq_lens_encoder)
644644
if self.use_mask_offset:
645645
print("encoder mask_offset: ", self.mask_offset)
646646
self.cmp_append_attention(attn_mask=self.attention_mask)

0 commit comments

Comments
 (0)