From 79834543f106ad9806e38ca16884c8a185e5b0c9 Mon Sep 17 00:00:00 2001
From: b4b4o <zwbao@foxmail.com>
Date: Sat, 9 Mar 2024 22:13:58 +0800
Subject: [PATCH] feat: add invokeBatchTopKOnly

---
 .../kernels/sampling_topk_kernels.cu          | 176 ++++++++++++++++++
 src/turbomind/kernels/sampling_topk_kernels.h |  12 ++
 2 files changed, 188 insertions(+)

diff --git a/src/turbomind/kernels/sampling_topk_kernels.cu b/src/turbomind/kernels/sampling_topk_kernels.cu
index 82b208298d..2ab7e1502d 100644
--- a/src/turbomind/kernels/sampling_topk_kernels.cu
+++ b/src/turbomind/kernels/sampling_topk_kernels.cu
@@ -619,4 +619,180 @@ template void invokeTopKTopPSampling(void*          workspace,
                                      const int*     end_ids,
                                      cudaStream_t   stream);
 
+template<typename T, int BLOCK_SIZE_>
+__global__ void topk_only(const T* __restrict log_probs,
+                          T*          tmp_log_probs,
+                          int*        topk_tmp_id_buf,
+                          T*          topk_tmp_val_buf,
+                          const bool* finished,
+                          const int   max_top_k,
+                          const int*  top_ks,
+                          const int   vocab_size,
+                          const int*  end_ids,
+                          const bool* skip_decode)
+{
+    typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
+    __shared__ typename BlockReduce::TempStorage     temp_storage;
+
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int batch_id = bid;
+    if (skip_decode != nullptr && skip_decode[batch_id]) {
+        return;
+    }
+    const int k                  = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k;
+    const int tmp_log_buf_index  = batch_id * vocab_size;
+    const int tmp_topk_buf_index = batch_id * max_top_k;
+
+    TopK_2<T>  partial;
+    const bool IS_FP16   = std::is_same<T, half>::value;
+    const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
+
+    for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) {
+        int index            = elem_id + tmp_log_buf_index;
+        tmp_log_probs[index] = log_probs[index];
+    }
+
+    for (int ite = 0; ite < k; ite++) {
+        partial.init();
+#pragma unroll
+        for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) {
+            int index = elem_id + tmp_log_buf_index;
+            partial.insert(tmp_log_probs[index], index);
+        }
+        TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
+
+        if (tid == 0) {
+            const int index         = tmp_topk_buf_index + ite;
+            topk_tmp_id_buf[index]  = total.p % vocab_size;
+            topk_tmp_val_buf[index] = total.u;
+            tmp_log_probs[total.p]  = -MAX_T_VAL;
+        }
+        __syncthreads();
+    }
+}
+
+#ifdef _MSC_VER
+#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_)                                                                    \
+    if (K_MIN <= max_top_k && max_top_k <= K_MAX) {                                                                    \
+        topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs,                                   \
+                                                                          temp_log_probs,                              \
+                                                                          topk_tmp_id_buf,                             \
+                                                                          topk_tmp_val_buf,                            \
+                                                                          finished,                                    \
+                                                                          max_top_k,                                   \
+                                                                          top_ks,                                      \
+                                                                          vocab_size,                                  \
+                                                                          end_ids,                                     \
+                                                                          skip_decode);                                \
+        break;                                                                                                         \
+    }
+#else
+#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_)                                                                    \
+    case K_MIN ... K_MAX:                                                                                              \
+        topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs,                                   \
+                                                                          temp_log_probs,                              \
+                                                                          topk_tmp_id_buf,                             \
+                                                                          topk_tmp_val_buf,                            \
+                                                                          finished,                                    \
+                                                                          max_top_k,                                   \
+                                                                          top_ks,                                      \
+                                                                          vocab_size,                                  \
+                                                                          end_ids,                                     \
+                                                                          skip_decode);                                \
+        break;
+#endif
+
+template<typename T>
+void invokeBatchTopKOnly(void*        workspace,
+                         size_t&      workspace_size,
+                         const T*     log_probs,
+                         bool*        finished,
+                         const int    max_top_k,
+                         const int*   top_ks,
+                         const int    vocab_size_padded,
+                         const int*   end_ids,
+                         cudaStream_t stream,
+                         const int    batch_size,
+                         const bool*  skip_decode)
+{
+
+    const int vocab_size              = vocab_size_padded;
+    int       temp_log_probs_buf_size = batch_size * vocab_size;
+    int       topk_tmp_ids_buf_size   = batch_size * max_top_k;
+    int       topk_tmp_val_buf_size   = batch_size * max_top_k;
+
+    temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
+    topk_tmp_ids_buf_size   = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
+    topk_tmp_val_buf_size   = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;
+
+    if (workspace == nullptr) {
+        workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size
+                         + sizeof(T) * topk_tmp_val_buf_size;
+        return;
+    }
+
+    T*   temp_log_probs   = (T*)workspace;
+    int* topk_tmp_id_buf  = (int*)(temp_log_probs + temp_log_probs_buf_size);
+    T*   topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
+#ifdef _MSC_VER
+    do {
+        ONLY_TOPK_CASE_K(1, 16, 128 * 8);
+        ONLY_TOPK_CASE_K(17, 32, 256 * 8);
+        ONLY_TOPK_CASE_K(33, 64, 256 * 8);
+        ONLY_TOPK_CASE_K(65, 1024, 256 * 8);
+        throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
+    } while (0);
+#else
+    switch (max_top_k) {
+        ONLY_TOPK_CASE_K(1, 16, 128 * 8);
+        ONLY_TOPK_CASE_K(17, 32, 256 * 8);
+        ONLY_TOPK_CASE_K(33, 64, 256 * 8);
+        ONLY_TOPK_CASE_K(65, 1024, 256 * 8);
+        default:
+            throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
+    }
+#endif
+}
+
+#undef ONLY_TOPK_CASE_K
+
+template void invokeBatchTopKOnly(void*        workspace,
+                                  size_t&      workspace_size,
+                                  const half*  log_probs,
+                                  bool*        finished,
+                                  const int    max_top_k,
+                                  const int*   top_ks,
+                                  const int    vocab_size_padded,
+                                  const int*   end_ids,
+                                  cudaStream_t stream,
+                                  const int    batch_size,
+                                  const bool*  skip_decode);
+
+template void invokeBatchTopKOnly(void*        workspace,
+                                  size_t&      workspace_size,
+                                  const float* log_probs,
+                                  bool*        finished,
+                                  const int    max_top_k,
+                                  const int*   top_ks,
+                                  const int    vocab_size_padded,
+                                  const int*   end_ids,
+                                  cudaStream_t stream,
+                                  const int    batch_size,
+                                  const bool*  skip_decode);
+
+#ifdef ENABLE_BF16
+template void invokeBatchTopKOnly(void*                workspace,
+                                  size_t&              workspace_size,
+                                  const __nv_bfloat16* log_probs,
+                                  bool*                finished,
+                                  const int            max_top_k,
+                                  const int*           top_ks,
+                                  const int            vocab_size_padded,
+                                  const int*           end_ids,
+                                  cudaStream_t         stream,
+                                  const int            batch_size,
+                                  const bool*          skip_decode);
+#endif
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/sampling_topk_kernels.h b/src/turbomind/kernels/sampling_topk_kernels.h
index a539abf0fa..8530e8f5eb 100644
--- a/src/turbomind/kernels/sampling_topk_kernels.h
+++ b/src/turbomind/kernels/sampling_topk_kernels.h
@@ -95,4 +95,16 @@ void invokeTopKTopPSampling(void*          workspace,
                             const int*     end_ids,
                             cudaStream_t   stream);
 
+template<typename T>
+void invokeBatchTopKOnly(void*        workspace,
+                         size_t&      workspace_size,
+                         const T*     log_probs,
+                         bool*        finished,
+                         const int    max_top_k,
+                         const int*   top_ks,
+                         const int    vocab_size_padded,
+                         const int*   end_ids,
+                         cudaStream_t stream,
+                         const int    batch_size,
+                         const bool*  skip_decode);
 }  // namespace turbomind