From 67a66e2d8619b949e0d447e73a7ae1ead61b8f64 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Tue, 31 Oct 2023 17:09:50 +0800
Subject: [PATCH] Fix race condition & check for anormal values (#627)

* check Inf and NaN

* fix `vocab_size`

* config smem size for `batchApplyRepetitionPenalty`

* `handleOptArg` with stream support

* sycn before using async data

* remove debug code

* fix msvc build

* fix msvc build
---
 src/turbomind/kernels/gpt_kernels.h           |  13 +-
 .../kernels/sampling_penalty_kernels.cu       |   7 +
 .../sampling_layers/TopKSamplingLayer.cu      |  22 ++++
 .../sampling_layers/TopPSamplingLayer.cu      |   1 +
 src/turbomind/models/llama/Barrier.h          |   1 +
 src/turbomind/models/llama/LlamaBatch.cc      |  58 ++++++++-
 .../models/llama/LlamaContextDecoder.cc       |  11 ++
 src/turbomind/models/llama/LlamaDecoder.cc    |  26 ++++
 .../llama/LlamaDecoderSelfAttentionLayer.cc   |  38 ++++++
 src/turbomind/models/llama/LlamaV2.cc         |   6 +
 src/turbomind/models/llama/llama_utils.cu     | 120 +++++++++++++++++-
 src/turbomind/models/llama/llama_utils.h      |  30 ++++-
 src/turbomind/utils/memory_utils.cu           |   7 +-
 13 files changed, 323 insertions(+), 17 deletions(-)

diff --git a/src/turbomind/kernels/gpt_kernels.h b/src/turbomind/kernels/gpt_kernels.h
index 4e1dc49be8..6a35e6b764 100644
--- a/src/turbomind/kernels/gpt_kernels.h
+++ b/src/turbomind/kernels/gpt_kernels.h
@@ -21,6 +21,7 @@
 #include <unordered_map>
 
 #include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/memory_utils.h"
 
 namespace turbomind {
@@ -131,14 +132,20 @@ void invokeFindContextDups(int*         shared_contexts,
                            cudaStream_t stream = 0);
 
 template<typename T>
-void handleOptArg(TensorMap* input_tensors, const std::string& arg_name, T* d_ptr, T default_value, size_t size)
+void handleOptArg(TensorMap*         input_tensors,
+                  const std::string& arg_name,
+                  T*                 d_ptr,
+                  T                  default_value,
+                  size_t             size,
+                  cudaStream_t       stream = {})
 {
     if (input_tensors->isExist(arg_name)) {
         FT_CHECK(input_tensors->at(arg_name).size() == size);
-        cudaH2Dcpy(d_ptr, input_tensors->at(arg_name).getPtr<const T>(), size);
+        check_cuda_error(cudaMemcpyAsync(
+            d_ptr, input_tensors->at(arg_name).getPtr<const T>(), sizeof(T) * size, cudaMemcpyDefault, stream));
     }
     else {
-        deviceFill(d_ptr, size, default_value);
+        deviceFill(d_ptr, size, default_value, stream);
     }
 }
 
diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu
index 4877bdb1a0..5e5b22f940 100644
--- a/src/turbomind/kernels/sampling_penalty_kernels.cu
+++ b/src/turbomind/kernels/sampling_penalty_kernels.cu
@@ -445,11 +445,18 @@ void invokeBatchApplyRepetitionPenalty(T*                    logits,
     dim3   block(min(step, 1024));
     dim3   grid(local_batch_size);
     size_t smem_size = step * (sizeof(float) + sizeof(int));
+
     if (penalty_type == RepetitionPenaltyType::Additive) {
+        check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>,
+                                              cudaFuncAttributeMaxDynamicSharedMemorySize,
+                                              smem_size));
         batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
             logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
     }
     else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
+        check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>,
+                                              cudaFuncAttributeMaxDynamicSharedMemorySize,
+                                              smem_size));
         batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
             logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
     }
diff --git a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
index 614b1a68ce..30b4cb4237 100644
--- a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
+++ b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
@@ -21,6 +21,8 @@
 #include "src/turbomind/kernels/sampling_topp_kernels.h"
 #include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h"
 #include "src/turbomind/macro.h"
