Skip to content

Commit

Permalink
[Bugfix][Kernel] Promote another index to int64_t (vllm-project#6838)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored and cadedaniel committed Jul 27, 2024
1 parent 0a5ca54 commit 877808a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
Expand Down

0 comments on commit 877808a

Please sign in to comment.