Skip to content

Commit

Permalink
fix: fix bug of t5 beam search (#410)
Browse files Browse the repository at this point in the history
1. fix bugs of cum_log_probs and output_log_probs of t5 beam search
2. fix bug of beam search for large beam width
3. fix bug of misaligment of beam search penalty kernel
  • Loading branch information
byshiue authored Jan 1, 2023
1 parent 83776ae commit f0b5b86
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 128 deletions.
9 changes: 5 additions & 4 deletions src/fastertransformer/kernels/beam_search_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ __global__ void apply_repetition_penalty(T* logits,

logits += bbid * vocab_size_padded;
extern __shared__ char sbuf[];
T* penalty_logits = reinterpret_cast<T*>(sbuf);
int* penalty_indices = reinterpret_cast<int*>(penalty_logits + step);
const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length;
T* penalty_logits = reinterpret_cast<T*>(sbuf);
// prevent misaligment when sizeof(T) = 2
int* penalty_indices = reinterpret_cast<int*>(sbuf + (sizeof(T) * step + 31) / 32 * 32);
const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length;
if (tid == 0) {
T repet_penalty = static_cast<T>(repetition_penalty);
int prev_id = current_ids[bbid];
Expand Down Expand Up @@ -181,7 +182,7 @@ void invokeAddBiasApplyPenalties(int step,
}

if (repetition_penalty != 1.0f) {
size_t smem_size = (sizeof(T) + sizeof(int)) * step;
size_t smem_size = (sizeof(T) * step + 31 / 32 * 32) + sizeof(int) * step;
dim3 block(256);
dim3 grid(beam_width * local_batch_size);
apply_repetition_penalty<<<grid, block, smem_size, stream>>>(
Expand Down
44 changes: 32 additions & 12 deletions src/fastertransformer/kernels/beam_search_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ void invokeTopkBeamSearch(void* workspace,
FT_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
// log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a token.
const int vocab_size = vocab_size_padded_;
// Beam size should be less than or equal to vocab size.
assert(beam_width <= vocab_size);
// Beam search needs the sequence lengths of beams to apply length penalty.
assert(length_penalty == 0.0f || sequence_lengths != nullptr);
const int max_block_per_beam = 8;
Expand Down Expand Up @@ -789,27 +791,45 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps,
const int batch_size,
const int beam_width)
{
const int bid = blockIdx.x;
int unfinished_idx = 0;
for (int beam_idx = beam_hyps.num_beams[bid]; beam_idx < beam_width; beam_idx++) {
const int bid = blockIdx.x;
const int tgt_start_idx = beam_hyps.num_beams[bid];
if (beam_hyps.is_done[bid]) {
return;
}
for (int i = 0; i < beam_width; i++) {
if (threadIdx.x == 0) {
int bbid = bid * beam_width + unfinished_idx;
const int src_beam_idx = bid * beam_width + i;
const int tgt_beam_idx = bid * beam_width * 2 + i + tgt_start_idx;

const int length = beam_hyps.sequence_lengths_src[bbid];
int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + bbid];
const int length = beam_hyps.sequence_lengths_src[src_beam_idx];

beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] =
beam_hyps.output_ids_src[length * batch_size * beam_width + bid * beam_width + src_beam_idx];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) {
beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] =
beam_hyps.log_probs_src[length * batch_size * beam_width + bid * beam_width + src_beam_idx];
}
int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + src_beam_idx];
// printf("[INFO] i = %d, cum_log_probs: %f \n", i, cum_log_probs[src_beam_idx]);
for (int j = length - 1; j >= 0; j--) {
beam_hyps.output_ids_tgt[(bid * beam_width + beam_idx) * (beam_hyps.max_seq_len - 1) + j] =
// output_ids_tgt need to use max_seq_len + 1 because its shape is
// [bs, beam_width, max_seq_len + 1]
beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] =
beam_hyps.output_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) {
beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] =
beam_hyps.log_probs_src[j * batch_size * beam_width + bid * beam_width + prev_id];
}
prev_id = beam_hyps.parent_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id];
}
beam_hyps.sequence_lengths_tgt[bid * beam_width + beam_idx] = length;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = length;

beam_hyps.num_beams[bid]++;
beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty(
cum_log_probs[src_beam_idx], finished[src_beam_idx] ? length + 1 : length, beam_hyps.length_penalty);
beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx];

beam_hyps.normed_scores[bid * beam_width + beam_idx] =
apply_length_penalty(cum_log_probs[bbid], length, beam_hyps.length_penalty);
beam_hyps.num_beams[bid]++;
}
unfinished_idx++;
}
}