+#include "src/turbomind/models/llama/llama_utils.h"
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
 #include "src/turbomind/utils/memory_utils.h"
 
@@ -131,6 +133,20 @@ void TopKSamplingLayer<T>::freeBuffer()
     is_allocate_buffer_ = false;
 }
 
+template<typename T>
+inline static std::string format(const Tensor& t)
+{
+    std::stringstream ss;
+    const int         size = t.size();
+    const T*          ptr  = t.getPtr<T>();
+    ss << "[";
+    for (int i = 0; i < size; ++i) {
+        ss << (i ? ", " : "") << ptr[i];
+    }
+    ss << "]";
+    return ss.str();
+}
+
 template<typename T>
 void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_width, TensorMap* runtime_args)
 {
@@ -168,6 +184,11 @@ void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
         cudaAutoCpy(runtime_top_p_buf_, runtime_top_p.getPtr<float>(), batch_size, stream_);
     }
 
+    // if (isDebug()) {
+    TM_LOG_INFO("[TopKSamplingLayer] runtime_top_k: %s", format<int>(runtime_top_k).c_str());
+    TM_LOG_INFO("[TopKSamplingLayer] runtime_top_p: %s", format<float>(runtime_top_p).c_str());
+    // }
+
     dim3 block(std::min((int)batch_size, 256));
     dim3 grid(div_up((int)batch_size, (int)block.x));
     // support top_k up to 1024.
@@ -182,6 +203,7 @@ void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
     cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_);
     uint* runtime_top_ks = new uint[batch_size];
     cudaAutoCpy(runtime_top_ks, runtime_top_k_buf_, batch_size, stream_);
+    check_cuda_error(cudaStreamSynchronize(stream_));
     runtime_max_top_k_ = static_cast<int>(*std::max_element(runtime_top_ks, runtime_top_ks + batch_size));
     delete[] runtime_top_ks;
 }
diff --git a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
index 8e7e97314f..ef05708526 100644
--- a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
+++ b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
@@ -249,6 +249,7 @@ void TopPSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
     cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_);
     float* runtime_top_ps = new float[batch_size];
     cudaAutoCpy(runtime_top_ps, runtime_top_p_buf_, batch_size, stream_);
+    check_cuda_error(cudaStreamSynchronize(stream_));
     runtime_max_top_p_ = *std::max_element(runtime_top_ps, runtime_top_ps + batch_size);
     delete[] runtime_top_ps;
 }
diff --git a/src/turbomind/models/llama/Barrier.h b/src/turbomind/models/llama/Barrier.h
index 6eb0df9585..e34c42e6ce 100644
--- a/src/turbomind/models/llama/Barrier.h
+++ b/src/turbomind/models/llama/Barrier.h
@@ -2,6 +2,7 @@
 
 #pragma once
 
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
 #ifndef _MSC_VER
 #include <pthread.h>
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 5d8d7d0411..096cfcb4f1 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -385,8 +385,7 @@ void LlamaBatch<T>::initializeSampling(int infer_request_count)
         }
     }
 
-    handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
-    cudaStreamSynchronize(0);
+    handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_, stream_);
 }
 
 template<typename T>
@@ -502,6 +501,35 @@ bool LlamaBatch<T>::generate()
                             batch_size_,
                             step_ - 1);
 
