From e8254ab30fdc6220b44e7bd47490f3ab08b1a08e Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sat, 9 Mar 2024 22:13:58 +0800 Subject: [PATCH] feat: add invokeBatchTopKOnly --- .../kernels/sampling_topk_kernels.cu | 176 ++++++++++++++++++ src/turbomind/kernels/sampling_topk_kernels.h | 12 ++ 2 files changed, 188 insertions(+) diff --git a/src/turbomind/kernels/sampling_topk_kernels.cu b/src/turbomind/kernels/sampling_topk_kernels.cu index 82b208298..2ab7e1502 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.cu +++ b/src/turbomind/kernels/sampling_topk_kernels.cu @@ -619,4 +619,180 @@ template void invokeTopKTopPSampling(void* workspace, const int* end_ids, cudaStream_t stream); +template +__global__ void topk_only(const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size, + const int* end_ids, + const bool* skip_decode) +{ + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int batch_id = bid; + if (skip_decode != nullptr && skip_decode[batch_id]) { + return; + } + const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; + const int tmp_log_buf_index = batch_id * vocab_size; + const int tmp_topk_buf_index = batch_id * max_top_k; + + TopK_2 partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + + for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) { + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = log_probs[index]; + } + + for (int ite = 0; ite < k; ite++) { + partial.init(); +#pragma unroll + for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) { + int index = elem_id + tmp_log_buf_index; + partial.insert(tmp_log_probs[index], index); + } + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p % vocab_size; + topk_tmp_val_buf[index] = total.u; + tmp_log_probs[total.p] = -MAX_T_VAL; + } + __syncthreads(); + } +} + +#ifdef _MSC_VER +#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \ + if (K_MIN <= max_top_k && max_top_k <= K_MAX) { \ + topk_only<<>>(log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + max_top_k, \ + top_ks, \ + vocab_size, \ + end_ids, \ + skip_decode); \ + break; \ + } +#else +#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \ + case K_MIN ... K_MAX: \ + topk_only<<>>(log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + max_top_k, \ + top_ks, \ + vocab_size, \ + end_ids, \ + skip_decode); \ + break; +#endif + +template +void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const T* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode) +{ + + const int vocab_size = vocab_size_padded; + int temp_log_probs_buf_size = batch_size * vocab_size; + int topk_tmp_ids_buf_size = batch_size * max_top_k; + int topk_tmp_val_buf_size = batch_size * max_top_k; + + temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + + if (workspace == nullptr) { + workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size + + sizeof(T) * topk_tmp_val_buf_size; + return; + } + + T* temp_log_probs = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); +#ifdef _MSC_VER + do { + ONLY_TOPK_CASE_K(1, 16, 128 * 8); + ONLY_TOPK_CASE_K(17, 32, 256 * 8); + ONLY_TOPK_CASE_K(33, 64, 256 * 8); + ONLY_TOPK_CASE_K(65, 1024, 256 * 8); + throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k)); + } while (0); +#else + switch (max_top_k) { + ONLY_TOPK_CASE_K(1, 16, 128 * 8); + ONLY_TOPK_CASE_K(17, 32, 256 * 8); + ONLY_TOPK_CASE_K(33, 64, 256 * 8); + ONLY_TOPK_CASE_K(65, 1024, 256 * 8); + default: + throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k)); + } +#endif +} + +#undef ONLY_TOPK_CASE_K + +template void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const half* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +template void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const float* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +#ifdef ENABLE_BF16 +template void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const __nv_bfloat16* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); +#endif } // namespace turbomind diff --git a/src/turbomind/kernels/sampling_topk_kernels.h b/src/turbomind/kernels/sampling_topk_kernels.h index a539abf0f..8530e8f5e 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.h +++ b/src/turbomind/kernels/sampling_topk_kernels.h @@ -95,4 +95,16 @@ void invokeTopKTopPSampling(void* workspace, const int* end_ids, cudaStream_t stream); +template +void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const T* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); } // namespace turbomind