diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index dd1288c56..df2730029 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -371,13 +371,15 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const int* output_ids, const int batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step) { - const int batch_idx = blockIdx.x; - const float penalty = penalties[batch_idx]; - const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + const int batch_idx = blockIdx.x; + const float penalty = penalties[batch_idx]; + const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + const int prompt_length = prompt_lengths != nullptr ? prompt_lengths[batch_idx] : 0; penalty_workspace += batch_idx * step * 2; float* penalty_logits = (float*)penalty_workspace; @@ -388,6 +390,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits, // Phase 1. Find indices to penalize and keep the penalized values. // A vocab id can appear multiple times but should be penalized once. for (int index = threadIdx.x; index < step; index += blockDim.x) { + // skip prompt + if (index < prompt_length) { + continue; + } // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; @@ -414,6 +420,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits, // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { + // skip prompt + if (index < prompt_length) { + continue; + } // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; @@ -430,6 +440,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, @@ -451,6 +462,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, output_ids, batch_size, vocab_size, + prompt_lengths, input_lengths, max_input_length, step); @@ -463,6 +475,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, output_ids, batch_size, vocab_size, + prompt_lengths, input_lengths, max_input_length, step); @@ -479,6 +492,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, @@ -492,6 +506,7 @@ template void invokeBatchApplyRepetitionPenalty(half* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, diff --git a/src/turbomind/kernels/sampling_penalty_kernels.h b/src/turbomind/kernels/sampling_penalty_kernels.h index e12698cdf..62f35f110 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.h +++ b/src/turbomind/kernels/sampling_penalty_kernels.h @@ -45,6 +45,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, const int batch_size, const int local_batch_size, const int vocab_size, + const int* prompt_lengths, const int* input_lengths, const int max_input_length, const int step, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc index 3e23cbd61..429ac7db0 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc @@ -39,6 +39,8 @@ void BaseSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tenso allocator_->reMalloc(runtime_logits_buf_, sizeof(T) * batch_size * vocab_size_padded_, false)); skip_decode_buf_ = reinterpret_cast(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false)); + prompt_lengths_buf_ = + reinterpret_cast(allocator_->reMalloc(prompt_lengths_buf_, sizeof(int) * batch_size, false)); // host buffers. temperature_ = (float*)std::realloc((void*)temperature_, batch_size * sizeof(float)); @@ -59,6 +61,7 @@ void BaseSamplingLayer::freeBuffer() allocator_->free((void**)(&temperature_buf_)); allocator_->free((void**)(&repetition_penalty_buf_)); allocator_->free((void**)(&min_lengths_buf_)); + allocator_->free((void**)(&prompt_lengths_buf_)); allocator_->free((void**)(&runtime_logits_buf_)); allocator_->free((void**)(&skip_decode_buf_)); std::free(temperature_); @@ -164,6 +167,14 @@ void BaseSamplingLayer::setup(const size_t batch_size, const size_t beam_widt repetition_penalty_type_ = RepetitionPenaltyType::None; } + if (runtime_args->isExist("prompt_length")) { + Tensor prompt_lengths = runtime_args->at("prompt_length"); + cudaAutoCpy(prompt_lengths_buf_, prompt_lengths.getPtr(), batch_size, stream_); + } + else { + deviceFill(prompt_lengths_buf_, batch_size, 0, stream_); + } + // min_length if (runtime_args->isExist("min_length")) { Tensor min_lengths = runtime_args->at("min_length"); @@ -304,6 +315,7 @@ void BaseSamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_t batch_size, local_batch_size, vocab_size_padded_, + prompt_lengths_buf_ + ite * local_batch_size, input_tensors->at("input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {}, nullptr}).getPtr(), max_input_length, step, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h index 83d2c40f2..6645fffa7 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h @@ -43,6 +43,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { int* min_lengths_buf_ = nullptr; bool* skip_decode_buf_ = nullptr; T* runtime_logits_buf_ = nullptr; + int* prompt_lengths_buf_ = nullptr; float* temperature_ = nullptr; float* repetition_penalty_ = nullptr; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index c6bfc1502..1bc3e6833 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1066,11 +1066,9 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) } } - // MinLengthPenalty - if (inputs.isExist("min_length")) { - inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}}); - inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}}); - } + // MinLengthPenalty & RepetitionPenalty + inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}}); + inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}}); // init for eos std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_); diff --git a/tests/csrc/unittests/test_penalty_kernels.cu b/tests/csrc/unittests/test_penalty_kernels.cu index 301b79aa9..e774f30fc 100644 --- a/tests/csrc/unittests/test_penalty_kernels.cu +++ b/tests/csrc/unittests/test_penalty_kernels.cu @@ -509,6 +509,7 @@ public: batch_size_, batch_size_, vocab_size_padded_, + nullptr, d_input_lengths_, max_input_length_, step_, @@ -568,6 +569,7 @@ public: batch_size_, batch_size_, vocab_size_padded_, + nullptr, d_input_lengths_, max_input_length_, step_,