+    // insert +inf for half
+    // T x = 999999.f;
+    // cudaMemcpyAsync(decoder_input_buf_, &x, sizeof(x), cudaMemcpyDefault, stream_);
+
+    // CheckValues(decoder_input_buf_, batch_size_ * llama_->hidden_units_, "embedding_lookup", stream_);
+
+    // if (compare_mode == kCmpWrite) {
+    //     if (rank_ == 0) {
+    //         Compare(decoder_input_buf_, llama_->hidden_units_, Concat("decoder_input", step_), compare_mode,
+    //         stream_);
+    //     }
+    // }
+    // else {
+    //     for (int i = 0; i < batch_size_; ++i) {
+    //         Compare(decoder_input_buf_ + i * llama_->hidden_units_,
+    //                 llama_->hidden_units_,
+    //                 Concat("decoder_input", step_),
+    //                 compare_mode,
+    //                 stream_,
+    //                 Concat("", rank_, i));
+    //     }
+    // }
+    // CheckBatchConsistency(decoder_input_buf_,  //
+    //                       llama_->hidden_units_,
+    //                       batch_size_,
+    //                       Concat("decoder_input", step_),
+    //                       rank_,
+    //                       stream_);
+
     llama_->decoderForward(decoder_output_buf_,
                            k_cache_ptr_buf_,
                            v_cache_ptr_buf_,
@@ -514,11 +542,37 @@ bool LlamaBatch<T>::generate()
                            session_len_,
                            batch_size_);
 
+    // CheckBatchConsistency(decoder_input_buf_,  //
+    //                       llama_->hidden_units_,
+    //                       batch_size_,
+    //                       Concat("decoder_output", step_),
+    //                       rank_,
+    //                       stream_);
+
+    // if (compare_mode == kCmpWrite) {
+    //     if (rank_ == 0) {
+    //         Compare(decoder_output_buf_, llama_->hidden_units_, Concat("decoder_output", step_), compare_mode,
+    //         stream_);
+    //     }
+    // }
+    // else {
+    //     for (int i = 0; i < batch_size_; ++i) {
+    //         Compare(decoder_output_buf_ + i * llama_->hidden_units_,
+    //                 llama_->hidden_units_,
+    //                 Concat("decoder_output", step_),
+    //                 compare_mode,
+    //                 stream_,
+    //                 Concat("", rank_, i));
+    //     }
+    // }
+
     llama_->postDecodeEmbedding(logits_buf_,  //
                                 local_logits_buf_,
                                 decoder_output_buf_,
                                 batch_size_);
 
+    // CheckValues(logits_buf_, batch_size_ * llama_->vocab_size_padded_, "post_decode_embedding", stream_);
+
     // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
     // not supported yet.
     bool should_stop{};
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index f914063a70..7edf087a8b 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -25,6 +25,7 @@
 #include "src/turbomind/models/llama/LlamaContextDecoder.h"
 #include "src/turbomind/models/llama/llama_decoder_kernels.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
+#include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 
 namespace turbomind {
@@ -244,11 +245,15 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
                              stream_);
     sync_check_cuda_error();
 
+    // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm", 0), stream_);
+
     for (size_t layer = 0; layer < num_layer_; ++layer) {
         /////////////////////////////////////////////
         /// self-attention
         forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);
 
+        // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_self_attn", layer), stream_);
+
         invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
                                           decoder_output,
                                           decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
@@ -259,6 +264,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
                                           stream_);
         sync_check_cuda_error();
 
+        // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm1", layer), stream_);
+
         ////////////////////////////////////////////
         /// feed-forward network
         TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
@@ -266,6 +273,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
             {"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
         silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);
 
+        // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_ffn", layer), stream_);
+
         auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
                                                      input_tensors->at("output_norm_weight").getPtr<T>();
         invokeFusedAddBiasResidualRMSNorm(decoder_input_output,  //
@@ -277,6 +286,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
                                           hidden_units_,
                                           stream_);
         sync_check_cuda_error();
