Skip to content

Commit

Permalink
repetition penalty output ids
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Jan 28, 2024
1 parent 955fd24 commit 62f4224
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 8 deletions.
21 changes: 18 additions & 3 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,15 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
const int* output_ids,
const int batch_size,
const int vocab_size,
const int* prompt_lengths,
const int* input_lengths,
const int max_input_length,
const int step)
{
const int batch_idx = blockIdx.x;
const float penalty = penalties[batch_idx];
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
const int batch_idx = blockIdx.x;
const float penalty = penalties[batch_idx];
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
const int prompt_length = prompt_lengths != nullptr ? prompt_lengths[batch_idx] : 0;

penalty_workspace += batch_idx * step * 2;
float* penalty_logits = (float*)penalty_workspace;
Expand All @@ -388,6 +390,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
// Phase 1. Find indices to penalize and keep the penalized values.
// A vocab id can appear multiple times but should be penalized once.
for (int index = threadIdx.x; index < step; index += blockDim.x) {
// skip prompt
if (index < prompt_length) {
continue;
}
// Skip the padding tokens in input sequences.
if (index >= input_length && index < max_input_length) {
continue;
Expand All @@ -414,6 +420,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits,

// Phase 2. Replace a logit value by the penalized one.
for (int index = threadIdx.x; index < step; index += blockDim.x) {
// skip prompt
if (index < prompt_length) {
continue;
}
// Skip the padding tokens in input sequences.
if (index >= input_length && index < max_input_length) {
continue;
Expand All @@ -430,6 +440,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* prompt_lengths,
const int* input_lengths,
const int max_input_length,
const int step,
Expand All @@ -451,6 +462,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
output_ids,
batch_size,
vocab_size,
prompt_lengths,
input_lengths,
max_input_length,
step);
Expand All @@ -463,6 +475,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
output_ids,
batch_size,
vocab_size,
prompt_lengths,
input_lengths,
max_input_length,
step);
Expand All @@ -479,6 +492,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* prompt_lengths,
const int* input_lengths,
const int max_input_length,
const int step,
Expand All @@ -492,6 +506,7 @@ template void invokeBatchApplyRepetitionPenalty(half* logits,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* prompt_lengths,
const int* input_lengths,
const int max_input_length,
const int step,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/kernels/sampling_penalty_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* prompt_lengths,
const int* input_lengths,
const int max_input_length,
const int step,
Expand Down
12 changes: 12 additions & 0 deletions src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
allocator_->reMalloc(runtime_logits_buf_, sizeof(T) * batch_size * vocab_size_padded_, false));
skip_decode_buf_ =
reinterpret_cast<bool*>(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false));
prompt_lengths_buf_ =
reinterpret_cast<int*>(allocator_->reMalloc(prompt_lengths_buf_, sizeof(int) * batch_size, false));

// host buffers.
temperature_ = (float*)std::realloc((void*)temperature_, batch_size * sizeof(float));
Expand All @@ -59,6 +61,7 @@ void BaseSamplingLayer<T>::freeBuffer()
allocator_->free((void**)(&temperature_buf_));
allocator_->free((void**)(&repetition_penalty_buf_));
allocator_->free((void**)(&min_lengths_buf_));
allocator_->free((void**)(&prompt_lengths_buf_));
allocator_->free((void**)(&runtime_logits_buf_));
allocator_->free((void**)(&skip_decode_buf_));
std::free(temperature_);
Expand Down Expand Up @@ -164,6 +167,14 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
repetition_penalty_type_ = RepetitionPenaltyType::None;
}

if (runtime_args->isExist("prompt_length")) {
Tensor prompt_lengths = runtime_args->at("prompt_length");
cudaAutoCpy(prompt_lengths_buf_, prompt_lengths.getPtr<int>(), batch_size, stream_);
}
else {
deviceFill(prompt_lengths_buf_, batch_size, 0, stream_);
}

// min_length
if (runtime_args->isExist("min_length")) {
Tensor min_lengths = runtime_args->at("min_length");
Expand Down Expand Up @@ -304,6 +315,7 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
batch_size,
local_batch_size,
vocab_size_padded_,
prompt_lengths_buf_ + ite * local_batch_size,
input_tensors->at("input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {}, nullptr}).getPtr<int>(),
max_input_length,
step,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/layers/sampling_layers/BaseSamplingLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer {
int* min_lengths_buf_ = nullptr;
bool* skip_decode_buf_ = nullptr;
T* runtime_logits_buf_ = nullptr;
int* prompt_lengths_buf_ = nullptr;

float* temperature_ = nullptr;
float* repetition_penalty_ = nullptr;
Expand Down
8 changes: 3 additions & 5 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1066,11 +1066,9 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
}
}

// MinLengthPenalty
if (inputs.isExist("min_length")) {
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
}
// MinLengthPenalty & RepetitionPenalty
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});

// init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Expand Down
2 changes: 2 additions & 0 deletions tests/csrc/unittests/test_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ public:
batch_size_,
batch_size_,
vocab_size_padded_,
nullptr,
d_input_lengths_,
max_input_length_,
step_,
Expand Down Expand Up @@ -568,6 +569,7 @@ public:
batch_size_,
batch_size_,
vocab_size_padded_,
nullptr,
d_input_lengths_,
max_input_length_,
step_,
Expand Down

0 comments on commit 62f4224

Please sign in to comment.