diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index 28bf43aac9..dd1288c56f 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -367,6 +367,7 @@ template void invokeApplyRepetitionPenalty(half* logits, template __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int vocab_size, @@ -374,11 +375,13 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const int max_input_length, const int step) { - extern __shared__ float penalty_logits[]; - int* penalty_indices = (int*)(penalty_logits + step); - const int batch_idx = blockIdx.x; - const float penalty = penalties[batch_idx]; - const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + const int batch_idx = blockIdx.x; + const float penalty = penalties[batch_idx]; + const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; + + penalty_workspace += batch_idx * step * 2; + float* penalty_logits = (float*)penalty_workspace; + int* penalty_indices = (int*)(penalty_workspace + step); logits += batch_idx * vocab_size; @@ -409,10 +412,6 @@ __global__ void batchApplyRepetitionPenalty(T* logits, } } - if (blockDim.x > 32) { - __syncthreads(); - } - // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { // Skip the padding tokens in input sequences. @@ -426,6 +425,7 @@ __global__ void batchApplyRepetitionPenalty(T* logits, template void invokeBatchApplyRepetitionPenalty(T* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, @@ -442,22 +442,30 @@ void invokeBatchApplyRepetitionPenalty(T* logits, // output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size). // input_lengths [local_batch_size], input lengths (optional). // Padding tokens at [input_length, max_input_length) of input will not be penalized. - dim3 block(min(step, 1024)); - dim3 grid(local_batch_size); - size_t smem_size = step * (sizeof(float) + sizeof(int)); + dim3 block(min(step, 1024)); + dim3 grid(local_batch_size); if (penalty_type == RepetitionPenaltyType::Additive) { - check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - batchApplyRepetitionPenalty<<>>( - logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); + batchApplyRepetitionPenalty<<>>(logits, + penalties, + penalty_workspace, + output_ids, + batch_size, + vocab_size, + input_lengths, + max_input_length, + step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { - check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - batchApplyRepetitionPenalty<<>>( - logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); + batchApplyRepetitionPenalty + <<>>(logits, + penalties, + penalty_workspace, + output_ids, + batch_size, + vocab_size, + input_lengths, + max_input_length, + step); } else if (penalty_type == RepetitionPenaltyType::None) { // do nothing @@ -466,6 +474,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, template void invokeBatchApplyRepetitionPenalty(float* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, @@ -478,6 +487,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits, template void invokeBatchApplyRepetitionPenalty(half* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, diff --git a/src/turbomind/kernels/sampling_penalty_kernels.h b/src/turbomind/kernels/sampling_penalty_kernels.h index 3c54cc15bf..e12698cdf7 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.h +++ b/src/turbomind/kernels/sampling_penalty_kernels.h @@ -40,6 +40,7 @@ void invokeApplyRepetitionPenalty(T* logits, template void invokeBatchApplyRepetitionPenalty(T* logits, const float* penalties, + int* penalty_workspace, const int* output_ids, const int batch_size, const int local_batch_size, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc index 91b6809f3f..3e23cbd616 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc @@ -55,6 +55,7 @@ void BaseSamplingLayer::freeBuffer() { TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { + allocator_->free((void**)(&repetition_penalty_workspace_)); allocator_->free((void**)(&temperature_buf_)); allocator_->free((void**)(&repetition_penalty_buf_)); allocator_->free((void**)(&min_lengths_buf_)); @@ -293,9 +294,12 @@ void BaseSamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_t if (step > 1 && repetition_penalty_type_ != RepetitionPenaltyType::None) { float default_value = getDefaultPenaltyValue(repetition_penalty_type_); if (!ALL_OF(repetition_penalty_ + ite * local_batch_size, local_batch_size, float, default_value)) { + repetition_penalty_workspace_ = reinterpret_cast(allocator_->reMalloc( + repetition_penalty_workspace_, batch_size * step * (sizeof(int) + sizeof(float)), false)); invokeBatchApplyRepetitionPenalty( logits, repetition_penalty_buf_ + ite * local_batch_size, + repetition_penalty_workspace_ + ite * local_batch_size, output_tensors->at("output_ids").getPtrWithOffset(ite * local_batch_size), batch_size, local_batch_size, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h index 68cf79c871..83d2c40f24 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h @@ -33,6 +33,8 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { size_t vocab_size_; size_t vocab_size_padded_; + int* repetition_penalty_workspace_; + size_t sampling_workspace_size_; void* sampling_workspace_ = nullptr; diff --git a/tests/csrc/unittests/test_penalty_kernels.cu b/tests/csrc/unittests/test_penalty_kernels.cu index 86d23f44e6..301b79aa9f 100644 --- a/tests/csrc/unittests/test_penalty_kernels.cu +++ b/tests/csrc/unittests/test_penalty_kernels.cu @@ -18,10 +18,10 @@ #include // snprintf #include // expf, log #include -#include // rand -#include // std::string +#include // rand +#include // std::string #include -#include // std::vector +#include // std::vector #include #include @@ -386,6 +386,7 @@ protected: T* d_bias_; int* d_output_ids_; int* d_input_lengths_; + int* d_penalty_workspace_; float* d_repetition_penalties_; @@ -410,6 +411,8 @@ protected: d_bias_ = reinterpret_cast(allocator->malloc(sizeof(T) * vocab_size_padded_)); d_output_ids_ = reinterpret_cast(allocator->malloc(sizeof(int) * sequence_length_ * batch_size_)); d_input_lengths_ = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size_)); + d_penalty_workspace_ = + reinterpret_cast(allocator->malloc((sizeof(int) + sizeof(float)) * batch_size_ * step_)); cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream); cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream); @@ -501,6 +504,7 @@ public: else { invokeBatchApplyRepetitionPenalty(d_logits_, d_repetition_penalties_, + d_penalty_workspace_, d_output_ids_, batch_size_, batch_size_, @@ -559,6 +563,7 @@ public: cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream); invokeBatchApplyRepetitionPenalty(d_logits_batch, d_repetition_penalties_, + d_penalty_workspace_, d_output_ids_, batch_size_, batch_size_,