+
+        // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm2", layer), stream_);
     }
 
     if (is_free_buffer_after_forward_) {
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index 73e95b1353..7b934f3cb2 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -195,6 +195,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
     T* decoder_input  = input_tensors->at("decoder_input").getPtr<T>();
     T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
 
+    // int step = input_tensors->at("step").getVal<int>();
+
     ////////////////////////////////////////////
     /// RMSNorm
     invokeRootMeanSquareNorm(decoder_output,
@@ -206,10 +208,28 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
                              stream_);
     sync_check_cuda_error();
 
+    // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm", 0), stream_);
+
+    // CheckBatchConsistency(decoder_output,
+    //                       hidden_units_,
+    //                       sess.batch_size,
+    //                       Concat("decode_norm", step, 0),
+    //                       tensor_para_.rank_,
+    //                       stream_);
+
     for (size_t layer = 0; layer < num_layer_; ++layer) {
         // output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
         forwardSelfAttn(sess, decoder_output, input_tensors, layer);
 
+        // CheckBatchConsistency(decoder_output,
+        //                       hidden_units_,
+        //                       sess.batch_size,
+        //                       Concat("decode_self_attn", step, layer),
+        //                       tensor_para_.rank_,
+        //                       stream_);
+
+        // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_self_attn", layer), stream_);
+
         invokeFusedAddBiasResidualRMSNorm(decoder_input,
                                           decoder_output,
                                           decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
@@ -220,9 +240,13 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
                                           stream_);
         sync_check_cuda_error();
 
+        // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm1", layer), stream_);
+
         // decoder_layer_output_ = ffn(decoder_normed_input_)
         forwardFfn(sess, decoder_output, layer);
 
+        // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_ffn", layer), stream_);
+
         auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
                                                      input_tensors->at("output_norm_weight").getPtr<T>();
         invokeFusedAddBiasResidualRMSNorm(decoder_input,  //
@@ -234,6 +258,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
                                           hidden_units_,
                                           stream_);
         sync_check_cuda_error();
+
+        // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm2", layer), stream_);
     }
 
     if (is_free_buffer_after_forward_) {
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index 103b32e88f..4eb7937ba6 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -25,6 +25,7 @@
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
+#include "src/turbomind/utils/nccl_utils.h"
 #include "src/turbomind/utils/nvtx_utils.h"
 #include <string>
 // #include <glog/logging.h>
@@ -236,10 +237,24 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     allocateBuffer(batch_size, step, max_seq_len);
 
+    // CheckBatchConsistency((T*)input_query_data,
+    //                       hidden_units_,
+    //                       batch_size,
+    //                       Concat("before_qkv_gemm", step, layer_id),
+    //                       tensor_para_.rank_,
+    //                       stream_);
+
     PUSH_RANGE("qkv_gemm");
     linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
     POP_RANGE;
 
+    // CheckBatchConsistency(qkv_buf_,
+    //                       (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_,
+    //                       batch_size,
+    //                       Concat("after_qkv_gemm", step, layer_id),
+    //                       tensor_para_.rank_,
+    //                       stream_);
+
     const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
     const int  memory_len            = max_seq_len;
 
@@ -287,15 +302,38 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
         stream_);
     sync_check_cuda_error();
 
+    // CheckBatchConsistency((T*)context_buf_,
+    //                       local_hidden_units_,
+    //                       batch_size,
+    //                       Concat("before_o_gemm", step, layer_id),
+    //                       tensor_para_.rank_,
+    //                       stream_);
+
     linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
 
+    // CheckBatchConsistency(hidden_features_data,
+    //                       hidden_units_,
+    //                       batch_size,
+    //                       Concat("after_o_gemm", step, layer_id),
+    //                       tensor_para_.rank_,
+    //                       stream_);
+
     if (tensor_para_.world_size_ > 1) {
         NcclGuard nccl_guard(tensor_para_, stream_);
         ftNcclAllReduceSum(
             hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_);
         sync_check_cuda_error();
+        // ftNcclStreamSynchronize(tensor_para_, {}, stream_);
+        // sync_check_cuda_error();
     }
 
+    // CheckBatchConsistency(hidden_features_data,
+    //                       hidden_units_,
+    //                       batch_size,
+    //                       Concat("self_attn_allreduce", step, layer_id),
+    //                       tensor_para_.rank_,
+    //                       stream_);
+
     if (is_free_buffer_after_forward_) {
         freeBuffer();
     }
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index 8768e7fd05..97cf0e0a3d 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -438,6 +438,12 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
     TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_);
     check_cuda_error(cudaSetDevice(device_id));
 
+    model_instance_barrier() = shared_state_->barrier.get();
+
+    // initialize global counters
+    // CheckValues((T*)0, 0, {}, 0);
+    shared_state_->barrier->wait();
+
     auto& request_queue  = shared_state_->request_queue;
     auto& infer_requests = shared_state_->infer_requests;
     auto& stop_requests  = shared_state_->stop_requests;
