Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition & check for anormal values (#627) #648

Merged
merged 1 commit into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/turbomind/kernels/gpt_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <unordered_map>

#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -131,14 +132,20 @@ void invokeFindContextDups(int* shared_contexts,
cudaStream_t stream = 0);

template<typename T>
void handleOptArg(TensorMap* input_tensors, const std::string& arg_name, T* d_ptr, T default_value, size_t size)
void handleOptArg(TensorMap* input_tensors,
const std::string& arg_name,
T* d_ptr,
T default_value,
size_t size,
cudaStream_t stream = {})
{
if (input_tensors->isExist(arg_name)) {
FT_CHECK(input_tensors->at(arg_name).size() == size);
cudaH2Dcpy(d_ptr, input_tensors->at(arg_name).getPtr<const T>(), size);
check_cuda_error(cudaMemcpyAsync(
d_ptr, input_tensors->at(arg_name).getPtr<const T>(), sizeof(T) * size, cudaMemcpyDefault, stream));
}
else {
deviceFill(d_ptr, size, default_value);
deviceFill(d_ptr, size, default_value, stream);
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,18 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
dim3 block(min(step, 1024));
dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));

if (penalty_type == RepetitionPenaltyType::Additive) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
Expand Down
22 changes: 22 additions & 0 deletions src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"

Expand Down Expand Up @@ -131,6 +133,20 @@ void TopKSamplingLayer<T>::freeBuffer()
is_allocate_buffer_ = false;
}

template<typename T>
inline static std::string format(const Tensor& t)
{
std::stringstream ss;
const int size = t.size();
const T* ptr = t.getPtr<T>();
ss << "[";
for (int i = 0; i < size; ++i) {
ss << (i ? ", " : "") << ptr[i];
}
ss << "]";
return ss.str();
}

template<typename T>
void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_width, TensorMap* runtime_args)
{
Expand Down Expand Up @@ -168,6 +184,11 @@ void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
cudaAutoCpy(runtime_top_p_buf_, runtime_top_p.getPtr<float>(), batch_size, stream_);
}

// if (isDebug()) {
TM_LOG_INFO("[TopKSamplingLayer] runtime_top_k: %s", format<int>(runtime_top_k).c_str());
TM_LOG_INFO("[TopKSamplingLayer] runtime_top_p: %s", format<float>(runtime_top_p).c_str());
// }

dim3 block(std::min((int)batch_size, 256));
dim3 grid(div_up((int)batch_size, (int)block.x));
// support top_k up to 1024.
Expand All @@ -182,6 +203,7 @@ void TopKSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_);
uint* runtime_top_ks = new uint[batch_size];
cudaAutoCpy(runtime_top_ks, runtime_top_k_buf_, batch_size, stream_);
check_cuda_error(cudaStreamSynchronize(stream_));
runtime_max_top_k_ = static_cast<int>(*std::max_element(runtime_top_ks, runtime_top_ks + batch_size));
delete[] runtime_top_ks;
}
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void TopPSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_);
float* runtime_top_ps = new float[batch_size];
cudaAutoCpy(runtime_top_ps, runtime_top_p_buf_, batch_size, stream_);
check_cuda_error(cudaStreamSynchronize(stream_));
runtime_max_top_p_ = *std::max_element(runtime_top_ps, runtime_top_ps + batch_size);
delete[] runtime_top_ps;
}
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/Barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#ifndef _MSC_VER
#include <pthread.h>
Expand Down
58 changes: 56 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,7 @@ void LlamaBatch<T>::initializeSampling(int infer_request_count)
}
}

handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
cudaStreamSynchronize(0);
handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_, stream_);
}

template<typename T>
Expand Down Expand Up @@ -502,6 +501,35 @@ bool LlamaBatch<T>::generate()
batch_size_,
step_ - 1);

// insert +inf for half
// T x = 999999.f;
// cudaMemcpyAsync(decoder_input_buf_, &x, sizeof(x), cudaMemcpyDefault, stream_);

// CheckValues(decoder_input_buf_, batch_size_ * llama_->hidden_units_, "embedding_lookup", stream_);

// if (compare_mode == kCmpWrite) {
// if (rank_ == 0) {
// Compare(decoder_input_buf_, llama_->hidden_units_, Concat("decoder_input", step_), compare_mode,
// stream_);
// }
// }
// else {
// for (int i = 0; i < batch_size_; ++i) {
// Compare(decoder_input_buf_ + i * llama_->hidden_units_,
// llama_->hidden_units_,
// Concat("decoder_input", step_),
// compare_mode,
// stream_,
// Concat("", rank_, i));
// }
// }
// CheckBatchConsistency(decoder_input_buf_, //
// llama_->hidden_units_,
// batch_size_,
// Concat("decoder_input", step_),
// rank_,
// stream_);

llama_->decoderForward(decoder_output_buf_,
k_cache_ptr_buf_,
v_cache_ptr_buf_,
Expand All @@ -514,11 +542,37 @@ bool LlamaBatch<T>::generate()
session_len_,
batch_size_);

// CheckBatchConsistency(decoder_input_buf_, //
// llama_->hidden_units_,
// batch_size_,
// Concat("decoder_output", step_),
// rank_,
// stream_);

