-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] [Kernel] Fix GPU SEGV occurring in int8 kernels #9391
Merged
tlrmchlsmth
merged 14 commits into
vllm-project:main
from
rasmith:ransmith_int8_segv_fix
Oct 17, 2024
Merged
Changes from 1 commit
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
cc0abbc
fix int8 gpu segv
rasmith 6cd8e38
use uint64_t and ensure iterator can fit in 32-bit register
rasmith ed62d62
revert custom_ops
rasmith 9026085
main
rasmith e207343
add <algorithm> for upstream pr check
rasmith fcddad8
Make literal into long in std::min
rasmith d02a21b
change hidden_size back to int
rasmith b9df1da
change hidden_size back to int
rasmith c50ce42
change hidden_size back to int
rasmith f547c43
change hidden_size back to int
rasmith 1e91784
add comment to clarify use of 64-bit math
rasmith e63904b
use current custom ops
rasmith 6253fe9
clang format
rasmith fa5fd71
Merge branch 'vllm-project:main' into ransmith_int8_segv_fix
rasmith File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,12 +94,12 @@ namespace vllm { | |
template <typename scalar_t, typename scale_type> | ||
__global__ void static_scaled_int8_quant_kernel( | ||
scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||
scale_type const* scale_ptr, const int hidden_size) { | ||
scale_type const* scale_ptr, const size_t hidden_size) { | ||
int const tid = threadIdx.x; | ||
int const token_idx = blockIdx.x; | ||
scale_type const scale = *scale_ptr; | ||
|
||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||
for (size_t i = tid; i < hidden_size; i += blockDim.x) { | ||
out[token_idx * hidden_size + i] = float_to_int8_rn( | ||
static_cast<float>(input[token_idx * hidden_size + i]) / scale); | ||
} | ||
|
@@ -109,13 +109,13 @@ template <typename scalar_t, typename scale_type, typename azp_type> | |
__global__ void static_scaled_int8_azp_quant_kernel( | ||
scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||
scale_type const* scale_ptr, azp_type const* azp_ptr, | ||
const int hidden_size) { | ||
const size_t hidden_size) { | ||
int const tid = threadIdx.x; | ||
int const token_idx = blockIdx.x; | ||
scale_type const scale = *scale_ptr; | ||
azp_type const azp = *azp_ptr; | ||
|
||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||
for (size_t i = tid; i < hidden_size; i += blockDim.x) { | ||
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); | ||
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); | ||
out[token_idx * hidden_size + i] = quant_val; | ||
|
@@ -125,13 +125,13 @@ __global__ void static_scaled_int8_azp_quant_kernel( | |
template <typename scalar_t, typename scale_type> | ||
__global__ void dynamic_scaled_int8_quant_kernel( | ||
scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||
scale_type* scale, const int hidden_size) { | ||
scale_type* scale, const size_t hidden_size) { | ||
int const tid = threadIdx.x; | ||
int const token_idx = blockIdx.x; | ||
float absmax_val = 0.0f; | ||
float const zero = 0.0f; | ||
|
||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||
for (size_t i = tid; i < hidden_size; i += blockDim.x) { | ||
float val = static_cast<float>(input[token_idx * hidden_size + i]); | ||
val = val > zero ? val : -val; | ||
absmax_val = val > absmax_val ? val : absmax_val; | ||
|
@@ -149,7 +149,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( | |
__syncthreads(); | ||
|
||
float const tmp_scale = 127.0f / block_absmax_val; | ||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||
for (size_t i = tid; i < hidden_size; i += blockDim.x) { | ||
out[token_idx * hidden_size + i] = float_to_int8_rn( | ||
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale); | ||
} | ||
|
@@ -158,13 +158,13 @@ __global__ void dynamic_scaled_int8_quant_kernel( | |
template <typename scalar_t, typename scale_type, typename azp_type> | ||
__global__ void dynamic_scaled_int8_azp_quant_kernel( | ||
scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||
scale_type* scale, azp_type* azp, const int hidden_size) { | ||
scale_type* scale, azp_type* azp, const size_t hidden_size) { | ||
int const token_idx = blockIdx.x; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
|
||
// Scan for the min and max value for this token | ||
float max_val = std::numeric_limits<float>::min(); | ||
float min_val = std::numeric_limits<float>::max(); | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
auto val = static_cast<float>(input[token_idx * hidden_size + i]); | ||
max_val = std::max(max_val, val); | ||
min_val = std::min(min_val, val); | ||
|
@@ -199,7 +199,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( | |
azp_type const azp_val = azp_sh; | ||
|
||
// Quantize the values | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); | ||
auto const quant_val = | ||
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change these to
size_t
as well? For consistency.