Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] [Kernel] Fix GPU SEGV occurring in int8 kernels #9391

Merged
merged 14 commits into from
Oct 17, 2024
20 changes: 10 additions & 10 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
scale_type const* scale_ptr, const size_t hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change these to size_t as well? For consistency.

scale_type const scale = *scale_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
for (size_t i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
}
Expand All @@ -109,13 +109,13 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void static_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
const size_t hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
for (size_t i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val;
Expand All @@ -125,13 +125,13 @@ __global__ void static_scaled_int8_azp_quant_kernel(
template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
scale_type* scale, const size_t hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;

for (int i = tid; i < hidden_size; i += blockDim.x) {
for (size_t i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
Expand All @@ -149,7 +149,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
__syncthreads();

float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
for (size_t i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
}
Expand All @@ -158,13 +158,13 @@ __global__ void dynamic_scaled_int8_quant_kernel(
template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
scale_type* scale, azp_type* azp, const size_t hidden_size) {
int const token_idx = blockIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
Expand Down Expand Up @@ -199,7 +199,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
azp_type const azp_val = azp_sh;

// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
Expand Down