Skip to content

Commit

Permalink
add debug log
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Nov 10, 2023
1 parent 1c68bcd commit ef8281a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/turbomind/kernels/sampling_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ __global__ void topk_stage1(const T* __restrict log_probs,
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);

if (tid == 0) {
if (total.p < 0) {
printf("[topk_stage1] k=%d, ite=%d, total.p=%d, total.u=%f, blockIdx=%d\n",
(int)k,
(int)ite,
(int)total.p,
(float)total.u,
(int)blockIdx.x);
}
const int index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p;
topk_tmp_val_buf[index] = total.u;
Expand Down Expand Up @@ -262,6 +270,14 @@ __global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf,
TopK_2<float> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<float>);

if (tid == 0) {
if (total.p < 0) {
printf("[topk_stage2] k=%d, ite=%d, total.p=%d, total.u=%f, blockIdx=%d\n",
(int)k,
(int)ite,
(int)total.p,
(float)total.u,
(int)blockIdx.x);
}
if (ite == 0) {
s_max = total.u;
}
Expand All @@ -285,7 +301,16 @@ __global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf,
float exp_logit = s_val2[i];
rand_num = rand_num - exp_logit;
if (rand_num <= 0.0f || i == k - 1) {
ids[batch_id] = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size;
int id = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size;
ids[batch_id] = id;
if (id < 0) {
printf("[topk_stage2] k=%d, i=%d, id=%d, exp_logit=%f, blockIdx=%d\n",
(int)k,
(int)i,
(int)id,
(float)exp_logit,
(int)blockIdx.x);
}
if (cum_log_probs != nullptr || output_log_probs != nullptr) {
float log_prob = logf(exp_logit);
if (cum_log_probs != nullptr) {
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ bool LlamaBatch<T>::generate()
session_len_,
batch_size_);

CheckValues(decoder_output_buf_, batch_size_ * llama_->hidden_units_, "decoderForward", stream_);

// CheckBatchConsistency(decoder_input_buf_, //
// llama_->hidden_units_,
// batch_size_,
Expand Down

0 comments on commit ef8281a

Please sign in to comment.