diff --git a/src/turbomind/models/llama/llama_utils.cu b/src/turbomind/models/llama/llama_utils.cu
index 7050d2d13f..4d409cde35 100644
--- a/src/turbomind/models/llama/llama_utils.cu
+++ b/src/turbomind/models/llama/llama_utils.cu
@@ -56,7 +56,7 @@ void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream)
 }
 
 template<typename T>
-void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
+void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream, std::string msg)
 {
     // wait for b
     check_cuda_error(cudaStreamSynchronize(stream));
@@ -88,7 +88,7 @@ void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
     auto                  transform_iter = thrust::make_transform_iterator(zip_iter, abs_diff<T>{});
     // sum(abs(a - b))
     auto asum = thrust::reduce(thrust::device, transform_iter, transform_iter + size);
-    std::cerr << key << ": " << asum << " " << asum / size << "\n";
+    std::cerr << key << msg << ": " << asum << " " << asum / size << "\n";
 }
 
 template<typename T>
@@ -106,11 +106,11 @@ void CmpWrite(T* ptr, size_t size, std::string key, cudaStream_t stream)
 }
 
 template<typename T>
-void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream)
+void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg)
 {
     // std::cerr << "Comparing " << key << "\n";
     if (mode == kCmpRead) {
-        CmpRead(ptr, size, key, stream);
+        CmpRead(ptr, size, key, stream, msg);
     }
     else if (mode == kCmpWrite) {
         CmpWrite(ptr, size, key, stream);
@@ -120,9 +120,9 @@ void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t st
     }
 }
 
-template void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
-template void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
-template void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
+template void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg);
+template void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg);
+template void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg);
 
 template void CheckNan(const float* ptr, size_t size, std::string key, cudaStream_t stream);
 template void CheckNan(const half* ptr, size_t size, std::string key, cudaStream_t stream);
@@ -157,4 +157,110 @@ bool isDebug()
     return is_debug;
 }
 
+template<int kWarpCount, typename T>
+inline __device__ T blockSum(T val, int warp_id, int lane_id)
+{
+    __shared__ T smem_red[32];
+
+    for (int mask = 32 >> 1; mask >= 1; mask >>= 1) {
+        val += __shfl_xor_sync((uint32_t)-1, val, mask);
+    }
+    if (lane_id == 0) {
+        smem_red[warp_id] = val;
+    }
+
+    __syncthreads();
+
+    val = lane_id < kWarpCount ? smem_red[lane_id] : T{};
+
+    for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
+        val += __shfl_xor_sync((uint32_t)-1, val, mask);
+    }
+    val = __shfl_sync((uint32_t)-1, val, 0);
+
+    return val;
+}
+
+template<int kWarpCount, typename T>
+__global__ void CountInfNan(const T* data, uint32_t* g_inf_count, uint32_t* g_nan_count, size_t count)
+{
+    uint32_t inf_count = 0;
+    uint32_t nan_count = 0;
+    for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
+        if constexpr (std::is_same_v<T, float>) {
+            inf_count += __isinf(data[i]);
+            nan_count += __isnan(data[i]);
+        }
+        else if constexpr (std::is_same_v<T, half>) {
+            inf_count += __hisinf(data[i]);
+            nan_count += __hisnan(data[i]);
+        }
+    }
+    inf_count = blockSum<kWarpCount>(inf_count, threadIdx.x / 32, threadIdx.x % 32);
+    nan_count = blockSum<kWarpCount>(nan_count, threadIdx.x / 32, threadIdx.x % 32);
+    if (threadIdx.x == 0) {
+        if (inf_count) {
+            atomicAdd(g_inf_count, inf_count);
+        }
+        if (nan_count) {
+            atomicAdd(g_nan_count, nan_count);
+        }
+    }
+}
+namespace {
+struct Info {
+    char data[256];
+};
+
+}  // namespace
+
+__global__ void ReportInfNan(uint32_t* g_inf_count, uint32_t* g_nan_count, Info info)
+{
+    auto inf_count = *g_inf_count;
+    auto nan_count = *g_nan_count;
+    if (inf_count || nan_count) {
+        printf("[TM][ERROR] [%s] Inf=%u, NaN=%u\n", info.data, inf_count, nan_count);
+    }
+    // reset the counters for later use
+    *g_inf_count = 0;
+    *g_nan_count = 0;
+}
+
+template<typename T>
+void CheckValues(const T* data, int count, const std::string& msg, cudaStream_t stream)
+{
+    thread_local uint32_t* counters = [] {
+        uint32_t* ptr{};
+        cudaMalloc(&ptr, sizeof(uint32_t) * 2);
+        cudaMemset(ptr, 0, sizeof(uint32_t) * 2);
+        cudaDeviceSynchronize();
+        return ptr;
+    }();
+
+    if (data == nullptr && count == 0) {
+        return;
+    }
+
+    const auto g_inf_count = counters;
+    const auto g_nan_count = g_inf_count + 1;
+
+    FT_CHECK(msg.size() < sizeof(Info) - 1);
+
+    CountInfNan<4><<<256, 128, 0, stream>>>(data, g_inf_count, g_nan_count, count);
+
+    Info info;
+    strncpy(info.data, msg.c_str(), sizeof(info) - 1);
+
+    ReportInfNan<<<1, 1, 0, stream>>>(g_inf_count, g_nan_count, info);
+}
+
+template void CheckValues(const half* data, int count, const std::string& msg, cudaStream_t stream);
+template void CheckValues(const float* data, int count, const std::string& msg, cudaStream_t stream);
+
+Barrier*& model_instance_barrier()
+{
+    thread_local Barrier* p{};
+    return p;
+}
+
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h
index 05c10be80b..40a4688e01 100644
--- a/src/turbomind/models/llama/llama_utils.h
+++ b/src/turbomind/models/llama/llama_utils.h
@@ -1,6 +1,7 @@
 // Copyright (c) OpenMMLab. All rights reserved.
 
 #pragma once
