Skip to content

Commit

Permalink
fix: fix bug of ia3
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Jan 6, 2023
1 parent f0b5b86 commit 6ca2d58
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 73 deletions.
88 changes: 72 additions & 16 deletions src/fastertransformer/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ __global__ void generic_activation(T* out,
const int int8_mode,
const float* __restrict activation_in,
const float* __restrict activation_out,
const int* __restrict padding_offset,
const int seq_len,
int m,
int n)
{
Expand Down Expand Up @@ -219,7 +221,10 @@ __global__ void generic_activation(T* out,
}

if (with_ia3) {
const int task = ia3_tasks[id / n];
const int word_id = id / n;
const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
const int batch_id = (word_id + offset) / seq_len;
const int task = ia3_tasks[batch_id];
val = val * ia3_weights[task * n + (id % n)];
}

Expand All @@ -246,6 +251,8 @@ void invokeGenericActivation(T* out,
const int int8_mode,
const float* activation_in,
const float* activation_out,
const int* padding_offset,
const int seq_len,
cudaStream_t stream)
{
using PT = typename packed_type<T>::type;
Expand All @@ -270,6 +277,8 @@ void invokeGenericActivation(T* out,
int8_mode,
activation_in,
activation_out,
padding_offset,
seq_len,
m,
n / packed_elems);
}
Expand All @@ -286,7 +295,9 @@ void invokeGenericActivation(T* out,
const int int8_mode, \
const float* activation_in, \
const float* activation_out, \
cudaStream_t stream)
const int* padding_offset, \
const int seq_len, \
cudaStream_t stream);

INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float);
INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half);
Expand Down Expand Up @@ -317,8 +328,13 @@ INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, __nv_bfloat16);
#undef INSTANCIATE_GENERIC_ACTIVATION

template<typename T2, int N>
__global__ void
addBiasGeluV2(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2* ia3_weights, const int size)
__global__ void addBiasGeluV2(T2* out,
const T2* __restrict bias,
const int* ia3_tasks,
const T2* ia3_weights,
const int size,
const int* padding_offset,
const int seq_len)
{
const bool with_ia3 = ia3_tasks != nullptr;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
Expand All @@ -329,16 +345,24 @@ addBiasGeluV2(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2
}
val = GeluActivation<T2>::apply(val);
if (with_ia3) {
const int task = ia3_tasks[id / N];
val = val * ia3_weights[task * N + (id % N)];
const int word_id = id / N;
const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
const int batch_id = (word_id + offset) / seq_len;
const int task = ia3_tasks[batch_id];
val = val * ia3_weights[task * N + (id % N)];
}
out[id] = val;
}
}