Expand Down
15 changes: 11 additions & 4 deletions src/fastertransformer/kernels/beam_search_topk_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ namespace fastertransformer {
struct BeamHypotheses {
int* output_ids_tgt = nullptr;
int* sequence_lengths_tgt = nullptr;
float* cum_log_probs = nullptr; // cum_log
float* normed_scores = nullptr; // cum_log / (length**length_penalty)
float* log_probs = nullptr; // log probs of each generated token
float* min_normed_scores = nullptr; // record the min normed scores for each batch
int* num_beams = nullptr; // the number of finished beams we collect
bool* is_done = nullptr;

// Used to set inputs
const int* output_ids_src;
const int* parent_ids_src;
const int* sequence_lengths_src;
const int* end_ids;
const int* output_ids_src;
const int* parent_ids_src;
const int* sequence_lengths_src;
const int* end_ids;
const float* log_probs_src;

// some variables for kernels
int step;
Expand All @@ -48,6 +52,9 @@ struct BeamHypotheses {
int local_batch_size;
int max_seq_len;
float length_penalty;

bool early_stopping = true;
bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs
};

template<typename T>
Expand Down
66 changes: 50 additions & 16 deletions src/fastertransformer/kernels/decoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -697,17 +697,27 @@ template void invokePlusScalar(int* buf, const int val, const int size, cudaStre

__global__ void finalize(int* output_ids,
int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
const int* topk_output_ids,
const int* topk_sequence_lengths,
const float* scores,
const float* topk_cum_log_probs,
const float* topk_log_probs,
const int* num_beams,
const int beam_width,
const int max_seq_len)
{
// output_ids: [bs, beam_width, max_seq_len]
// sequence_lengths: [bs, beam_width]
// topk_output_ids: [bs, beam_width, max_seq_len + 1]
// topk_sequence_lengths: [bs, beam_width]
// scores: [bs, beam_width]
// cum_log_probs: [bs, beam_width]
// output_log_probs: [bs, beam_width, max_seq_len]
// topk_output_ids: [bs, 2 * beam_width, max_seq_len + 1]
// topk_sequence_lengths: [bs, 2 * beam_width]
// scores: [bs, 2 * beam_width]
// topk_cum_log_probs: [bs, 2 * beam_width]
// topk_log_probs: [bs, 2 * beam_width, max_seq_len + 1]
// num_beams: [bs]

// This kernel do a sorting for scores first, and then put the topk_output_ids
// into output_ids by the rank of scores.
Expand All @@ -716,19 +726,17 @@ __global__ void finalize(int* output_ids,
extern __shared__ char array[];
int* rank = (int*)(array);
float* s_scores = (float*)(rank + beam_width);
__shared__ float s_max_score;

if (threadIdx.x < beam_width) {
s_scores[threadIdx.x] = scores[blockIdx.x * beam_width + threadIdx.x];
if (threadIdx.x < num_beams[blockIdx.x]) {
s_scores[threadIdx.x] = scores[blockIdx.x * beam_width * 2 + threadIdx.x];
}
__syncthreads();

for (int i = 0; i < beam_width; i++) {
float score = threadIdx.x < beam_width ? s_scores[threadIdx.x] : -FLT_MAX;
float score = threadIdx.x < num_beams[blockIdx.x] ? s_scores[threadIdx.x] : -FLT_MAX;
float max_score = blockReduceMax<float>(score);

if (threadIdx.x == 0) {
for (int j = 0; j < beam_width; j++) {
for (int j = 0; j < beam_width * 2; j++) {
if (s_scores[j] == max_score) {
rank[i] = j;
s_scores[j] = -FLT_MAX;
Expand All @@ -741,32 +749,58 @@ __global__ void finalize(int* output_ids,

if (threadIdx.x < beam_width) {
sequence_lengths[blockIdx.x * beam_width + threadIdx.x] =
topk_sequence_lengths[blockIdx.x * beam_width + rank[threadIdx.x]];
topk_sequence_lengths[blockIdx.x * beam_width * 2 + rank[threadIdx.x]];
if (cum_log_probs != nullptr) {
cum_log_probs[blockIdx.x * beam_width + threadIdx.x] =
topk_cum_log_probs[blockIdx.x * beam_width * 2 + rank[threadIdx.x]];
}
}
for (int beam_idx = 0; beam_idx < beam_width; beam_idx++) {
// start from step 1 to skip the start token
for (int i = threadIdx.x + 1; i < sequence_lengths[blockIdx.x * beam_width + beam_idx]; i += blockDim.x) {
output_ids[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + (i - 1)] =
topk_output_ids[blockIdx.x * beam_width * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1) + i];
for (int i = threadIdx.x; i < sequence_lengths[blockIdx.x * beam_width + beam_idx]; i += blockDim.x) {
output_ids[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] =
topk_output_ids[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1)
+ (i + 1)];
if (output_log_probs != nullptr) {
output_log_probs[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] =
topk_log_probs[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1)
+ (i + 1)];
}
}
}
}

void invokeFinalize(int* output_ids,
int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
const int* topk_output_ids,
const int* topk_sequence_lengths,
const float* scores,
const float* topk_cum_log_probs,
const float* topk_log_probs,
const int* num_beams,
const int beam_width,
const int max_seq_len,
const int batch_size,
cudaStream_t stream)
{
dim3 block(beam_width);
dim3 block(beam_width * 2);
block.x = (block.x + 31) / 32 * 32;
FT_CHECK(block.x < 1024);
finalize<<<batch_size, block, beam_width * sizeof(int) + beam_width * sizeof(float), stream>>>(
output_ids, sequence_lengths, topk_output_ids, topk_sequence_lengths, scores, beam_width, max_seq_len);
finalize<<<batch_size, block, beam_width * sizeof(int) + (beam_width * 2) * sizeof(float), stream>>>(
output_ids,
sequence_lengths,
cum_log_probs,
output_log_probs,
topk_output_ids,
topk_sequence_lengths,
scores,
topk_cum_log_probs,
topk_log_probs,
num_beams,
beam_width,
max_seq_len);
}

} // namespace fastertransformer
5 changes: 5 additions & 0 deletions src/fastertransformer/kernels/decoding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,14 @@ void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream);

void invokeFinalize(int* output_ids,
int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
const int* topk_output_ids,
const int* topk_sequence_lengths,
const float* scores,
const float* topk_cum_log_probs,
const float* topk_log_probs,
const int* num_beams,
const int beam_width,
const int max_seq_len,
const int batch_size,
Expand Down
Loading

0 comments on commit f0b5b86

Please sign in to comment.