+#include "src/turbomind/models/llama/Barrier.h"
 #include "src/turbomind/utils/Tensor.h"
 #include <cuda_runtime.h>
 #include <sstream>
@@ -29,7 +30,7 @@ enum CmpMode
 extern CmpMode compare_mode;
 
 template<typename T>
-void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
+void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg = {});
 
 template<typename T>
 void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream);
@@ -66,4 +67,31 @@ size_t curandStateGetSize();
 
 bool isDebug();
 
+template<typename T>
+void CheckValues(const T* data, int count, const std::string& msg, cudaStream_t stream);
+
+Barrier*& model_instance_barrier();
+
+template<typename T>
+inline void CheckBatchConsistency(T* ptr, size_t size, int batch_size, std::string key, int rank, cudaStream_t stream)
+{
+    if (compare_mode == kCmpNone) {
+        return;
+    }
+    model_instance_barrier()->wait();
+    if (compare_mode == kCmpWrite) {
+        if (rank == 0) {
+            Compare(ptr, size, key, compare_mode, stream);
+        }
+    }
+    else {
+        if (rank == 0) {
+            for (int i = 0; i < batch_size; ++i) {
+                Compare(ptr + i * size, size, key, compare_mode, stream, Concat("", rank, i));
+            }
+        }
+    }
+    model_instance_barrier()->wait();
+}
+
 }  // namespace turbomind
diff --git a/src/turbomind/utils/memory_utils.cu b/src/turbomind/utils/memory_utils.cu
index 93547f364f..4344bda9cf 100644
--- a/src/turbomind/utils/memory_utils.cu
+++ b/src/turbomind/utils/memory_utils.cu
@@ -93,10 +93,9 @@ template void deviceFree(__nv_fp8_e4m3*& ptr);
 template<typename T>
 void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream)
 {
-    T* arr = new T[size];
-    std::fill(arr, arr + size, value);
-    check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream));
-    delete[] arr;
+    std::unique_ptr<T[]> arr(new T[size]);
+    std::fill(arr.get(), arr.get() + size, value);
+    check_cuda_error(cudaMemcpyAsync(devptr, arr.get(), sizeof(T) * size, cudaMemcpyDefault, stream));
 }
 
 template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream);