Skip to content

Commit

Permalink
use half/bf16 lm_head output
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Mar 5, 2025
1 parent e8c8e7a commit 4fc60d2
Show file tree
Hide file tree
Showing 18 changed files with 238 additions and 160 deletions.
4 changes: 2 additions & 2 deletions src/turbomind/engine/model_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output

if (param.gen_cfg.output_logits) {
const int len = param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len : max_out_len;
add(outputs_, "logits", TYPE_FP32, MEMORY_CPU, len, vocab_size_);
add(outputs_, "logits", data_type_, MEMORY_CPU, len, vocab_size_);
}

if (param.gen_cfg.output_last_hidden_state) {
Expand All @@ -133,7 +133,7 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output
}

if (param.gen_cfg.output_logprobs) {
add(outputs_, "logprob_vals", TYPE_FP32, MEMORY_CPU, max_out_len, kMaxLogProb);
add(outputs_, "logprob_vals", data_type_, MEMORY_CPU, max_out_len, kMaxLogProb);
add(outputs_, "logprob_indexes", TYPE_INT32, MEMORY_CPU, max_out_len, kMaxLogProb);
add(outputs_, "logprob_nums", TYPE_INT32, MEMORY_CPU, max_out_len);
}
Expand Down
62 changes: 20 additions & 42 deletions src/turbomind/kernels/ban_bad_words.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "src/turbomind/kernels/ban_bad_words.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -80,7 +81,7 @@ __global__ void ban_bad_words(T* logits,
int banned_token = base_bad_words[item_end - 1];
if (0 < banned_token && banned_token < vocab_size_padded) {
logits[batch_idx * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token] =
static_cast<T>(-INFINITY);
-getMaxValue<T>();
}
}
}
Expand Down Expand Up @@ -119,48 +120,25 @@ void invokeBanBadWords(T* logits,
sync_check_cuda_error();
}

#if 0
template void invokeBanBadWords(half* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);
#define INSTANTIATE_INVOKE_BAN_BAD_WORDS(T) \
template void invokeBanBadWords<T>(T * logits, \
const int* output_ids_buf, \
const int* parent_ids_buf, \
int batch_size, \
int local_batch_size, \
int beam_width, \
const int* bad_words, \
bool share_words, \
size_t bad_words_len, \
int id_offset, \
int vocab_size_padded, \
size_t step, \
cudaStream_t stream);

