Skip to content

Commit

Permalink
Fix race condition & check for anormal values (#627) (#648)
Browse files Browse the repository at this point in the history
* check Inf and NaN

* fix `vocab_size`

* config smem size for `batchApplyRepetitionPenalty`

* `handleOptArg` with stream support

* sycn before using async data

* remove debug code

* fix msvc build

* fix msvc build

Co-authored-by: Li Zhang <[email protected]>
  • Loading branch information
lvhan028 and lzhangzz authored Nov 5, 2023
1 parent 96f1b8e commit 1c68bcd
Show file tree
Hide file tree
Showing 13 changed files with 323 additions and 17 deletions.
13 changes: 10 additions & 3 deletions src/turbomind/kernels/gpt_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <unordered_map>

#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -131,14 +132,20 @@ void invokeFindContextDups(int* shared_contexts,
cudaStream_t stream = 0);

template<typename T>
void handleOptArg(TensorMap* input_tensors, const std::string& arg_name, T* d_ptr, T default_value, size_t size)
void handleOptArg(TensorMap* input_tensors,
const std::string& arg_name,
T* d_ptr,
T default_value,
size_t size,
cudaStream_t stream = {})
{
if (input_tensors->isExist(arg_name)) {
FT_CHECK(input_tensors->at(arg_name).size() == size);
cudaH2Dcpy(d_ptr, input_tensors->at(arg_name).getPtr<const T>(), size);
check_cuda_error(cudaMemcpyAsync(
d_ptr, input_tensors->at(arg_name).getPtr<const T>(), sizeof(T) * size, cudaMemcpyDefault, stream));
}
else {
deviceFill(d_ptr, size, default_value);
deviceFill(d_ptr, size, default_value, stream);
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,18 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
dim3 block(min(step, 1024));
dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));

if (penalty_type == RepetitionPenaltyType::Additive) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
Expand Down
22 changes: 22 additions & 0 deletions src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"

Expand Down Expand Up @@ -131,6 +133,20 @@ void TopKSamplingLayer<T>::freeBuffer()
is_allocate_buffer_ = false;
}

template<typename T>
inline static std::string format(const Tensor& t)
{
std::stringstream ss;
const int size = t.size();
const T* ptr = t.getPtr<T>();
ss << "[";
for (int i = 0; i < size; ++i) {
ss << (i ? ", " : "") << ptr[i];
}
ss << "]";
return ss.str();
}

template<typename T>
void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_width, TensorMap* runtime_args)
{
Expand Down Expand Up @@ -168,6 +184,11 @@ void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
cudaAutoCpy(runtime_top_p_buf_, runtime_top_p.getPtr<float>(), batch_size, stream_);
}

// if (isDebug()) {
TM_LOG_INFO("[TopKSamplingLayer] runtime_top_k: %s", format<int>(runtime_top_k).c_str());
TM_LOG_INFO("[TopKSamplingLayer] runtime_top_p: %s", format<float>(runtime_top_p).c_str());
// }

dim3 block(std::min((int)batch_size, 256));
dim3 grid(div_up((int)batch_size, (int)block.x));
// support top_k up to 1024.
Expand All @@ -182,6 +203,7 @@ void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_);
uint* runtime_top_ks = new uint[batch_size];
cudaAutoCpy(runtime_top_ks, runtime_top_k_buf_, batch_size, stream_);
check_cuda_error(cudaStreamSynchronize(stream_));
runtime_max_top_k_ = static_cast<int>(*std::max_element(runtime_top_ks, runtime_top_ks + batch_size));
delete[] runtime_top_ks;
}
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void TopPSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_);
float* runtime_top_ps = new float[batch_size];
cudaAutoCpy(runtime_top_ps, runtime_top_p_buf_, batch_size, stream_);
check_cuda_error(cudaStreamSynchronize(stream_));
runtime_max_top_p_ = *std::max_element(runtime_top_ps, runtime_top_ps + batch_size);
delete[] runtime_top_ps;
}
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/Barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#ifndef _MSC_VER
#include <pthread.h>
Expand Down
58 changes: 56 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,7 @@ void LlamaBatch<T>::initializeSampling(int infer_request_count)
}
}

handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
cudaStreamSynchronize(0);
handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_, stream_);
}

template<typename T>
Expand Down Expand Up @@ -502,6 +501,35 @@ bool LlamaBatch<T>::generate()
batch_size_,
step_ - 1);

// insert +inf for half
// T x = 999999.f;
// cudaMemcpyAsync(decoder_input_buf_, &x, sizeof(x), cudaMemcpyDefault, stream_);

// CheckValues(decoder_input_buf_, batch_size_ * llama_->hidden_units_, "embedding_lookup", stream_);

// if (compare_mode == kCmpWrite) {
// if (rank_ == 0) {
// Compare(decoder_input_buf_, llama_->hidden_units_, Concat("decoder_input", step_), compare_mode,
// stream_);
// }
// }
// else {
// for (int i = 0; i < batch_size_; ++i) {
// Compare(decoder_input_buf_ + i * llama_->hidden_units_,
// llama_->hidden_units_,
// Concat("decoder_input", step_),
// compare_mode,
// stream_,
// Concat("", rank_, i));
// }
// }
// CheckBatchConsistency(decoder_input_buf_, //
// llama_->hidden_units_,
// batch_size_,
// Concat("decoder_input", step_),
// rank_,
// stream_);

llama_->decoderForward(decoder_output_buf_,
k_cache_ptr_buf_,
v_cache_ptr_buf_,
Expand All @@ -514,11 +542,37 @@ bool LlamaBatch<T>::generate()
session_len_,
batch_size_);

// CheckBatchConsistency(decoder_input_buf_, //
// llama_->hidden_units_,
// batch_size_,
// Concat("decoder_output", step_),
// rank_,
// stream_);

// if (compare_mode == kCmpWrite) {
// if (rank_ == 0) {
// Compare(decoder_output_buf_, llama_->hidden_units_, Concat("decoder_output", step_), compare_mode,
// stream_);
// }
// }
// else {
// for (int i = 0; i < batch_size_; ++i) {
// Compare(decoder_output_buf_ + i * llama_->hidden_units_,
// llama_->hidden_units_,
// Concat("decoder_output", step_),
// compare_mode,
// stream_,
// Concat("", rank_, i));
// }
// }

llama_->postDecodeEmbedding(logits_buf_, //
local_logits_buf_,
decoder_output_buf_,
batch_size_);

// CheckValues(logits_buf_, batch_size_ * llama_->vocab_size_padded_, "post_decode_embedding", stream_);

// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet.
bool should_stop{};
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/models/llama/LlamaContextDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"

namespace turbomind {
Expand Down Expand Up @@ -244,11 +245,15 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm", 0), stream_);

for (size_t layer = 0; layer < num_layer_; ++layer) {
/////////////////////////////////////////////
/// self-attention
forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_self_attn", layer), stream_);

invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
Expand All @@ -259,13 +264,17 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm1", layer), stream_);

////////////////////////////////////////////
/// feed-forward network
TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
TensorMap ffn_outputs{
{"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_ffn", layer), stream_);

auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input_output, //
Expand All @@ -277,6 +286,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
hidden_units_,
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm2", layer), stream_);
}

if (is_free_buffer_after_forward_) {
Expand Down
26 changes: 26 additions & 0 deletions src/turbomind/models/llama/LlamaDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
T* decoder_input = input_tensors->at("decoder_input").getPtr<T>();
T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();

// int step = input_tensors->at("step").getVal<int>();

////////////////////////////////////////////
/// RMSNorm
invokeRootMeanSquareNorm(decoder_output,
Expand All @@ -206,10 +208,28 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm", 0), stream_);

// CheckBatchConsistency(decoder_output,
// hidden_units_,
// sess.batch_size,
// Concat("decode_norm", step, 0),
// tensor_para_.rank_,
// stream_);

for (size_t layer = 0; layer < num_layer_; ++layer) {
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
forwardSelfAttn(sess, decoder_output, input_tensors, layer);

// CheckBatchConsistency(decoder_output,
// hidden_units_,
// sess.batch_size,
// Concat("decode_self_attn", step, layer),
// tensor_para_.rank_,
// stream_);

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_self_attn", layer), stream_);

invokeFusedAddBiasResidualRMSNorm(decoder_input,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
Expand All @@ -220,9 +240,13 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm1", layer), stream_);

// decoder_layer_output_ = ffn(decoder_normed_input_)
forwardFfn(sess, decoder_output, layer);

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_ffn", layer), stream_);

auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input, //
Expand All @@ -234,6 +258,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
hidden_units_,
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm2", layer), stream_);
}

if (is_free_buffer_after_forward_) {
Expand Down
38 changes: 38 additions & 0 deletions src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/nccl_utils.h"
#include "src/turbomind/utils/nvtx_utils.h"
#include <string>
// #include <glog/logging.h>
Expand Down Expand Up @@ -236,10 +237,24 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o

allocateBuffer(batch_size, step, max_seq_len);

// CheckBatchConsistency((T*)input_query_data,
// hidden_units_,
// batch_size,
// Concat("before_qkv_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

PUSH_RANGE("qkv_gemm");
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
POP_RANGE;

// CheckBatchConsistency(qkv_buf_,
// (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_,
// batch_size,
// Concat("after_qkv_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
const int memory_len = max_seq_len;

Expand Down Expand Up @@ -287,15 +302,38 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
stream_);
sync_check_cuda_error();

// CheckBatchConsistency((T*)context_buf_,
// local_hidden_units_,
// batch_size,
// Concat("before_o_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);

// CheckBatchConsistency(hidden_features_data,
// hidden_units_,
// batch_size,
// Concat("after_o_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(
hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_);
sync_check_cuda_error();
// ftNcclStreamSynchronize(tensor_para_, {}, stream_);
// sync_check_cuda_error();
}

// CheckBatchConsistency(hidden_features_data,
// hidden_units_,
// batch_size,
// Concat("self_attn_allreduce", step, layer_id),
// tensor_para_.rank_,
// stream_);

if (is_free_buffer_after_forward_) {
freeBuffer();
}
Expand Down
6 changes: 6 additions & 0 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,12 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_);
check_cuda_error(cudaSetDevice(device_id));

model_instance_barrier() = shared_state_->barrier.get();

// initialize global counters
// CheckValues((T*)0, 0, {}, 0);
shared_state_->barrier->wait();

auto& request_queue = shared_state_->request_queue;
auto& infer_requests = shared_state_->infer_requests;
auto& stop_requests = shared_state_->stop_requests;
Expand Down
Loading

0 comments on commit 1c68bcd

Please sign in to comment.