// if (compare_mode == kCmpWrite) {
// if (rank_ == 0) {
// Compare(decoder_output_buf_, llama_->hidden_units_, Concat("decoder_output", step_), compare_mode,
// stream_);
// }
// }
// else {
// for (int i = 0; i < batch_size_; ++i) {
// Compare(decoder_output_buf_ + i * llama_->hidden_units_,
// llama_->hidden_units_,
// Concat("decoder_output", step_),
// compare_mode,
// stream_,
// Concat("", rank_, i));
// }
// }

llama_->postDecodeEmbedding(logits_buf_, //
local_logits_buf_,
decoder_output_buf_,
batch_size_);

// CheckValues(logits_buf_, batch_size_ * llama_->vocab_size_padded_, "post_decode_embedding", stream_);

// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet.
bool should_stop{};
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/models/llama/LlamaContextDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"

namespace turbomind {
Expand Down Expand Up @@ -244,11 +245,15 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm", 0), stream_);

for (size_t layer = 0; layer < num_layer_; ++layer) {
/////////////////////////////////////////////
/// self-attention
forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_self_attn", layer), stream_);

invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
Expand All @@ -259,13 +264,17 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm1", layer), stream_);

////////////////////////////////////////////
/// feed-forward network
TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
TensorMap ffn_outputs{
{"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_ffn", layer), stream_);

auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input_output, //
Expand All @@ -277,6 +286,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
hidden_units_,
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.token_num * hidden_units_, Concat("prefill_norm2", layer), stream_);
}

if (is_free_buffer_after_forward_) {
Expand Down
26 changes: 26 additions & 0 deletions src/turbomind/models/llama/LlamaDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
T* decoder_input = input_tensors->at("decoder_input").getPtr<T>();
T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();

// int step = input_tensors->at("step").getVal<int>();

////////////////////////////////////////////
/// RMSNorm
invokeRootMeanSquareNorm(decoder_output,
Expand All @@ -206,10 +208,28 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm", 0), stream_);

// CheckBatchConsistency(decoder_output,
// hidden_units_,
// sess.batch_size,
// Concat("decode_norm", step, 0),
// tensor_para_.rank_,
// stream_);

for (size_t layer = 0; layer < num_layer_; ++layer) {
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
forwardSelfAttn(sess, decoder_output, input_tensors, layer);

// CheckBatchConsistency(decoder_output,
// hidden_units_,
// sess.batch_size,
// Concat("decode_self_attn", step, layer),
// tensor_para_.rank_,
// stream_);

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_self_attn", layer), stream_);

invokeFusedAddBiasResidualRMSNorm(decoder_input,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
Expand All @@ -220,9 +240,13 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm1", layer), stream_);

// decoder_layer_output_ = ffn(decoder_normed_input_)
forwardFfn(sess, decoder_output, layer);

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_ffn", layer), stream_);

auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input, //
Expand All @@ -234,6 +258,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
hidden_units_,
stream_);
sync_check_cuda_error();

// CheckValues(decoder_output, sess.batch_size * hidden_units_, Concat("decode_norm2", layer), stream_);
}

if (is_free_buffer_after_forward_) {
Expand Down
38 changes: 38 additions & 0 deletions src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/nccl_utils.h"
#include "src/turbomind/utils/nvtx_utils.h"
#include <string>
// #include <glog/logging.h>
Expand Down Expand Up @@ -236,10 +237,24 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o

allocateBuffer(batch_size, step, max_seq_len);

// CheckBatchConsistency((T*)input_query_data,
// hidden_units_,
// batch_size,
// Concat("before_qkv_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

PUSH_RANGE("qkv_gemm");
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
POP_RANGE;

// CheckBatchConsistency(qkv_buf_,
// (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_,
// batch_size,
// Concat("after_qkv_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
const int memory_len = max_seq_len;

Expand Down Expand Up @@ -287,15 +302,38 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
stream_);
sync_check_cuda_error();

// CheckBatchConsistency((T*)context_buf_,
// local_hidden_units_,
// batch_size,
// Concat("before_o_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);

// CheckBatchConsistency(hidden_features_data,
// hidden_units_,
// batch_size,
// Concat("after_o_gemm", step, layer_id),
// tensor_para_.rank_,
// stream_);

if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(
hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_);
sync_check_cuda_error();
// ftNcclStreamSynchronize(tensor_para_, {}, stream_);
// sync_check_cuda_error();
}

// CheckBatchConsistency(hidden_features_data,
// hidden_units_,
// batch_size,
// Concat("self_attn_allreduce", step, layer_id),
// tensor_para_.rank_,
// stream_);

if (is_free_buffer_after_forward_) {
freeBuffer();
}
Expand Down
6 changes: 6 additions & 0 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,12 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_);
check_cuda_error(cudaSetDevice(device_id));

model_instance_barrier() = shared_state_->barrier.get();

// initialize global counters
// CheckValues((T*)0, 0, {}, 0);
shared_state_->barrier->wait();

auto& request_queue = shared_state_->request_queue;
auto& infer_requests = shared_state_->infer_requests;
auto& stop_requests = shared_state_->stop_requests;
Expand Down
Loading
Loading