diff --git a/src/turbomind/kernels/gpt_kernels.h b/src/turbomind/kernels/gpt_kernels.h index 4e1dc49be..6a35e6b76 100644 --- a/src/turbomind/kernels/gpt_kernels.h +++ b/src/turbomind/kernels/gpt_kernels.h @@ -21,6 +21,7 @@ #include #include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/memory_utils.h" namespace turbomind { @@ -131,14 +132,20 @@ void invokeFindContextDups(int* shared_contexts, cudaStream_t stream = 0); template -void handleOptArg(TensorMap* input_tensors, const std::string& arg_name, T* d_ptr, T default_value, size_t size) +void handleOptArg(TensorMap* input_tensors, + const std::string& arg_name, + T* d_ptr, + T default_value, + size_t size, + cudaStream_t stream = {}) { if (input_tensors->isExist(arg_name)) { FT_CHECK(input_tensors->at(arg_name).size() == size); - cudaH2Dcpy(d_ptr, input_tensors->at(arg_name).getPtr(), size); + check_cuda_error(cudaMemcpyAsync( + d_ptr, input_tensors->at(arg_name).getPtr(), sizeof(T) * size, cudaMemcpyDefault, stream)); } else { - deviceFill(d_ptr, size, default_value); + deviceFill(d_ptr, size, default_value, stream); } } diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index 4877bdb1a..5e5b22f94 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -445,11 +445,18 @@ void invokeBatchApplyRepetitionPenalty(T* logits, dim3 block(min(step, 1024)); dim3 grid(local_batch_size); size_t smem_size = step * (sizeof(float) + sizeof(int)); + if (penalty_type == RepetitionPenaltyType::Additive) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } diff --git a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu index 614b1a68c..30b4cb423 100644 --- a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu +++ b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu @@ -21,6 +21,8 @@ #include "src/turbomind/kernels/sampling_topp_kernels.h" #include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h" #include "src/turbomind/macro.h" +#include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/memory_utils.h" @@ -131,6 +133,20 @@ void TopKSamplingLayer::freeBuffer() is_allocate_buffer_ = false; } +template +inline static std::string format(const Tensor& t) +{ + std::stringstream ss; + const int size = t.size(); + const T* ptr = t.getPtr(); + ss << "["; + for (int i = 0; i < size; ++i) { + ss << (i ? ", " : "") << ptr[i]; + } + ss << "]"; + return ss.str(); +} + template void TopKSamplingLayer::setup(const size_t batch_size, const size_t beam_width, TensorMap* runtime_args) { @@ -168,6 +184,11 @@ void TopKSamplingLayer::setup(const size_t batch_size, const size_t beam_widt cudaAutoCpy(runtime_top_p_buf_, runtime_top_p.getPtr(), batch_size, stream_); } + // if (isDebug()) { + TM_LOG_INFO("[TopKSamplingLayer] runtime_top_k: %s", format(runtime_top_k).c_str()); + TM_LOG_INFO("[TopKSamplingLayer] runtime_top_p: %s", format(runtime_top_p).c_str()); + // } + dim3 block(std::min((int)batch_size, 256)); dim3 grid(div_up((int)batch_size, (int)block.x)); // support top_k up to 1024. @@ -182,6 +203,7 @@ void TopKSamplingLayer::setup(const size_t batch_size, const size_t beam_widt cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_); uint* runtime_top_ks = new uint[batch_size]; cudaAutoCpy(runtime_top_ks, runtime_top_k_buf_, batch_size, stream_); + check_cuda_error(cudaStreamSynchronize(stream_)); runtime_max_top_k_ = static_cast(*std::max_element(runtime_top_ks, runtime_top_ks + batch_size)); delete[] runtime_top_ks; } diff --git a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu index 8e7e97314..ef0570852 100644 --- a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu +++ b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu @@ -249,6 +249,7 @@ void TopPSamplingLayer::setup(const size_t batch_size, const size_t beam_widt cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_); float* runtime_top_ps = new float[batch_size]; cudaAutoCpy(runtime_top_ps, runtime_top_p_buf_, batch_size, stream_); + check_cuda_error(cudaStreamSynchronize(stream_)); runtime_max_top_p_ = *std::max_element(runtime_top_ps, runtime_top_ps + batch_size); delete[] runtime_top_ps; } diff --git a/src/turbomind/models/llama/Barrier.h b/src/turbomind/models/llama/Barrier.h index 6eb0df958..e34c42e6c 100644 --- a/src/turbomind/models/llama/Barrier.h +++ b/src/turbomind/models/llama/Barrier.h @@ -2,6 +2,7 @@ #pragma once +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" #ifndef _MSC_VER #include diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 5d8d7d041..096cfcb4f 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -385,8 +385,7 @@ void LlamaBatch::initializeSampling(int infer_request_count) } } - handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_); - cudaStreamSynchronize(0); + handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_, stream_); } template @@ -502,6 +501,35 @@ bool LlamaBatch::generate() batch_size_, step_ - 1); + // insert +inf for half + // T x = 999999.f; + // cudaMemcpyAsync(decoder_input_buf_, &x, sizeof(x), cudaMemcpyDefault, stream_); + + // CheckValues(decoder_input_buf_, batch_size_ * llama_->hidden_units_, "embedding_lookup", stream_); + + // if (compare_mode == kCmpWrite) { + // if (rank_ == 0) { + // Compare(decoder_input_buf_, llama_->hidden_units_, Concat("decoder_input", step_), compare_mode, + // stream_); + // } + // } + // else { + // for (int i = 0; i < batch_size_; ++i) { + // Compare(decoder_input_buf_ + i * llama_->hidden_units_, + // llama_->hidden_units_, + // Concat("decoder_input", step_), + // compare_mode, + // stream_, + // Concat("", rank_, i)); + // } + // } + // CheckBatchConsistency(decoder_input_buf_, // + // llama_->hidden_units_, + // batch_size_, + // Concat("decoder_input", step_), + // rank_, + // stream_); + llama_->decoderForward(decoder_output_buf_, k_cache_ptr_buf_, v_cache_ptr_buf_, @@ -514,11 +542,37 @@ bool LlamaBatch::generate() session_len_, batch_size_); + // CheckBatchConsistency(decoder_input_buf_, // + // llama_->hidden_units_, + // batch_size_, + // Concat("decoder_output", step_), + // rank_, + // stream_); + + // if (compare_mode == kCmpWrite) { + // if (rank_ == 0) { + // Compare(decoder_output_buf_, llama_->hidden_units_, Concat("decoder_output", step_), compare_mode, + // stream_); + // } + // } + // else { + // for (int i = 0; i < batch_size_; ++i) { + // Compare(decoder_output_buf_ + i * llama_->hidden_units_, + // llama_->hidden_units_, + // Concat("decoder_output", step_), + // compare_mode, + // stream_, + // Concat("", rank_, i)); + // } + // } + llama_->postDecodeEmbedding(logits_buf_, // local_logits_buf_, decoder_output_buf_, batch_size_); + // CheckValues(logits_buf_, batch_size_ * llama_->vocab_size_padded_, "post_decode_embedding", stream_); + // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is // not supported yet. bool should_stop{}; diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc index f914063a7..7edf087a8 100644 --- a/src/turbomind/models/llama/LlamaContextDecoder.cc +++ b/src/turbomind/models/llama/LlamaContextDecoder.cc @@ -25,6 +25,7 @@ #include "src/turbomind/models/llama/LlamaContextDecoder.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h" +#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/Tensor.h" namespace turbomind { @@ -244,11 +245,15 @@ void LlamaContextDecoder::forward(std::unordered_map* stream_); sync_check_cuda_error(); + // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm", 0), stream_); + for (size_t layer = 0; layer < num_layer_; ++layer) { ///////////////////////////////////////////// /// self-attention forwardSelfAttn(sess, decoder_output, input_tensors, layer, false); + // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_self_attn", layer), stream_); + invokeFusedAddBiasResidualRMSNorm(decoder_input_output, decoder_output, decoder_layer_weights->at(layer)->self_attn_weights.output.bias, @@ -259,6 +264,8 @@ void LlamaContextDecoder::forward(std::unordered_map* stream_); sync_check_cuda_error(); + // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm1", layer), stream_); + //////////////////////////////////////////// /// feed-forward network TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}}; @@ -266,6 +273,8 @@ void LlamaContextDecoder::forward(std::unordered_map* {"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}}; silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights); + // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_ffn", layer), stream_); + auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : input_tensors->at("output_norm_weight").getPtr(); invokeFusedAddBiasResidualRMSNorm(decoder_input_output, // @@ -277,6 +286,8 @@ void LlamaContextDecoder::forward(std::unordered_map* hidden_units_, stream_); sync_check_cuda_error(); + + // CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm2", layer), stream_); } if (is_free_buffer_after_forward_) { diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc index 73e95b135..7b934f3cb 100644 --- a/src/turbomind/models/llama/LlamaDecoder.cc +++ b/src/turbomind/models/llama/LlamaDecoder.cc @@ -195,6 +195,8 @@ void LlamaDecoder::forward(std::unordered_map* ou T* decoder_input = input_tensors->at("decoder_input").getPtr(); T* decoder_output = output_tensors->at("decoder_output").getPtr(); + // int step = input_tensors->at("step").getVal(); + //////////////////////////////////////////// /// RMSNorm invokeRootMeanSquareNorm(decoder_output, @@ -206,10 +208,28 @@ void LlamaDecoder::forward(std::unordered_map* ou stream_); sync_check_cuda_error(); + // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm", 0), stream_); + + // CheckBatchConsistency(decoder_output, + // hidden_units_, + // sess.batch_size, + // Concat("decode_norm", step, 0), + // tensor_para_.rank_, + // stream_); + for (size_t layer = 0; layer < num_layer_; ++layer) { // output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_) forwardSelfAttn(sess, decoder_output, input_tensors, layer); + // CheckBatchConsistency(decoder_output, + // hidden_units_, + // sess.batch_size, + // Concat("decode_self_attn", step, layer), + // tensor_para_.rank_, + // stream_); + + // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_self_attn", layer), stream_); + invokeFusedAddBiasResidualRMSNorm(decoder_input, decoder_output, decoder_layer_weights->at(layer)->self_attn_weights.output.bias, @@ -220,9 +240,13 @@ void LlamaDecoder::forward(std::unordered_map* ou stream_); sync_check_cuda_error(); + // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm1", layer), stream_); + // decoder_layer_output_ = ffn(decoder_normed_input_) forwardFfn(sess, decoder_output, layer); + // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_ffn", layer), stream_); + auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : input_tensors->at("output_norm_weight").getPtr(); invokeFusedAddBiasResidualRMSNorm(decoder_input, // @@ -234,6 +258,8 @@ void LlamaDecoder::forward(std::unordered_map* ou hidden_units_, stream_); sync_check_cuda_error(); + + // CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm2", layer), stream_); } if (is_free_buffer_after_forward_) { diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc index 103b32e88..4eb7937ba 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -25,6 +25,7 @@ #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" +#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nvtx_utils.h" #include // #include @@ -236,10 +237,24 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o allocateBuffer(batch_size, step, max_seq_len); + // CheckBatchConsistency((T*)input_query_data, + // hidden_units_, + // batch_size, + // Concat("before_qkv_gemm", step, layer_id), + // tensor_para_.rank_, + // stream_); + PUSH_RANGE("qkv_gemm"); linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv); POP_RANGE; + // CheckBatchConsistency(qkv_buf_, + // (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_, + // batch_size, + // Concat("after_qkv_gemm", step, layer_id), + // tensor_para_.rank_, + // stream_); + const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_; const int memory_len = max_seq_len; @@ -287,15 +302,38 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o stream_); sync_check_cuda_error(); + // CheckBatchConsistency((T*)context_buf_, + // local_hidden_units_, + // batch_size, + // Concat("before_o_gemm", step, layer_id), + // tensor_para_.rank_, + // stream_); + linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output); + // CheckBatchConsistency(hidden_features_data, + // hidden_units_, + // batch_size, + // Concat("after_o_gemm", step, layer_id), + // tensor_para_.rank_, + // stream_); + if (tensor_para_.world_size_ > 1) { NcclGuard nccl_guard(tensor_para_, stream_); ftNcclAllReduceSum( hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_); sync_check_cuda_error(); + // ftNcclStreamSynchronize(tensor_para_, {}, stream_); + // sync_check_cuda_error(); } + // CheckBatchConsistency(hidden_features_data, + // hidden_units_, + // batch_size, + // Concat("self_attn_allreduce", step, layer_id), + // tensor_para_.rank_, + // stream_); + if (is_free_buffer_after_forward_) { freeBuffer(); } diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 8768e7fd0..97cf0e0a3 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -438,6 +438,12 @@ void LlamaV2::internalThreadEntry(int device_id) TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_); check_cuda_error(cudaSetDevice(device_id)); + model_instance_barrier() = shared_state_->barrier.get(); + + // initialize global counters + // CheckValues((T*)0, 0, {}, 0); + shared_state_->barrier->wait(); + auto& request_queue = shared_state_->request_queue; auto& infer_requests = shared_state_->infer_requests; auto& stop_requests = shared_state_->stop_requests; diff --git a/src/turbomind/models/llama/llama_utils.cu b/src/turbomind/models/llama/llama_utils.cu index 7050d2d13..4d409cde3 100644 --- a/src/turbomind/models/llama/llama_utils.cu +++ b/src/turbomind/models/llama/llama_utils.cu @@ -56,7 +56,7 @@ void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream) } template -void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream) +void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream, std::string msg) { // wait for b check_cuda_error(cudaStreamSynchronize(stream)); @@ -88,7 +88,7 @@ void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream) auto transform_iter = thrust::make_transform_iterator(zip_iter, abs_diff{}); // sum(abs(a - b)) auto asum = thrust::reduce(thrust::device, transform_iter, transform_iter + size); - std::cerr << key << ": " << asum << " " << asum / size << "\n"; + std::cerr << key << msg << ": " << asum << " " << asum / size << "\n"; } template @@ -106,11 +106,11 @@ void CmpWrite(T* ptr, size_t size, std::string key, cudaStream_t stream) } template -void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream) +void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg) { // std::cerr << "Comparing " << key << "\n"; if (mode == kCmpRead) { - CmpRead(ptr, size, key, stream); + CmpRead(ptr, size, key, stream, msg); } else if (mode == kCmpWrite) { CmpWrite(ptr, size, key, stream); @@ -120,9 +120,9 @@ void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t st } } -template void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); -template void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); -template void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); +template void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg); +template void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg); +template void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg); template void CheckNan(const float* ptr, size_t size, std::string key, cudaStream_t stream); template void CheckNan(const half* ptr, size_t size, std::string key, cudaStream_t stream); @@ -157,4 +157,110 @@ bool isDebug() return is_debug; } +template +inline __device__ T blockSum(T val, int warp_id, int lane_id) +{ + __shared__ T smem_red[32]; + + for (int mask = 32 >> 1; mask >= 1; mask >>= 1) { + val += __shfl_xor_sync((uint32_t)-1, val, mask); + } + if (lane_id == 0) { + smem_red[warp_id] = val; + } + + __syncthreads(); + + val = lane_id < kWarpCount ? smem_red[lane_id] : T{}; + + for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) { + val += __shfl_xor_sync((uint32_t)-1, val, mask); + } + val = __shfl_sync((uint32_t)-1, val, 0); + + return val; +} + +template +__global__ void CountInfNan(const T* data, uint32_t* g_inf_count, uint32_t* g_nan_count, size_t count) +{ + uint32_t inf_count = 0; + uint32_t nan_count = 0; + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + if constexpr (std::is_same_v) { + inf_count += __isinf(data[i]); + nan_count += __isnan(data[i]); + } + else if constexpr (std::is_same_v) { + inf_count += __hisinf(data[i]); + nan_count += __hisnan(data[i]); + } + } + inf_count = blockSum(inf_count, threadIdx.x / 32, threadIdx.x % 32); + nan_count = blockSum(nan_count, threadIdx.x / 32, threadIdx.x % 32); + if (threadIdx.x == 0) { + if (inf_count) { + atomicAdd(g_inf_count, inf_count); + } + if (nan_count) { + atomicAdd(g_nan_count, nan_count); + } + } +} +namespace { +struct Info { + char data[256]; +}; + +} // namespace + +__global__ void ReportInfNan(uint32_t* g_inf_count, uint32_t* g_nan_count, Info info) +{ + auto inf_count = *g_inf_count; + auto nan_count = *g_nan_count; + if (inf_count || nan_count) { + printf("[TM][ERROR] [%s] Inf=%u, NaN=%u\n", info.data, inf_count, nan_count); + } + // reset the counters for later use + *g_inf_count = 0; + *g_nan_count = 0; +} + +template +void CheckValues(const T* data, int count, const std::string& msg, cudaStream_t stream) +{ + thread_local uint32_t* counters = [] { + uint32_t* ptr{}; + cudaMalloc(&ptr, sizeof(uint32_t) * 2); + cudaMemset(ptr, 0, sizeof(uint32_t) * 2); + cudaDeviceSynchronize(); + return ptr; + }(); + + if (data == nullptr && count == 0) { + return; + } + + const auto g_inf_count = counters; + const auto g_nan_count = g_inf_count + 1; + + FT_CHECK(msg.size() < sizeof(Info) - 1); + + CountInfNan<4><<<256, 128, 0, stream>>>(data, g_inf_count, g_nan_count, count); + + Info info; + strncpy(info.data, msg.c_str(), sizeof(info) - 1); + + ReportInfNan<<<1, 1, 0, stream>>>(g_inf_count, g_nan_count, info); +} + +template void CheckValues(const half* data, int count, const std::string& msg, cudaStream_t stream); +template void CheckValues(const float* data, int count, const std::string& msg, cudaStream_t stream); + +Barrier*& model_instance_barrier() +{ + thread_local Barrier* p{}; + return p; +} + } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h index 05c10be80..40a4688e0 100644 --- a/src/turbomind/models/llama/llama_utils.h +++ b/src/turbomind/models/llama/llama_utils.h @@ -1,6 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. #pragma once +#include "src/turbomind/models/llama/Barrier.h" #include "src/turbomind/utils/Tensor.h" #include #include @@ -29,7 +30,7 @@ enum CmpMode extern CmpMode compare_mode; template -void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); +void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream, std::string msg = {}); template void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream); @@ -66,4 +67,31 @@ size_t curandStateGetSize(); bool isDebug(); +template +void CheckValues(const T* data, int count, const std::string& msg, cudaStream_t stream); + +Barrier*& model_instance_barrier(); + +template +inline void CheckBatchConsistency(T* ptr, size_t size, int batch_size, std::string key, int rank, cudaStream_t stream) +{ + if (compare_mode == kCmpNone) { + return; + } + model_instance_barrier()->wait(); + if (compare_mode == kCmpWrite) { + if (rank == 0) { + Compare(ptr, size, key, compare_mode, stream); + } + } + else { + if (rank == 0) { + for (int i = 0; i < batch_size; ++i) { + Compare(ptr + i * size, size, key, compare_mode, stream, Concat("", rank, i)); + } + } + } + model_instance_barrier()->wait(); +} + } // namespace turbomind diff --git a/src/turbomind/utils/memory_utils.cu b/src/turbomind/utils/memory_utils.cu index 93547f364..4344bda9c 100644 --- a/src/turbomind/utils/memory_utils.cu +++ b/src/turbomind/utils/memory_utils.cu @@ -93,10 +93,9 @@ template void deviceFree(__nv_fp8_e4m3*& ptr); template void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) { - T* arr = new T[size]; - std::fill(arr, arr + size, value); - check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); - delete[] arr; + std::unique_ptr arr(new T[size]); + std::fill(arr.get(), arr.get() + size, value); + check_cuda_error(cudaMemcpyAsync(devptr, arr.get(), sizeof(T) * size, cudaMemcpyDefault, stream)); } template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream);