diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 024331fa4e64e..72210468bd557 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -29,11 +29,12 @@ __global__ void act_and_mul_quant_kernel( FP8_TYPE* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const float* scale, const int d) { - const int64_t token_idx = blockIdx.x; const int32_t blocks_per_token = gridDim.y; const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + // We don't expect the hidden dimension to exceed 32 bits so int32 should + // be safe here. const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); const int32_t elems_per_block = round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); @@ -41,6 +42,9 @@ __global__ void act_and_mul_quant_kernel( int32_t block_end = block_start + elems_per_block; block_end = block_end > d ? d : block_end; + // token_idx is 64 bit to prevent 32 bit overflow when the number of tokens + // is very large + const int64_t token_idx = blockIdx.x; const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; FP8_TYPE* __restrict__ out_ptr = out + token_idx * d; @@ -105,12 +109,12 @@ __global__ void act_and_mul_quant_kernel( scale.data_ptr(), d); \ }); -void silu_and_mul_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, - torch::Tensor& scale) // [..., 2 * d] -{ +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) { TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); + TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } \ No newline at end of file diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 7424f7c9c3639..992108f7873e7 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -49,8 +49,6 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, config = CompilationConfig.PassConfig(enable_fusion=do_fusion, enable_reshape=True) - # compilation_config = CompilationConfig(level=3, - # custom_ops=["+silu_and_mul"]) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) act_quant_fusion_pass = ActivationQuantFusionPass.instance(config) @@ -85,9 +83,6 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, gen_no_func = llm.generate(prompts, sampling_params) - # print( - # backend_func.graph_pre_pass.python_code(root_module="self", - # verbose=True).src) for output_func, output_no_func in zip(gen_func, gen_no_func): assert output_func.outputs[0].text == output_no_func.outputs[0].text diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index d0fee1c1aa014..fdb86c13adbb7 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -8,7 +8,7 @@ DTYPES = [torch.bfloat16, torch.float16] QUANT_DTYPES = [torch.float8_e4m3fn] NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing -HIDDEN_SIZES = [32, 64, 128, 2048, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)