Skip to content

Commit

Permalink
feat: add invokeBatchTopKOnly
Browse files Browse the repository at this point in the history
  • Loading branch information
b4b4o authored and zhyncs committed Mar 11, 2024
1 parent 45d7521 commit 7983454
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
176 changes: 176 additions & 0 deletions src/turbomind/kernels/sampling_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -619,4 +619,180 @@ template void invokeTopKTopPSampling(void* workspace,
const int* end_ids,
cudaStream_t stream);

template<typename T, int BLOCK_SIZE_>
__global__ void topk_only(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size,
const int* end_ids,
const bool* skip_decode)
{
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int tid = threadIdx.x;
const int bid = blockIdx.x;

const int batch_id = bid;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k;
const int tmp_log_buf_index = batch_id * vocab_size;
const int tmp_topk_buf_index = batch_id * max_top_k;

TopK_2<T> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;

for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) {
int index = elem_id + tmp_log_buf_index;
tmp_log_probs[index] = log_probs[index];
}

for (int ite = 0; ite < k; ite++) {
partial.init();
#pragma unroll
for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) {
int index = elem_id + tmp_log_buf_index;
partial.insert(tmp_log_probs[index], index);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);

if (tid == 0) {
const int index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p % vocab_size;
topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL;
}
__syncthreads();
}
}

#ifdef _MSC_VER
#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \
if (K_MIN <= max_top_k && max_top_k <= K_MAX) { \
topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
break; \
}
#else
#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \
case K_MIN ... K_MAX: \
topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
break;
#endif

template<typename T>
void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const T* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode)
{

const int vocab_size = vocab_size_padded;
int temp_log_probs_buf_size = batch_size * vocab_size;
int topk_tmp_ids_buf_size = batch_size * max_top_k;
int topk_tmp_val_buf_size = batch_size * max_top_k;

temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;

if (workspace == nullptr) {
workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size
+ sizeof(T) * topk_tmp_val_buf_size;
return;
}

T* temp_log_probs = (T*)workspace;
int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size);
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
#ifdef _MSC_VER
do {
ONLY_TOPK_CASE_K(1, 16, 128 * 8);
ONLY_TOPK_CASE_K(17, 32, 256 * 8);
ONLY_TOPK_CASE_K(33, 64, 256 * 8);
ONLY_TOPK_CASE_K(65, 1024, 256 * 8);
throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
} while (0);
#else
switch (max_top_k) {
ONLY_TOPK_CASE_K(1, 16, 128 * 8);
ONLY_TOPK_CASE_K(17, 32, 256 * 8);
ONLY_TOPK_CASE_K(33, 64, 256 * 8);
ONLY_TOPK_CASE_K(65, 1024, 256 * 8);
default:
throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
}
#endif
}

#undef ONLY_TOPK_CASE_K

template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const half* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);

template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const float* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);

#ifdef ENABLE_BF16
template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const __nv_bfloat16* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
#endif
} // namespace turbomind
12 changes: 12 additions & 0 deletions src/turbomind/kernels/sampling_topk_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,16 @@ void invokeTopKTopPSampling(void* workspace,
const int* end_ids,
cudaStream_t stream);

template<typename T>
void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const T* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
} // namespace turbomind

0 comments on commit 7983454

Please sign in to comment.