Skip to content

Commit

Permalink
repetition penalty for long context
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Jan 28, 2024
1 parent 3f1c691 commit 955fd24
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 25 deletions.
54 changes: 32 additions & 22 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -367,18 +367,21 @@ template void invokeApplyRepetitionPenalty(half* logits,
template<typename T, RepetitionPenaltyType penalty_type>
__global__ void batchApplyRepetitionPenalty(T* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step)
{
extern __shared__ float penalty_logits[];
int* penalty_indices = (int*)(penalty_logits + step);
const int batch_idx = blockIdx.x;
const float penalty = penalties[batch_idx];
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
const int batch_idx = blockIdx.x;
const float penalty = penalties[batch_idx];
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;

penalty_workspace += batch_idx * step * 2;
float* penalty_logits = (float*)penalty_workspace;
int* penalty_indices = (int*)(penalty_workspace + step);

logits += batch_idx * vocab_size;

Expand Down Expand Up @@ -409,10 +412,6 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
}
}

if (blockDim.x > 32) {
__syncthreads();
}

// Phase 2. Replace a logit value by the penalized one.
for (int index = threadIdx.x; index < step; index += blockDim.x) {
// Skip the padding tokens in input sequences.
Expand All @@ -426,6 +425,7 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
template<typename T>
void invokeBatchApplyRepetitionPenalty(T* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int local_batch_size,
Expand All @@ -442,22 +442,30 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
// output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size).
// input_lengths [local_batch_size], input lengths (optional).
// Padding tokens at [input_length, max_input_length) of input will not be penalized.
dim3 block(min(step, 1024));
dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));
dim3 block(min(step, 1024));
dim3 grid(local_batch_size);
if (penalty_type == RepetitionPenaltyType::Additive) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, 0, stream>>>(logits,
penalties,
penalty_workspace,
output_ids,
batch_size,
vocab_size,
input_lengths,
max_input_length,
step);
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>
<<<grid, block, 0, stream>>>(logits,
penalties,
penalty_workspace,
output_ids,
batch_size,
vocab_size,
input_lengths,
max_input_length,
step);
}
else if (penalty_type == RepetitionPenaltyType::None) {
// do nothing
Expand All @@ -466,6 +474,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,

template void invokeBatchApplyRepetitionPenalty(float* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int local_batch_size,
Expand All @@ -478,6 +487,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits,

template void invokeBatchApplyRepetitionPenalty(half* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int local_batch_size,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/kernels/sampling_penalty_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void invokeApplyRepetitionPenalty(T* logits,
template<typename T>
void invokeBatchApplyRepetitionPenalty(T* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int local_batch_size,
Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void BaseSamplingLayer<T>::freeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
allocator_->free((void**)(&repetition_penalty_workspace_));
allocator_->free((void**)(&temperature_buf_));
allocator_->free((void**)(&repetition_penalty_buf_));
allocator_->free((void**)(&min_lengths_buf_));
Expand Down Expand Up @@ -293,9 +294,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
if (step > 1 && repetition_penalty_type_ != RepetitionPenaltyType::None) {
float default_value = getDefaultPenaltyValue(repetition_penalty_type_);
if (!ALL_OF(repetition_penalty_ + ite * local_batch_size, local_batch_size, float, default_value)) {
repetition_penalty_workspace_ = reinterpret_cast<int*>(allocator_->reMalloc(
repetition_penalty_workspace_, batch_size * step * (sizeof(int) + sizeof(float)), false));
invokeBatchApplyRepetitionPenalty(
logits,
repetition_penalty_buf_ + ite * local_batch_size,
repetition_penalty_workspace_ + ite * local_batch_size,
output_tensors->at("output_ids").getPtrWithOffset<int>(ite * local_batch_size),
batch_size,
local_batch_size,
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/layers/sampling_layers/BaseSamplingLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
size_t vocab_size_;
size_t vocab_size_padded_;

int* repetition_penalty_workspace_;

size_t sampling_workspace_size_;
void* sampling_workspace_ = nullptr;

Expand Down
11 changes: 8 additions & 3 deletions tests/csrc/unittests/test_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include <iostream> // snprintf
#include <math.h> // expf, log
#include <stdexcept>
#include <stdlib.h> // rand
#include <string> // std::string
#include <stdlib.h> // rand
#include <string> // std::string
#include <unordered_map>
#include <vector> // std::vector
#include <vector> // std::vector

#include <cublasLt.h>
#include <cublas_v2.h>
Expand Down Expand Up @@ -386,6 +386,7 @@ protected:
T* d_bias_;
int* d_output_ids_;
int* d_input_lengths_;
int* d_penalty_workspace_;

float* d_repetition_penalties_;

Expand All @@ -410,6 +411,8 @@ protected:
d_bias_ = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * vocab_size_padded_));
d_output_ids_ = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * sequence_length_ * batch_size_));
d_input_lengths_ = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size_));
d_penalty_workspace_ =
reinterpret_cast<int*>(allocator->malloc((sizeof(int) + sizeof(float)) * batch_size_ * step_));

cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream);
cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream);
Expand Down Expand Up @@ -501,6 +504,7 @@ public:
else {
invokeBatchApplyRepetitionPenalty(d_logits_,
d_repetition_penalties_,
d_penalty_workspace_,
d_output_ids_,
batch_size_,
batch_size_,
Expand Down Expand Up @@ -559,6 +563,7 @@ public:
cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream);
invokeBatchApplyRepetitionPenalty(d_logits_batch,
d_repetition_penalties_,
d_penalty_workspace_,
d_output_ids_,
batch_size_,
batch_size_,
Expand Down

0 comments on commit 955fd24

Please sign in to comment.