Skip to content

Commit

Permalink
[BugFix] [Kernel] Fix GPU SEGV occurring in int8 kernels (#9391)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmith authored Oct 17, 2024
1 parent c3fab5f commit 92d86da
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
}
}

Expand All @@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val;
out[i] = quant_val;
}
}

Expand All @@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
float val = static_cast<float>(input[i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
Expand All @@ -150,22 +161,25 @@ __global__ void dynamic_scaled_int8_quant_kernel(

float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
}
}

template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
}
Expand Down Expand Up @@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(

// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const val = static_cast<float>(input[i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val;
out[i] = quant_val;
}
}

Expand Down

0 comments on commit 92d86da

Please sign in to comment.