From ef8281a4f66bb734b6164664e2fc6c56ae085058 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 10 Nov 2023 04:13:01 +0000 Subject: [PATCH] add debug log --- .../kernels/sampling_topk_kernels.cu | 27 ++++++++++++++++++- src/turbomind/models/llama/LlamaBatch.cc | 2 ++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/turbomind/kernels/sampling_topk_kernels.cu b/src/turbomind/kernels/sampling_topk_kernels.cu index 82b208298d..927a0bc981 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.cu +++ b/src/turbomind/kernels/sampling_topk_kernels.cu @@ -193,6 +193,14 @@ __global__ void topk_stage1(const T* __restrict log_probs, TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { + if (total.p < 0) { + printf("[topk_stage1] k=%d, ite=%d, total.p=%d, total.u=%f, blockIdx=%d\n", + (int)k, + (int)ite, + (int)total.p, + (float)total.u, + (int)blockIdx.x); + } const int index = tmp_topk_buf_index + ite; topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; @@ -262,6 +270,14 @@ __global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf, TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { + if (total.p < 0) { + printf("[topk_stage2] k=%d, ite=%d, total.p=%d, total.u=%f, blockIdx=%d\n", + (int)k, + (int)ite, + (int)total.p, + (float)total.u, + (int)blockIdx.x); + } if (ite == 0) { s_max = total.u; } @@ -285,7 +301,16 @@ __global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf, float exp_logit = s_val2[i]; rand_num = rand_num - exp_logit; if (rand_num <= 0.0f || i == k - 1) { - ids[batch_id] = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size; + int id = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size; + ids[batch_id] = id; + if (id < 0) { + printf("[topk_stage2] k=%d, i=%d, id=%d, exp_logit=%f, blockIdx=%d\n", + (int)k, + (int)i, + (int)id, + (float)exp_logit, + (int)blockIdx.x); + } if (cum_log_probs != nullptr || output_log_probs != nullptr) { float log_prob = logf(exp_logit); if (cum_log_probs != nullptr) { diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 096cfcb4f1..79fde27001 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -542,6 +542,8 @@ bool LlamaBatch::generate() session_len_, batch_size_); + CheckValues(decoder_output_buf_, batch_size_ * llama_->hidden_units_, "decoderForward", stream_); + // CheckBatchConsistency(decoder_input_buf_, // // llama_->hidden_units_, // batch_size_,