template<typename T2, int N, int ELEMENT_PER_ROUND>
__global__ void
addBiasGeluV3(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2* ia3_weights, const int size)
__global__ void addBiasGeluV3(T2* out,
const T2* __restrict bias,
const int* ia3_tasks,
const T2* ia3_weights,
const int size,
const int* padding_offset,
const int seq_len)
{
const bool with_ia3 = ia3_tasks != nullptr;
T2 buffer[ELEMENT_PER_ROUND];
Expand All @@ -359,8 +383,11 @@ addBiasGeluV3(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2
}
buffer[i] = GeluActivation<T2>::apply(buffer[i]);
if (with_ia3) {
const int task = ia3_tasks[(id + i) / N];
buffer[i] = buffer[i] * ia3_weights[task * N + ((id + i) % N)];
const int word_id = (id + i) / N;
const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
const int batch_id = (word_id + offset) / seq_len;
const int task = ia3_tasks[batch_id];
buffer[i] = buffer[i] * ia3_weights[task * N + ((id + i) % N)];
}
out[id + i] = buffer[i];
}
Expand All @@ -371,18 +398,35 @@ addBiasGeluV3(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2
case HALF_N: \
if (ELEMENT_PER_ROUND > 1) { \
grid.x = grid.x / ELEMENT_PER_ROUND; \
addBiasGeluV3<T2, HALF_N, ELEMENT_PER_ROUND> \
<<<grid, block, 0, stream>>>((T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n); \
addBiasGeluV3<T2, HALF_N, ELEMENT_PER_ROUND><<<grid, block, 0, stream>>>((T2*)out, \
(const T2*)bias, \
ia3_tasks, \
(T2*)ia3_weights, \
m * half_n, \
padding_offset, \
seq_len); \
} \
else { \
addBiasGeluV2<T2, HALF_N> \
<<<grid, block, 0, stream>>>((T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n); \
addBiasGeluV2<T2, HALF_N><<<grid, block, 0, stream>>>((T2*)out, \
(const T2*)bias, \
ia3_tasks, \
(T2*)ia3_weights, \
m * half_n, \
padding_offset, \
seq_len); \
} \
break;

template<typename T>
void invokeAddBiasGeluV2(
T* out, const T* bias, const int* ia3_tasks, const T* ia3_weights, const int m, const int n, cudaStream_t stream)
void invokeAddBiasGeluV2(T* out,
const T* bias,
const int* ia3_tasks,
const T* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream)
{
if (n % 2 == 0 && sizeof(T) == 2) {
const int half_n = n / 2;
Expand Down Expand Up @@ -415,6 +459,8 @@ void invokeAddBiasGeluV2(
0,
(float*)nullptr,
(float*)nullptr,
padding_offset,
seq_len,
stream);
break;
}
Expand Down Expand Up @@ -443,6 +489,8 @@ void invokeAddBiasGeluV2(
0,
(float*)nullptr,
(float*)nullptr,
padding_offset,
seq_len,
stream);
break;
}
Expand All @@ -460,6 +508,8 @@ void invokeAddBiasGeluV2(
0,
(float*)nullptr,
(float*)nullptr,
padding_offset,
seq_len,
stream);
}
}
Expand All @@ -470,13 +520,17 @@ template void invokeAddBiasGeluV2(float* out,
const float* bias,
const int* ia3_tasks,
const float* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
template void invokeAddBiasGeluV2(half* out,
const half* bias,
const int* ia3_tasks,
const half* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
Expand All @@ -485,6 +539,8 @@ template void invokeAddBiasGeluV2(__nv_bfloat16* out,
const __nv_bfloat16* bias,
const int* ia3_tasks,
const __nv_bfloat16* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
Expand Down
48 changes: 47 additions & 1 deletion src/fastertransformer/kernels/activation_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,57 @@ void invokeGenericActivation(T* out,
const int int8_mode,
const float* activation_in,
const float* activation_out,
const int* padding_offset,
const int seq_len,
cudaStream_t stream);

template<template<typename T> class Activation, typename T, typename BT>
void invokeGenericActivation(T* out,
const BT* bias,
const T* gated_weights,
const BT* gated_bias,
const int* ia3_tasks,
const T* ia3_weights,
const int m,
const int n,
const int int8_mode,
const float* activation_in,
const float* activation_out,
cudaStream_t stream)
{
invokeGenericActivation<Activation, T, BT>(out,
bias,
gated_weights,
gated_bias,
ia3_tasks,
ia3_weights,
m,
n,
int8_mode,
activation_in,
activation_out,
(const int*)nullptr,
0,
stream);
}

template<typename T>
void invokeAddBiasGeluV2(T* out,
const T* bias,
const int* ia3_tasks,
const T* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);

template<typename T>
void invokeAddBiasGeluV2(
T* out, const T* bias, const int* ia3_tasks, const T* ia3_weights, const int m, const int n, cudaStream_t stream);
T* out, const T* bias, const int* ia3_tasks, const T* ia3_weights, const int m, const int n, cudaStream_t stream)
{
invokeAddBiasGeluV2(out, bias, ia3_tasks, ia3_weights, nullptr, 0, m, n, stream);
}

template<typename T>
void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream);
Expand Down
1 change: 0 additions & 1 deletion src/fastertransformer/kernels/beam_search_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,6 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps,
beam_hyps.log_probs_src[length * batch_size * beam_width + bid * beam_width + src_beam_idx];
}
int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + src_beam_idx];
// printf("[INFO] i = %d, cum_log_probs: %f \n", i, cum_log_probs[src_beam_idx]);
for (int j = length - 1; j >= 0; j--) {
// output_ids_tgt need to use max_seq_len + 1 because its shape is
// [bs, beam_width, max_seq_len + 1]
Expand Down
4 changes: 2 additions & 2 deletions src/fastertransformer/kernels/unfused_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ __global__ void add_QKV_bias_rebuild_padding_ia3(const T* Q,
const int n = head_num * size_per_head;

const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[bid] : 0;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
Expand Down Expand Up @@ -956,7 +956,7 @@ __global__ void rebuild_padding_ia3(const T* Q,
const int n = head_num * size_per_head;

const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[bid] : 0;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
Expand Down
Loading

0 comments on commit 6ca2d58

Please sign in to comment.