From f0b5b8631806aedfbe0d844eb9a32202002dd463 Mon Sep 17 00:00:00 2001 From: byshiue Date: Sun, 1 Jan 2023 14:55:12 +0800 Subject: [PATCH] fix: fix bug of t5 beam search (#410) 1. fix bugs of cum_log_probs and output_log_probs of t5 beam search 2. fix bug of beam search for large beam width 3. fix bug of misaligment of beam search penalty kernel --- .../kernels/beam_search_penalty_kernels.cu | 9 +- .../kernels/beam_search_topk_kernels.cu | 44 ++++++--- .../kernels/beam_search_topk_kernels.h | 15 +++- .../kernels/decoding_kernels.cu | 66 ++++++++++---- .../kernels/decoding_kernels.h | 5 ++ .../online_softmax_beamsearch_kernels.cu | 89 +++++++++++-------- .../layers/DynamicDecodeLayer.cc | 12 +-- .../beam_search_layers/BaseBeamSearchLayer.cu | 3 +- .../beam_search_layers/BeamSearchLayer.cu | 6 +- .../OnlineBeamSearchLayer.cu | 16 ++-- .../models/bart/BartDecoding.cc | 52 ++++++++--- src/fastertransformer/models/t5/T5Decoding.cc | 64 ++++++++----- src/fastertransformer/th_op/t5/T5DecodingOp.h | 4 +- 13 files changed, 257 insertions(+), 128 deletions(-) diff --git a/src/fastertransformer/kernels/beam_search_penalty_kernels.cu b/src/fastertransformer/kernels/beam_search_penalty_kernels.cu index 4cf8727b2..2e29178f9 100644 --- a/src/fastertransformer/kernels/beam_search_penalty_kernels.cu +++ b/src/fastertransformer/kernels/beam_search_penalty_kernels.cu @@ -108,9 +108,10 @@ __global__ void apply_repetition_penalty(T* logits, logits += bbid * vocab_size_padded; extern __shared__ char sbuf[]; - T* penalty_logits = reinterpret_cast(sbuf); - int* penalty_indices = reinterpret_cast(penalty_logits + step); - const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length; + T* penalty_logits = reinterpret_cast(sbuf); + // prevent misaligment when sizeof(T) = 2 + int* penalty_indices = reinterpret_cast(sbuf + (sizeof(T) * step + 31) / 32 * 32); + const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length; if (tid == 0) { T repet_penalty = static_cast(repetition_penalty); int prev_id = current_ids[bbid]; @@ -181,7 +182,7 @@ void invokeAddBiasApplyPenalties(int step, } if (repetition_penalty != 1.0f) { - size_t smem_size = (sizeof(T) + sizeof(int)) * step; + size_t smem_size = (sizeof(T) * step + 31 / 32 * 32) + sizeof(int) * step; dim3 block(256); dim3 grid(beam_width * local_batch_size); apply_repetition_penalty<<>>( diff --git a/src/fastertransformer/kernels/beam_search_topk_kernels.cu b/src/fastertransformer/kernels/beam_search_topk_kernels.cu index b6fe9416f..a5fc4613e 100644 --- a/src/fastertransformer/kernels/beam_search_topk_kernels.cu +++ b/src/fastertransformer/kernels/beam_search_topk_kernels.cu @@ -598,6 +598,8 @@ void invokeTopkBeamSearch(void* workspace, FT_LOG_DEBUG("%s", __PRETTY_FUNCTION__); // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a token. const int vocab_size = vocab_size_padded_; + // Beam size should be less than or equal to vocab size. + assert(beam_width <= vocab_size); // Beam search needs the sequence lengths of beams to apply length penalty. assert(length_penalty == 0.0f || sequence_lengths != nullptr); const int max_block_per_beam = 8; @@ -789,27 +791,45 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const int batch_size, const int beam_width) { - const int bid = blockIdx.x; - int unfinished_idx = 0; - for (int beam_idx = beam_hyps.num_beams[bid]; beam_idx < beam_width; beam_idx++) { + const int bid = blockIdx.x; + const int tgt_start_idx = beam_hyps.num_beams[bid]; + if (beam_hyps.is_done[bid]) { + return; + } + for (int i = 0; i < beam_width; i++) { if (threadIdx.x == 0) { - int bbid = bid * beam_width + unfinished_idx; + const int src_beam_idx = bid * beam_width + i; + const int tgt_beam_idx = bid * beam_width * 2 + i + tgt_start_idx; - const int length = beam_hyps.sequence_lengths_src[bbid]; - int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + bbid]; + const int length = beam_hyps.sequence_lengths_src[src_beam_idx]; + + beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] = + beam_hyps.output_ids_src[length * batch_size * beam_width + bid * beam_width + src_beam_idx]; + if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { + beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] = + beam_hyps.log_probs_src[length * batch_size * beam_width + bid * beam_width + src_beam_idx]; + } + int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + src_beam_idx]; + // printf("[INFO] i = %d, cum_log_probs: %f \n", i, cum_log_probs[src_beam_idx]); for (int j = length - 1; j >= 0; j--) { - beam_hyps.output_ids_tgt[(bid * beam_width + beam_idx) * (beam_hyps.max_seq_len - 1) + j] = + // output_ids_tgt need to use max_seq_len + 1 because its shape is + // [bs, beam_width, max_seq_len + 1] + beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] = beam_hyps.output_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id]; + if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { + beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] = + beam_hyps.log_probs_src[j * batch_size * beam_width + bid * beam_width + prev_id]; + } prev_id = beam_hyps.parent_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id]; } - beam_hyps.sequence_lengths_tgt[bid * beam_width + beam_idx] = length; + beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = length; - beam_hyps.num_beams[bid]++; + beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty( + cum_log_probs[src_beam_idx], finished[src_beam_idx] ? length + 1 : length, beam_hyps.length_penalty); + beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx]; - beam_hyps.normed_scores[bid * beam_width + beam_idx] = - apply_length_penalty(cum_log_probs[bbid], length, beam_hyps.length_penalty); + beam_hyps.num_beams[bid]++; } - unfinished_idx++; } } diff --git a/src/fastertransformer/kernels/beam_search_topk_kernels.h b/src/fastertransformer/kernels/beam_search_topk_kernels.h index 80e2180fa..209d77cd1 100644 --- a/src/fastertransformer/kernels/beam_search_topk_kernels.h +++ b/src/fastertransformer/kernels/beam_search_topk_kernels.h @@ -31,15 +31,19 @@ namespace fastertransformer { struct BeamHypotheses { int* output_ids_tgt = nullptr; int* sequence_lengths_tgt = nullptr; + float* cum_log_probs = nullptr; // cum_log float* normed_scores = nullptr; // cum_log / (length**length_penalty) + float* log_probs = nullptr; // log probs of each generated token float* min_normed_scores = nullptr; // record the min normed scores for each batch int* num_beams = nullptr; // the number of finished beams we collect + bool* is_done = nullptr; // Used to set inputs - const int* output_ids_src; - const int* parent_ids_src; - const int* sequence_lengths_src; - const int* end_ids; + const int* output_ids_src; + const int* parent_ids_src; + const int* sequence_lengths_src; + const int* end_ids; + const float* log_probs_src; // some variables for kernels int step; @@ -48,6 +52,9 @@ struct BeamHypotheses { int local_batch_size; int max_seq_len; float length_penalty; + + bool early_stopping = true; + bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs }; template diff --git a/src/fastertransformer/kernels/decoding_kernels.cu b/src/fastertransformer/kernels/decoding_kernels.cu index bed253672..934dcf651 100644 --- a/src/fastertransformer/kernels/decoding_kernels.cu +++ b/src/fastertransformer/kernels/decoding_kernels.cu @@ -697,17 +697,27 @@ template void invokePlusScalar(int* buf, const int val, const int size, cudaStre __global__ void finalize(int* output_ids, int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, const int* topk_output_ids, const int* topk_sequence_lengths, const float* scores, + const float* topk_cum_log_probs, + const float* topk_log_probs, + const int* num_beams, const int beam_width, const int max_seq_len) { // output_ids: [bs, beam_width, max_seq_len] // sequence_lengths: [bs, beam_width] - // topk_output_ids: [bs, beam_width, max_seq_len + 1] - // topk_sequence_lengths: [bs, beam_width] - // scores: [bs, beam_width] + // cum_log_probs: [bs, beam_width] + // output_log_probs: [bs, beam_width, max_seq_len] + // topk_output_ids: [bs, 2 * beam_width, max_seq_len + 1] + // topk_sequence_lengths: [bs, 2 * beam_width] + // scores: [bs, 2 * beam_width] + // topk_cum_log_probs: [bs, 2 * beam_width] + // topk_log_probs: [bs, 2 * beam_width, max_seq_len + 1] + // num_beams: [bs] // This kernel do a sorting for scores first, and then put the topk_output_ids // into output_ids by the rank of scores. @@ -716,19 +726,17 @@ __global__ void finalize(int* output_ids, extern __shared__ char array[]; int* rank = (int*)(array); float* s_scores = (float*)(rank + beam_width); - __shared__ float s_max_score; - - if (threadIdx.x < beam_width) { - s_scores[threadIdx.x] = scores[blockIdx.x * beam_width + threadIdx.x]; + if (threadIdx.x < num_beams[blockIdx.x]) { + s_scores[threadIdx.x] = scores[blockIdx.x * beam_width * 2 + threadIdx.x]; } __syncthreads(); for (int i = 0; i < beam_width; i++) { - float score = threadIdx.x < beam_width ? s_scores[threadIdx.x] : -FLT_MAX; + float score = threadIdx.x < num_beams[blockIdx.x] ? s_scores[threadIdx.x] : -FLT_MAX; float max_score = blockReduceMax(score); if (threadIdx.x == 0) { - for (int j = 0; j < beam_width; j++) { + for (int j = 0; j < beam_width * 2; j++) { if (s_scores[j] == max_score) { rank[i] = j; s_scores[j] = -FLT_MAX; @@ -741,32 +749,58 @@ __global__ void finalize(int* output_ids, if (threadIdx.x < beam_width) { sequence_lengths[blockIdx.x * beam_width + threadIdx.x] = - topk_sequence_lengths[blockIdx.x * beam_width + rank[threadIdx.x]]; + topk_sequence_lengths[blockIdx.x * beam_width * 2 + rank[threadIdx.x]]; + if (cum_log_probs != nullptr) { + cum_log_probs[blockIdx.x * beam_width + threadIdx.x] = + topk_cum_log_probs[blockIdx.x * beam_width * 2 + rank[threadIdx.x]]; + } } for (int beam_idx = 0; beam_idx < beam_width; beam_idx++) { // start from step 1 to skip the start token - for (int i = threadIdx.x + 1; i < sequence_lengths[blockIdx.x * beam_width + beam_idx]; i += blockDim.x) { - output_ids[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + (i - 1)] = - topk_output_ids[blockIdx.x * beam_width * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1) + i]; + for (int i = threadIdx.x; i < sequence_lengths[blockIdx.x * beam_width + beam_idx]; i += blockDim.x) { + output_ids[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] = + topk_output_ids[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1) + + (i + 1)]; + if (output_log_probs != nullptr) { + output_log_probs[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] = + topk_log_probs[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1) + + (i + 1)]; + } } } } void invokeFinalize(int* output_ids, int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, const int* topk_output_ids, const int* topk_sequence_lengths, const float* scores, + const float* topk_cum_log_probs, + const float* topk_log_probs, + const int* num_beams, const int beam_width, const int max_seq_len, const int batch_size, cudaStream_t stream) { - dim3 block(beam_width); + dim3 block(beam_width * 2); block.x = (block.x + 31) / 32 * 32; FT_CHECK(block.x < 1024); - finalize<<>>( - output_ids, sequence_lengths, topk_output_ids, topk_sequence_lengths, scores, beam_width, max_seq_len); + finalize<<>>( + output_ids, + sequence_lengths, + cum_log_probs, + output_log_probs, + topk_output_ids, + topk_sequence_lengths, + scores, + topk_cum_log_probs, + topk_log_probs, + num_beams, + beam_width, + max_seq_len); } } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/decoding_kernels.h b/src/fastertransformer/kernels/decoding_kernels.h index 9c71e2eaf..f40f9d7b1 100644 --- a/src/fastertransformer/kernels/decoding_kernels.h +++ b/src/fastertransformer/kernels/decoding_kernels.h @@ -154,9 +154,14 @@ void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream); void invokeFinalize(int* output_ids, int* sequence_lengths, + float* cum_log_probs, + float* output_log_probs, const int* topk_output_ids, const int* topk_sequence_lengths, const float* scores, + const float* topk_cum_log_probs, + const float* topk_log_probs, + const int* num_beams, const int beam_width, const int max_seq_len, const int batch_size, diff --git a/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu b/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu index c599bff11..78c6f4a15 100644 --- a/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu +++ b/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu @@ -122,11 +122,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ int selected_beams; - __shared__ bool is_stop; + __shared__ float old_cum_log_probs[MAX_K]; if (thread_id == 0) { selected_beams = 0; - is_stop = false; + } + if (thread_id < K) { + old_cum_log_probs[thread_id] = v[vector_id * K + thread_id]; } __syncthreads(); if (beam_hyps.num_beams != nullptr) { @@ -182,32 +184,36 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { // end the tracing and exist this for loop selected_beams = K; - is_stop = true; break; } else { // find the beam index which's score = min_normed_score, erase it. for (int j = 0; j < K; j++) { - if (beam_hyps.normed_scores[global_batch_idx * K + j] + if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] == beam_hyps.min_normed_scores[global_batch_idx]) { beam_idx = j; beam_hyps.num_beams[global_batch_idx]--; - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * K + j] = normed_score; + beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; + beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score; for (int l = 0; l < K; l++) { beam_hyps.min_normed_scores[global_batch_idx] = min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * K + l]); + beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]); } break; } } } } - const int tgt_id_offset = ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * K + beam_idx) - * (beam_hyps.max_seq_len); + const int tgt_id_offset = + ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx) + * (beam_hyps.max_seq_len); beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = beam_hyps.end_ids[vector_id]; + if (beam_hyps.log_probs != nullptr) { + beam_hyps.log_probs[tgt_id_offset + beam_hyps.step] = + (float)y[total.p[i]] - old_cum_log_probs[(x[total.p[i]] / vocab_size) % K]; + } int prev_id = (x[total.p[i]] / vocab_size) % K; for (int j = beam_hyps.step - 1; j >= 0; j--) { @@ -215,21 +221,26 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* + beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id; beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; - prev_id = beam_hyps.parent_ids_src[src_idx]; + if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { + beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx]; + } + prev_id = beam_hyps.parent_ids_src[src_idx]; } - const int tgt_beam_idx = global_batch_idx * K + beam_idx; + const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx; beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; beam_hyps.normed_scores[tgt_beam_idx] = normed_score; beam_hyps.min_normed_scores[global_batch_idx] = min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); beam_hyps.num_beams[global_batch_idx]++; + beam_hyps.cum_log_probs[tgt_beam_idx] = (float)y[total.p[i]]; } } - else if (i < 2 * K) { + else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K)) { z[selected_beams] = x[total.p[i]]; if (output_log_probs != nullptr) { - output_log_probs[vector_id * K + i] = (float)y[total.p[i]] - v[(z[i] / vocab_size) % K]; + output_log_probs[vector_id * K + selected_beams] = + (float)y[total.p[i]] - old_cum_log_probs[(z[selected_beams] / vocab_size) % K]; } v[selected_beams] = (float)y[total.p[i]]; selected_beams++; @@ -240,6 +251,14 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* } } } + if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr) { + if (beam_hyps.num_beams[blockIdx.x] < K) { + beam_hyps.is_done[blockIdx.x] = false; + } + else if (beam_hyps.early_stopping) { + beam_hyps.is_done[blockIdx.x] = true; + } + } } struct __align__(8) MD @@ -418,7 +437,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ #endif if (thread_id == 0) { - for (int i = 0; i < K; i++) { + for (int i = 0; i < 2 * K; i++) { reinterpret_cast(buf_s)[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id buf_s[MAX_K + i] = total.topk.u[i]; } @@ -426,9 +445,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ buf_s[2 * MAX_K + 1] = total.md.m; } __syncthreads(); - if (threadIdx.x < PACKED_TOP_KMD_SIZE) { - t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + threadIdx.x] = - buf_s[threadIdx.x]; + for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) { + t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; } } @@ -468,7 +486,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta if (threadIdx.x < parts_per_beam) { float* b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; - for (int i = 0; i < K; i++) { + for (int i = 0; i < 2 * K; i++) { partial.topk.p[i] = reinterpret_cast(b_s)[i]; partial.topk.u[i] = b_s[MAX_K + i]; } @@ -480,14 +498,14 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); if (thread_id == 0) { - z += vector_id * K; - v += vector_id * K; + z += vector_id * 2 * K; + v += vector_id * 2 * K; c += vector_id; float d_total_log = logf(total.md.d); for (int i = 0; i < MAX_K; ++i) { float val = (float)total.topk.u[i] - total.md.m - d_total_log; - if (i < K) { + if (i < 2 * K) { z[i] = total.topk.p[i]; v[i] = (float)val + (float)c[0]; } @@ -552,11 +570,11 @@ void topK_softMax_kernelLauncher(const T* log_probs, // const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; assert(temp_storage_size % 2 == 0); - assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width); + assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2); // Beam search needs the sequence lengths of beams to apply length penalty. assert(length_penalty == 0.0f || sequence_lengths != nullptr); - const int topk_buf_offset = ceil(batch_size * beam_width * beam_width / 4.) * 4; + const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4; int* topk_tmp_id_buf = reinterpret_cast(temp_storage); T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + topk_buf_offset); float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + topk_buf_offset); @@ -569,16 +587,18 @@ void topK_softMax_kernelLauncher(const T* log_probs, voc_parts = std::min(128, voc_parts); // we implement up to 128 } dim3 grid(batch_size * beam_width, voc_parts); - cudaFuncSetAttribute(beam_online_softmax_topk_stage1_kernel, + cudaFuncSetAttribute(beam_online_softmax_topk_stage1_kernel, cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1); - beam_online_softmax_topk_stage1_kernel + beam_online_softmax_topk_stage1_kernel <<>>(log_probs, bias, finished, tmp_buffer, vocab_size, beam_width, end_ids); + sync_check_cuda_error(); #endif if (beam_width > 1) { #ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX - beam_online_softmax_topk_stage2_kernelLauncher( + beam_online_softmax_topk_stage2_kernelLauncher( tmp_buffer, cum_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, batch_size, beam_width, voc_parts, stream); + sync_check_cuda_error(); #else beam_online_softmax_topk_kernel <<>>(log_probs, @@ -606,11 +626,12 @@ void topK_softMax_kernelLauncher(const T* log_probs, finished, sequence_lengths, *beam_hyps, - beam_width * beam_width, + beam_width * beam_width * 2, beam_width, vocab_size, length_penalty, diversity_rate); + sync_check_cuda_error(); #endif } else { @@ -627,7 +648,7 @@ void topK_softMax_kernelLauncher(const T* log_probs, } #define CASE_K(K, MAX_K) \ - case K: \ + case K ... MAX_K: \ topK_softMax_kernelLauncher(log_probs, \ bias, \ finished, \ @@ -667,13 +688,11 @@ void invokeTopkSoftMax(const T* log_probs, cudaStream_t stream) { switch (beam_width) { - CASE_K(1, 1); - CASE_K(2, 2); - CASE_K(3, 3); - CASE_K(4, 4); - CASE_K(8, 8); - CASE_K(16, 16); - CASE_K(32, 32); + CASE_K(1, 4); + CASE_K(5, 8); + CASE_K(9, 16); + CASE_K(17, 32); + CASE_K(33, 64); default: throw std::runtime_error(fmtstr("Topk kernel of beam search does not support beam_width=%d", beam_width)); } diff --git a/src/fastertransformer/layers/DynamicDecodeLayer.cc b/src/fastertransformer/layers/DynamicDecodeLayer.cc index 39086b900..71799ce0d 100644 --- a/src/fastertransformer/layers/DynamicDecodeLayer.cc +++ b/src/fastertransformer/layers/DynamicDecodeLayer.cc @@ -343,14 +343,7 @@ void DynamicDecodeLayer::forward(TensorMap* output_tensors, TensorMap* input_ } if (output_tensors->isExist("output_log_probs")) { - size_t step_offset = - (step - input_tensors->at("max_input_length").getVal()) * batch_size * beam_width; - dynamic_decode_output_tensors.insert( - {"output_log_probs", - Tensor{MEMORY_GPU, - TYPE_FP32, - {dynamic_decode_batch_size * beam_width}, - output_tensors->at("output_log_probs").getPtrWithOffset(step_offset + dynamic_id_offset)}}); + dynamic_decode_output_tensors.insert({"output_log_probs", output_tensors->at("output_log_probs")}); } dynamic_decode_input_tensors.insert({"src_cache_indirection", input_tensors->at("src_cache_indirection")}); @@ -362,13 +355,14 @@ void DynamicDecodeLayer::forward(TensorMap* output_tensors, TensorMap* input_ FT_CHECK_WITH_INFO(dynamic_decode_output_tensors.isExist("cum_log_probs"), "cum_log_probs should be provided in beam search."); - if (beam_width < 16 + if (true || beam_width < 16 || (output_tensors->isExist("beam_hyps") && input_tensors->getVal("beam_search_diversity_rate", 0.0f) != 0.0f)) { // only online_beamsearch_decode_ support beam_search_diversity_rate when beam_hyps is used online_beamsearch_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); } else { + FT_CHECK(false); // deprecate this module beamsearch_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); } } // end of dynamic_ite diff --git a/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu index f6fc8111b..7a5b4fcc0 100644 --- a/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.cu @@ -201,7 +201,7 @@ void BaseBeamSearchLayer::forward(TensorMap* output_tensors, TensorMap* input // parent_ids [max_seq_len, batch_size * beam_width] // sequence_length [local_batch_size * beam_width], optional // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - // output_log_probs [local_batch_size * beam_width], optional + // output_log_probs [max_seq_len, batch_size, beam_width], optional // beam_hyps, optional FT_CHECK(input_tensors->size() >= 7); @@ -220,6 +220,7 @@ void BaseBeamSearchLayer::forward(TensorMap* output_tensors, TensorMap* input input_tensors->isExist("repetition_penalty") ? input_tensors->at("repetition_penalty").getVal() : 1.0f; const T* embedding_bias = input_tensors->isExist("embedding_bias") ? input_tensors->at("embedding_bias").getPtr() : nullptr; + invokeAddBiasApplyPenalties( step, input_tensors->at("logits").getPtr(), diff --git a/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu index 98c351cc8..f96e7ec3c 100644 --- a/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/BeamSearchLayer.cu @@ -191,7 +191,7 @@ void BeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMap* inp // parent_ids [max_seq_len, batch_size * beam_width] // sequence_length [local_batch_size * beam_width] // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - // output_log_probs [local_batch_size * beam_width], optional + // output_log_probs [max_seq_len, batch_size * beam_width], optional // beam_hyps, optional FT_CHECK(input_tensors->size() >= 7); @@ -252,11 +252,9 @@ void BeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMap* inp stream_); sync_check_cuda_error(); - float* output_log_probs = - output_tensors->isExist("output_log_probs") ? output_tensors->at("output_log_probs").getPtr() : nullptr; invokeUpdateStates(float_log_prob_buf_, output_tensors->at("cum_log_probs").getPtr(), - output_log_probs, + output_tensors->getPtrWithOffset("output_log_probs", id_offset, nullptr), output_tensors->at("finished").getPtr(), output_tensors->at("parent_ids").getPtrWithOffset(id_offset), output_tensors->at("sequence_length").getPtr(), diff --git a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu index 523f9a12e..8ab8af74e 100644 --- a/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu +++ b/src/fastertransformer/layers/beam_search_layers/OnlineBeamSearchLayer.cu @@ -108,7 +108,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa // parent_ids [max_seq_len, batch_size * beam_width] // sequence_length [local_batch_size * beam_width] // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - // output_log_probs [local_batch_size * beam_width] + // output_log_probs [max_seq_len, batch_size, beam_width] FT_CHECK(input_tensors->size() >= 7); FT_CHECK(output_tensors->size() >= 6); @@ -124,8 +124,6 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa const float length_penalty = input_tensors->isExist("len_penalty") ? input_tensors->at("len_penalty").getVal() : 0.0f; - float* output_log_probs = - output_tensors->isExist("output_log_probs") ? output_tensors->at("output_log_probs").getPtr() : nullptr; const int id_offset = step * batch_size * beam_width + local_batch_size * ite * beam_width; BeamHypotheses beam_hyps; @@ -139,6 +137,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa beam_hyps.output_ids_src = output_tensors->at("output_ids").getPtr(); beam_hyps.parent_ids_src = output_tensors->at("parent_ids").getPtr(); beam_hyps.sequence_lengths_src = output_tensors->at("sequence_length").getPtr(); + beam_hyps.log_probs_src = output_tensors->getPtr("output_log_probs", nullptr); beam_hyps.length_penalty = length_penalty; beam_hyps.end_ids = input_tensors->at("end_id").getPtr(); } @@ -148,7 +147,7 @@ void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMa output_tensors->at("finished").getPtr(), output_tensors->at("sequence_length").getPtr(), output_tensors->at("cum_log_probs").getPtr(), - output_log_probs, + output_tensors->getPtrWithOffset("output_log_probs", id_offset, nullptr), output_tensors->at("output_ids").getPtrWithOffset(id_offset), topk_softmax_workspace_, topk_softmax_workspace_size_, @@ -186,12 +185,15 @@ template void OnlineBeamSearchLayer::allocateBuffer(size_t batch_size, size_t beam_width) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); + // we need to check 2 * beam_width candidates each time + // 64 is the max beam width we support now. topk_softmax_workspace_size_ = - (size_t)(ceil(batch_size * beam_width * beam_width / 4.) * 4 * 2 - + ceil(batch_size * beam_width * SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * MAX_K + 2) / 4.) * 4); + (size_t)(ceil(batch_size * 64 * (64 * 2) / 4.) * 4 * 2 + + ceil(batch_size * (64 * 2) * SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * (MAX_K * 2) + 2) / 4.) + * 4); topk_softmax_workspace_ = reinterpret_cast( - allocator_->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, false)); + allocator_->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, true)); is_allocate_buffer_ = true; } diff --git a/src/fastertransformer/models/bart/BartDecoding.cc b/src/fastertransformer/models/bart/BartDecoding.cc index 24793a459..129b193d3 100644 --- a/src/fastertransformer/models/bart/BartDecoding.cc +++ b/src/fastertransformer/models/bart/BartDecoding.cc @@ -122,20 +122,29 @@ void BartDecoding::allocateBuffer( parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false)); output_ids_transpose_buf_ = - (int*)(allocator_->reMalloc(output_ids_transpose_buf_, sizeof(int) * batchxbeam * max_seq_len, false)); + (int*)(allocator_->reMalloc(output_ids_transpose_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false)); output_log_probs_buf_ = - (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false)); + (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * (max_seq_len + 1), false)); if (using_beam_hyps) { + // Let beam_hyps_ can record at most 2*beam_width because we + // may find beam_width finished candidates during generation, + // and may compare them with unfinifhsed another beam_width candidates + // during finalization. beam_hyps_.output_ids_tgt = (int*)allocator_->reMalloc( - beam_hyps_.output_ids_tgt, sizeof(int) * batch_size * beam_width * (max_seq_len + 1), true); - beam_hyps_.sequence_lengths_tgt = - (int*)allocator_->reMalloc(beam_hyps_.sequence_lengths_tgt, sizeof(int) * batch_size * beam_width, true); + beam_hyps_.output_ids_tgt, sizeof(int) * batch_size * beam_width * 2 * (max_seq_len + 1), true); + beam_hyps_.sequence_lengths_tgt = (int*)allocator_->reMalloc( + beam_hyps_.sequence_lengths_tgt, sizeof(int) * batch_size * beam_width * 2, true); + beam_hyps_.cum_log_probs = + (float*)allocator_->reMalloc(beam_hyps_.cum_log_probs, sizeof(float) * batch_size * beam_width * 2, true); beam_hyps_.normed_scores = - (float*)allocator_->reMalloc(beam_hyps_.normed_scores, sizeof(float) * batch_size * beam_width, true); + (float*)allocator_->reMalloc(beam_hyps_.normed_scores, sizeof(float) * batch_size * beam_width * 2, true); + beam_hyps_.log_probs = (float*)allocator_->reMalloc( + beam_hyps_.log_probs, sizeof(float) * batch_size * beam_width * 2 * (max_seq_len + 1), true); beam_hyps_.min_normed_scores = (float*)allocator_->reMalloc(beam_hyps_.min_normed_scores, sizeof(float) * batch_size, true); beam_hyps_.num_beams = (int*)allocator_->reMalloc(beam_hyps_.num_beams, sizeof(int) * batch_size, true); + beam_hyps_.is_done = (bool*)allocator_->reMalloc(beam_hyps_.is_done, sizeof(bool) * batch_size, true); } is_allocate_buffer_ = true; } @@ -183,9 +192,12 @@ void BartDecoding::freeBuffer() if (using_beam_hyps) { allocator_->free((void**)(&beam_hyps_.output_ids_tgt)); allocator_->free((void**)(&beam_hyps_.sequence_lengths_tgt)); + allocator_->free((void**)(&beam_hyps_.cum_log_probs)); allocator_->free((void**)(&beam_hyps_.normed_scores)); + allocator_->free((void**)(&beam_hyps_.log_probs)); allocator_->free((void**)(&beam_hyps_.min_normed_scores)); allocator_->free((void**)(&beam_hyps_.num_beams)); + allocator_->free((void**)(&beam_hyps_.is_done)); } is_allocate_buffer_ = false; } @@ -816,6 +828,7 @@ void BartDecoding::forward(TensorMap* output_tensors, beam_hyps_.sequence_lengths_src = sequence_lengths; beam_hyps_.parent_ids_src = parent_ids_buf_; beam_hyps_.output_ids_src = output_ids_buf_; + beam_hyps_.log_probs_src = output_log_probs_buf_; beam_hyps_.max_seq_len = max_seq_len; beam_hyps_.length_penalty = input_tensors->at("len_penalty").getVal(); @@ -824,9 +837,14 @@ void BartDecoding::forward(TensorMap* output_tensors, invokeFinalize(output_tensors->at("output_ids").getPtr(), output_tensors->at("sequence_length").getPtr(), + output_tensors->getPtr("cum_log_probs", nullptr), + output_tensors->getPtr("output_log_probs", nullptr), beam_hyps_.output_ids_tgt, beam_hyps_.sequence_lengths_tgt, beam_hyps_.normed_scores, + beam_hyps_.cum_log_probs, + beam_hyps_.log_probs, + beam_hyps_.num_beams, beam_width, max_seq_len, batch_size, @@ -872,12 +890,22 @@ void BartDecoding::forward(TensorMap* output_tensors, stream_); } - // Return the cumulative log probability if requested. - if (output_tensors->isExist("cum_log_probs")) { - Tensor cum_log_probs = output_tensors->at("cum_log_probs"); - FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, - "The shape of cum_log_probs does not match with batch_size x beam_width."); - cudaD2Dcpy(cum_log_probs.getPtr(), cum_log_probs_, batch_size * beam_width); + // Return the cumulative log probability and log probability if requested. + if (!using_beam_hyps) { + if (output_tensors->isExist("output_log_probs")) { + invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), + output_log_probs_buf_, + max_seq_len, + batch_size * beam_width, + 1, + stream_); + } + if (output_tensors->isExist("cum_log_probs")) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, + "The shape of cum_log_probs does not match with batch_size x beam_width."); + cudaD2Dcpy(cum_log_probs.getPtr(), cum_log_probs_, batch_size * beam_width); + } } } diff --git a/src/fastertransformer/models/t5/T5Decoding.cc b/src/fastertransformer/models/t5/T5Decoding.cc index c68ee570b..e81ba998a 100644 --- a/src/fastertransformer/models/t5/T5Decoding.cc +++ b/src/fastertransformer/models/t5/T5Decoding.cc @@ -124,20 +124,29 @@ void T5Decoding::allocateBuffer( parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false)); output_ids_transpose_buf_ = - (int*)(allocator_->reMalloc(output_ids_transpose_buf_, sizeof(int) * batchxbeam * max_seq_len, false)); + (int*)(allocator_->reMalloc(output_ids_transpose_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false)); output_log_probs_buf_ = - (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false)); + (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * (max_seq_len + 1), false)); if (using_beam_hyps) { + // Let beam_hyps_ can record at most 2*beam_width because we + // may find beam_width finished candidates during generation, + // and may compare them with unfinifhsed another beam_width candidates + // during finalization. beam_hyps_.output_ids_tgt = (int*)allocator_->reMalloc( - beam_hyps_.output_ids_tgt, sizeof(int) * batch_size * beam_width * (max_seq_len + 1), true); - beam_hyps_.sequence_lengths_tgt = - (int*)allocator_->reMalloc(beam_hyps_.sequence_lengths_tgt, sizeof(int) * batch_size * beam_width, true); + beam_hyps_.output_ids_tgt, sizeof(int) * batch_size * beam_width * 2 * (max_seq_len + 1), true); + beam_hyps_.sequence_lengths_tgt = (int*)allocator_->reMalloc( + beam_hyps_.sequence_lengths_tgt, sizeof(int) * batch_size * beam_width * 2, true); + beam_hyps_.cum_log_probs = + (float*)allocator_->reMalloc(beam_hyps_.cum_log_probs, sizeof(float) * batch_size * beam_width * 2, true); beam_hyps_.normed_scores = - (float*)allocator_->reMalloc(beam_hyps_.normed_scores, sizeof(float) * batch_size * beam_width, true); + (float*)allocator_->reMalloc(beam_hyps_.normed_scores, sizeof(float) * batch_size * beam_width * 2, true); + beam_hyps_.log_probs = (float*)allocator_->reMalloc( + beam_hyps_.log_probs, sizeof(float) * batch_size * beam_width * 2 * (max_seq_len + 1), true); beam_hyps_.min_normed_scores = (float*)allocator_->reMalloc(beam_hyps_.min_normed_scores, sizeof(float) * batch_size, true); beam_hyps_.num_beams = (int*)allocator_->reMalloc(beam_hyps_.num_beams, sizeof(int) * batch_size, true); + beam_hyps_.is_done = (bool*)allocator_->reMalloc(beam_hyps_.is_done, sizeof(bool) * batch_size, true); } is_allocate_buffer_ = true; } @@ -185,9 +194,12 @@ void T5Decoding::freeBuffer() if (using_beam_hyps) { allocator_->free((void**)(&beam_hyps_.output_ids_tgt)); allocator_->free((void**)(&beam_hyps_.sequence_lengths_tgt)); + allocator_->free((void**)(&beam_hyps_.cum_log_probs)); allocator_->free((void**)(&beam_hyps_.normed_scores)); + allocator_->free((void**)(&beam_hyps_.log_probs)); allocator_->free((void**)(&beam_hyps_.min_normed_scores)); allocator_->free((void**)(&beam_hyps_.num_beams)); + allocator_->free((void**)(&beam_hyps_.is_done)); } is_allocate_buffer_ = false; } @@ -854,17 +866,23 @@ void T5Decoding::forward(TensorMap* output_tensors, beam_hyps_.sequence_lengths_src = sequence_lengths; beam_hyps_.parent_ids_src = parent_ids_buf_; beam_hyps_.output_ids_src = output_ids_buf_; + beam_hyps_.log_probs_src = output_log_probs_buf_; beam_hyps_.max_seq_len = max_seq_len; beam_hyps_.length_penalty = input_tensors->at("len_penalty").getVal(); invokeInsertUnfinishedPath(beam_hyps_, finished_buf_, cum_log_probs_, batch_size, beam_width, stream_); sync_check_cuda_error(); - invokeFinalize(output_tensors->at("output_ids").getPtr(), - output_tensors->at("sequence_length").getPtr(), + invokeFinalize(output_tensors->getPtr("output_ids"), + output_tensors->getPtr("sequence_length"), + output_tensors->getPtr("cum_log_probs", nullptr), + output_tensors->getPtr("output_log_probs", nullptr), beam_hyps_.output_ids_tgt, beam_hyps_.sequence_lengths_tgt, beam_hyps_.normed_scores, + beam_hyps_.cum_log_probs, + beam_hyps_.log_probs, + beam_hyps_.num_beams, beam_width, max_seq_len, batch_size, @@ -901,21 +919,23 @@ void T5Decoding::forward(TensorMap* output_tensors, 1, stream_); } - if (output_tensors->isExist("output_log_probs")) { - invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), - output_log_probs_buf_, - max_seq_len, - batch_size * beam_width, - 1, - stream_); - } - // Return the cumulative log probability if requested. - if (output_tensors->isExist("cum_log_probs")) { - Tensor cum_log_probs = output_tensors->at("cum_log_probs"); - FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, - "The shape of cum_log_probs does not match with batch_size x beam_width."); - cudaD2Dcpy(cum_log_probs.getPtr(), cum_log_probs_, batch_size * beam_width); + // Return the cumulative log probability and log probability if requested. + if (!using_beam_hyps) { + if (output_tensors->isExist("output_log_probs")) { + invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), + output_log_probs_buf_, + max_seq_len, + batch_size * beam_width, + 1, + stream_); + } + if (output_tensors->isExist("cum_log_probs")) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, + "The shape of cum_log_probs does not match with batch_size x beam_width."); + cudaD2Dcpy(cum_log_probs.getPtr(), cum_log_probs_, batch_size * beam_width); + } } } diff --git a/src/fastertransformer/th_op/t5/T5DecodingOp.h b/src/fastertransformer/th_op/t5/T5DecodingOp.h index a58bcd9b6..8ada5bbde 100644 --- a/src/fastertransformer/th_op/t5/T5DecodingOp.h +++ b/src/fastertransformer/th_op/t5/T5DecodingOp.h @@ -366,8 +366,8 @@ class FTT5Decoding: public IFTT5Decoding { get_ptr(sequence_length)}}}); if (is_return_output_log_probs) { - auto output_log_probs = torch::empty({(long int)(batch_size * beam_width * max_seq_len)}, - torch::dtype(torch::kFloat).device(torch::kCUDA).requires_grad(false)); + auto output_log_probs = torch::empty({batch_size, beam_width, max_seq_len}, + torch::dtype(torch::kFloat).device(torch::kCUDA).requires_grad(false)); output_tensors.insert({"output_log_probs", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_FP32,