INSTANTIATE_INVOKE_BAN_BAD_WORDS(float);
INSTANTIATE_INVOKE_BAN_BAD_WORDS(half);
#ifdef ENABLE_BF16
template void invokeBanBadWords(__nv_bfloat16* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);
INSTANTIATE_INVOKE_BAN_BAD_WORDS(__nv_bfloat16);
#endif
#endif
template void invokeBanBadWords(float* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);

} // namespace turbomind
5 changes: 5 additions & 0 deletions src/turbomind/kernels/gpt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ invokeTransposeAxis01(int* out, int* in, const int dim0, const int dim1, const i
template void
invokeTransposeAxis01(uint16_t* out, uint16_t* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

#ifdef ENABLE_BF16
template void invokeTransposeAxis01(
__nv_bfloat16* out, __nv_bfloat16* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);
#endif

template<typename T>
__global__ void transposeAxis01(T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1)
{
Expand Down
27 changes: 25 additions & 2 deletions src/turbomind/kernels/reduce_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ struct BytesToType<16> {
using type = float4;
};

template<typename T>
__device__ inline T getMaxValue();

template<>
__device__ inline float getMaxValue<float>()
{
return FLT_MAX;
}

template<>
__device__ inline half getMaxValue<half>()
{
return CUDART_MAX_NORMAL_FP16;
}

#ifdef ENABLE_BF16
template<>
__device__ inline __nv_bfloat16 getMaxValue<__nv_bfloat16>()
{
return CUDART_MAX_NORMAL_BF16;
}
#endif

template<int Bytes>
__device__ inline void copy(const void* local, void* data)
{
Expand Down Expand Up @@ -319,7 +342,7 @@ __device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(const TopK<T, MAX_K>& a
template<typename T>
struct TopK_2 {
int p = -1;
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
T u = -getMaxValue<T>();

__device__ __forceinline__ void insert(T elem, int elem_id)
{
Expand All @@ -331,7 +354,7 @@ struct TopK_2 {

__device__ __forceinline__ void init()
{
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
u = -getMaxValue<T>();
p = -1;
}
};
Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/kernels/sampling_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ __global__ void sampling(const T* logits,
curandState_t* curandstate,
int* output_ids,
int* sequence_length,
float* sampled_logprobs,
T* sampled_logprobs,
uint32_t* sampled_indexes,
uint32_t* sampled_nums)
{
Expand Down Expand Up @@ -92,11 +92,15 @@ void invokeSampling(SamplingParams& params, cudaStream_t stream)
params.curandstate,
params.output_ids,
params.sequence_length,
params.sampled_logprobs,
(T*)params.sampled_logprobs,
params.sampled_indexes,
params.sampled_nums);
}

template void invokeSampling<float>(SamplingParams& params, cudaStream_t stream);
template void invokeSampling<half>(SamplingParams& params, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeSampling<nv_bfloat16>(SamplingParams& params, cudaStream_t stream);
#endif

} // namespace turbomind
2 changes: 1 addition & 1 deletion src/turbomind/kernels/sampling_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct SamplingParams {
size_t batch_size;
int* output_ids;
int* sequence_length;
float* sampled_logprobs;
void* sampled_logprobs;
uint32_t* sampled_indexes;
uint32_t* sampled_nums;
};
Expand Down
127 changes: 83 additions & 44 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/sampling_penalty_kernels.h"

namespace turbomind {
Expand Down Expand Up @@ -224,9 +225,9 @@ template void invokeBatchApplyTemperaturePenalty(half* logits,
cudaStream_t stream);
#endif

template<int vec_size>
__global__ void batchApplyTemperaturePenalty_v2(float* logits,
const float* bias,
template<typename T, int vec_size>
__global__ void batchApplyTemperaturePenalty_v2(T* logits,
const T* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
Expand All @@ -250,23 +251,45 @@ __global__ void batchApplyTemperaturePenalty_v2(float* logits,
const int step = gridDim.x * blockDim.x * vec_size;

for (int i = vi * vec_size; i < vocab_size_padded; i += step) {
Array<float, vec_size> vec;
Load(vec, logits + i);
Array<T, vec_size> vec;
// load
if constexpr (sizeof(vec) >= sizeof(uint)) {
Load(vec, logits + i);
}
else {
PRAGMA_UNROLL
for (int j = 0; j < vec_size; ++j) {
vec[j] = logits[i + j];
}
}

// process
PRAGMA_UNROLL
for (int c = 0; c < vec_size; ++c) {
if (i + c < vocab_size) {
vec[c] *= scale;
}
else {
vec[c] = -FLT_MAX;
vec[c] = -getMaxValue<T>();
}
}

// store
if constexpr (sizeof(vec) >= sizeof(uint)) {
Store(logits + i, vec);
}
else {
PRAGMA_UNROLL
for (int j = 0; j < vec_size; ++j) {
logits[i + j] = vec[j];
}
}
Store(logits + i, vec);
}
}

void invokeBatchApplyTemperaturePenalty_v2(float* logits,
const float* bias,
template<typename T>
void invokeBatchApplyTemperaturePenalty_v2(T* logits,
const T* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
Expand All @@ -278,7 +301,7 @@ void invokeBatchApplyTemperaturePenalty_v2(float* logits,
constexpr int threads = 256;
const int blocks_per_tok = (vocab_size_padded + threads * vec_size - 1) / (threads * vec_size);
const dim3 blocks(blocks_per_tok, batch_size);
batchApplyTemperaturePenalty_v2<vec_size.value><<<blocks, threads, 0, stream>>>( //
batchApplyTemperaturePenalty_v2<T, vec_size.value><<<blocks, threads, 0, stream>>>( //
logits,
bias,
temperatures,
Expand All @@ -298,6 +321,21 @@ void invokeBatchApplyTemperaturePenalty_v2(float* logits,
}
}

#define INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(T) \
template void invokeBatchApplyTemperaturePenalty_v2(T* logits, \
const T* bias, \
const float* temperatures, \
const int batch_size, \
const int vocab_size, \
const int vocab_size_padded, \
cudaStream_t stream);

INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(float);
INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(__nv_bfloat16);
#endif

template<typename T, RepetitionPenaltyType penalty_type>
__global__ void applyRepetitionPenalty(T* logits,
const float penalty,
Expand Down Expand Up @@ -551,32 +589,26 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
}
}

template void invokeBatchApplyRepetitionPenalty(float* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step,
RepetitionPenaltyType penalty_type,
cudaStream_t stream);
#if 0
template void invokeBatchApplyRepetitionPenalty(half* logits,
const float* penalties,
int* penalty_workspace,
const int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step,
RepetitionPenaltyType penalty_type,
cudaStream_t stream);
#define INSTANTIATE_INVOKE_BATCH_APPLY_REPETITION_PENALTY(T) \
template void invokeBatchApplyRepetitionPenalty(T* logits, \
const float* penalties, \
int* penalty_workspace, \
const int* output_ids, \
const int batch_size, \
const int local_batch_size, \
const int vocab_size, \
const int* input_lengths, \
const int max_input_length, \
const int step, \
RepetitionPenaltyType penalty_type, \
cudaStream_t stream);

INSTANTIATE_INVOKE_BATCH_APPLY_REPETITION_PENALTY(float);
INSTANTIATE_INVOKE_BATCH_APPLY_REPETITION_PENALTY(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_BATCH_APPLY_REPETITION_PENALTY(__nv_bfloat16);
#endif

template<typename T>
__global__ void batchApplyMinLengthPenalty(T* __restrict__ logits,
const int* __restrict__ min_lengths,
Expand All @@ -592,7 +624,7 @@ __global__ void batchApplyMinLengthPenalty(T* __restrict__ logits,
if (bid < batch_size) {
int end_id = end_ids[bid * end_ids_size + eid];
if (end_id > 0 && sequence_lengths[bid] + 1 < min_lengths[bid]) {
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
T mask_val = -getMaxValue<T>();
logits[bid * vocab_size_padded + end_id] = mask_val;
}
}
Expand All @@ -614,13 +646,20 @@ void invokeMinLengthPenalty(T* logits,
logits, min_lengths, sequnece_lengths, vocab_size_padded, batch_size, end_ids, end_ids_size);
}

template void invokeMinLengthPenalty(float* logits,
const int* min_lengths,
const int* sequnece_lengths,
const int vocab_size_padded,
const int batch_size,
const int* end_ids,
const int end_ids_size,
cudaStream_t stream);
#define INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(T) \
template void invokeMinLengthPenalty(T* logits, \
const int* min_lengths, \
const int* sequnece_lengths, \
const int vocab_size_padded, \
const int batch_size, \
const int* end_ids, \
const int end_ids_size, \
cudaStream_t stream);

INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(float);
INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(__nv_bfloat16);
#endif

} // namespace turbomind
Loading

0 comments on commit 4fc60d2

Please sign in to comment.