From cc0abbca09f55c141358e81829cdc4678f3bfb37 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 15:52:31 -0500 Subject: [PATCH 01/13] fix int8 gpu segv --- .../compressed_tensors/int8_quant_kernels.cu | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aec9fa002f96e..42a766482077a 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -94,12 +94,12 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const int hidden_size) { + scale_type const* scale_ptr, const size_t hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; - for (int i = tid; i < hidden_size; i += blockDim.x) { + for (size_t i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) / scale); } @@ -109,13 +109,13 @@ template __global__ void static_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, azp_type const* azp_ptr, - const int hidden_size) { + const size_t hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; - for (int i = tid; i < hidden_size; i += blockDim.x) { + for (size_t i = tid; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); out[token_idx * hidden_size + i] = quant_val; @@ -125,13 +125,13 @@ __global__ void static_scaled_int8_azp_quant_kernel( template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const int hidden_size) { + scale_type* scale, const size_t hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; - for (int i = tid; i < hidden_size; i += blockDim.x) { + for (size_t i = tid; i < hidden_size; i += blockDim.x) { float val = static_cast(input[token_idx * hidden_size + i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; @@ -149,7 +149,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( __syncthreads(); float const tmp_scale = 127.0f / block_absmax_val; - for (int i = tid; i < hidden_size; i += blockDim.x) { + for (size_t i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) * tmp_scale); } @@ -158,13 +158,13 @@ __global__ void dynamic_scaled_int8_quant_kernel( template __global__ void dynamic_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, azp_type* azp, const int hidden_size) { + scale_type* scale, azp_type* azp, const size_t hidden_size) { int const token_idx = blockIdx.x; // Scan for the min and max value for this token float max_val = std::numeric_limits::min(); float min_val = std::numeric_limits::max(); - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { auto val = static_cast(input[token_idx * hidden_size + i]); max_val = std::max(max_val, val); min_val = std::min(min_val, val); @@ -199,7 +199,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( azp_type const azp_val = azp_sh; // Quantize the values - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); From 6cd8e384543a176249a20f2459bc923634737ddf Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 18:14:42 -0500 Subject: [PATCH 02/13] use uint64_t and ensure iterator can fit in 32-bit register --- .../compressed_tensors/int8_quant_kernels.cu | 62 +++++++++++-------- vllm/_custom_ops.py | 18 +++++- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 42a766482077a..b04613c6b7a6e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -94,14 +94,16 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const size_t hidden_size) { + scale_type const* scale_ptr, const int64_t hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; - for (size_t i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) / scale); + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[i] = float_to_int8_rn(static_cast(input[i]) / scale); } } @@ -109,30 +111,36 @@ template __global__ void static_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, azp_type const* azp_ptr, - const size_t hidden_size) { + const int64_t hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; - for (size_t i = tid; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[token_idx * hidden_size + i]); + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); - out[token_idx * hidden_size + i] = quant_val; + out[i] = quant_val; } } template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const size_t hidden_size) { + scale_type* scale, const int64_t hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; - for (size_t i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[token_idx * hidden_size + i]); + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = static_cast(input[i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; } @@ -149,23 +157,25 @@ __global__ void dynamic_scaled_int8_quant_kernel( __syncthreads(); float const tmp_scale = 127.0f / block_absmax_val; - for (size_t i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); } } template __global__ void dynamic_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, azp_type* azp, const size_t hidden_size) { - int const token_idx = blockIdx.x; + scale_type* scale, azp_type* azp, const int64_t hidden_size) { + int64_t const token_idx = blockIdx.x; + + out += token_idx * hidden_size; + input += token_idx * hidden_size; // Scan for the min and max value for this token float max_val = std::numeric_limits::min(); float min_val = std::numeric_limits::max(); - for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[token_idx * hidden_size + i]); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[i]); max_val = std::max(max_val, val); min_val = std::min(min_val, val); } @@ -199,11 +209,11 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( azp_type const azp_val = azp_sh; // Quantize the values - for (size_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[token_idx * hidden_size + i]); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); - out[token_idx * hidden_size + i] = quant_val; + out[i] = quant_val; } } @@ -218,7 +228,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(scale.numel() == 1); TORCH_CHECK(!azp || azp->numel() == 1); - int const hidden_size = input.size(-1); + int64_t const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); @@ -249,7 +259,7 @@ void dynamic_scaled_int8_quant( TORCH_CHECK(scales.is_contiguous()); TORCH_CHECK(!azp || azp->is_contiguous()); - int const hidden_size = input.size(-1); + int64_t const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a23692285efe..8fa8f6fb484a2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -507,10 +507,11 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) + # out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + out = torch.zeros((m, n), dtype=out_dtype, device=a.device) return out @@ -740,13 +741,21 @@ def scaled_int8_quant( Returns: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ - output = torch.empty_like(input, dtype=torch.int8) + # output = torch.empty_like(input, dtype=torch.int8) + output = torch.zeros_like(input, dtype=torch.int8) + input_scales = torch.zeros((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.zeros_like(input_scales, + dtype=torch.int32) if scale is not None: + print(f"call static_scaled_fp8_quant") # static-per-tensor quantization. assert symmetric == ( azp is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + print(f"done with static_scaled_fp8_quant") return output, scale, None # dynamic-per-token quantization. @@ -755,6 +764,9 @@ def scaled_int8_quant( dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + print(f"call dynamic_scaled_int8_quant output = {output.shape}, " + f"input = {input.shape}, input_scales = {input_scales.shape}," + f"input_azp = {'None' if input_azp is None else input_azp.shape}") torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp From ed62d623871ad36d24902965bd9b745b6db7fba0 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 18:15:56 -0500 Subject: [PATCH 03/13] revert custom_ops --- vllm/_custom_ops.py | 182 ++++++++++++++++---------------------------- 1 file changed, 67 insertions(+), 115 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8fa8f6fb484a2..4d71381184de5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,9 +1,8 @@ import contextlib import functools -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch -import torch.library import vllm.envs as envs from vllm._core_ext import ScalarType @@ -26,16 +25,6 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -if TYPE_CHECKING: - - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - def hint_on_error(fn): @@ -43,15 +32,6 @@ def hint_on_error(fn): def wrapper(*args, **kwargs): try: return fn(*args, **kwargs) - - except NotImplementedError as e: - msg = ( - "Error in calling custom op %s: %s\n" - "Not implemented or built, mostly likely because the current current device " - "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set " - "incorrectly while building)") - logger.error(msg, fn.__name__, e) - raise NotImplementedError(msg % (fn.__name__, e)) from e except AttributeError as e: msg = ( "Error in calling custom op %s: %s\n" @@ -277,7 +257,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_gemm"): - @register_fake("_C::gptq_gemm") + @torch.library.register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, @@ -312,7 +292,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - @register_fake("_C::gptq_marlin_24_gemm") + @torch.library.register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, @@ -320,7 +300,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::gptq_marlin_gemm") + @torch.library.register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -337,12 +317,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::ggml_dequantize") + @torch.library.register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, n: int) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) - @register_fake("_C::ggml_mul_mat_vec_a8") + @torch.library.register_fake("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -351,7 +331,7 @@ def _ggml_mul_mat_vec_a8_fake( ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) - @register_fake("_C::ggml_mul_mat_a8") + @torch.library.register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -361,7 +341,7 @@ def _ggml_mul_mat_a8_fake( batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @register_fake("_C::marlin_qqq_gemm") + @torch.library.register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, @@ -371,7 +351,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @register_fake("_C::marlin_gemm") + @torch.library.register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, @@ -380,7 +360,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @register_fake("_C::awq_dequantize") + @torch.library.register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: @@ -391,7 +371,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, dtype=scales.dtype, device=scales.device) - @register_fake("_C::awq_gemm") + @torch.library.register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: @@ -400,7 +380,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, dtype=input.dtype, device=input.device).sum(0) - @register_fake("_C::aqlm_gemm") + @torch.library.register_fake("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: List[int], @@ -416,7 +396,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, output_sizes.append(-1) return flat_output.reshape(tuple(output_sizes)) - @register_fake("_C::aqlm_dequant") + @torch.library.register_fake("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int]) -> torch.Tensor: @@ -426,14 +406,14 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @register_fake("_C::fp8_marlin_gemm") + @torch.library.register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_gemm") + @torch.library.register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, # Should be the tensor returned by machete_prepack_B @@ -451,42 +431,41 @@ def machete_gemm_fake( n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) - @register_fake("_C::machete_prepack_B") + @torch.library.register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @register_fake("_C::causal_conv1d_fwd") + @torch.library.register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - cu_seq_len: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: return torch.empty_like(x) - @register_fake("_C::causal_conv1d_update") + @torch.library.register_fake("_C::causal_conv1d_update") def causal_conv1d_update_fake( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) - @register_fake("_C::selective_scan_fwd") - def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, - A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, - cu_seq_len: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - ssm_states: Optional[torch.Tensor]) -> None: - return None + @torch.library.register_fake("_C::selective_scan_fwd") + def selective_scan_fwd_fake( + u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + a = torch.empty_like(u) + if z_ is not None: + c = torch.empty_like(z_) + return [a, c] + else: + return [a] # cutlass @@ -507,11 +486,10 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] - # out = torch.empty((m, n), dtype=out_dtype, device=a.device) + out = torch.empty((m, n), dtype=out_dtype, device=a.device) - # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - out = torch.zeros((m, n), dtype=out_dtype, device=a.device) return out @@ -580,20 +558,6 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output -def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - num_experts = b_q_weight.shape[0] - assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) - for e in range(num_experts): - output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, - size_n, num_bits) - return output - - def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -651,7 +615,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "permute_cols"): - @register_fake("_C::permute_cols") + @torch.library.register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -741,21 +705,13 @@ def scaled_int8_quant( Returns: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ - # output = torch.empty_like(input, dtype=torch.int8) - output = torch.zeros_like(input, dtype=torch.int8) - input_scales = torch.zeros((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.zeros_like(input_scales, - dtype=torch.int32) + output = torch.empty_like(input, dtype=torch.int8) if scale is not None: - print(f"call static_scaled_fp8_quant") # static-per-tensor quantization. assert symmetric == ( azp is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) - print(f"done with static_scaled_fp8_quant") return output, scale, None # dynamic-per-token quantization. @@ -764,9 +720,6 @@ def scaled_int8_quant( dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - print(f"call dynamic_scaled_int8_quant output = {output.shape}, " - f"input = {input.shape}, input_scales = {input_scales.shape}," - f"input_azp = {'None' if input_azp is None else input_azp.shape}") torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp @@ -808,37 +761,37 @@ def ggml_mul_mat_a8( # mamba def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation) + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + initial_states_, final_states_out_, + silu_activation) def causal_conv1d_update( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + conv_state_indices: Optional[torch.Tensor], +) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, + silu_activation, conv_state_indices) -def selective_scan_fwd( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): - torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, - delta_softplus, query_start_loc, - cache_indices, has_initial_state, - ssm_states) +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, + delta_bias_, delta_softplus, index_, + x) # moe @@ -860,17 +813,16 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - @register_fake("_moe_C::marlin_gemm_moe") + @torch.library.register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, - b_zero_points: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, num_experts: int, - topk: int, moe_block_size: int, - replicate_input: bool, + g_idx: torch.Tensor, perm: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, From 9026085c09fc572f7e94e8047e2110b09ff810c7 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 18:17:48 -0500 Subject: [PATCH 04/13] main --- vllm/_custom_ops.py | 164 +++++++++++++++++++++++++++----------------- 1 file changed, 100 insertions(+), 64 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4d71381184de5..3a23692285efe 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,8 +1,9 @@ import contextlib import functools -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch +import torch.library import vllm.envs as envs from vllm._core_ext import ScalarType @@ -25,6 +26,16 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + def hint_on_error(fn): @@ -32,6 +43,15 @@ def hint_on_error(fn): def wrapper(*args, **kwargs): try: return fn(*args, **kwargs) + + except NotImplementedError as e: + msg = ( + "Error in calling custom op %s: %s\n" + "Not implemented or built, mostly likely because the current current device " + "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set " + "incorrectly while building)") + logger.error(msg, fn.__name__, e) + raise NotImplementedError(msg % (fn.__name__, e)) from e except AttributeError as e: msg = ( "Error in calling custom op %s: %s\n" @@ -257,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_gemm"): - @torch.library.register_fake("_C::gptq_gemm") + @register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, @@ -292,7 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - @torch.library.register_fake("_C::gptq_marlin_24_gemm") + @register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, @@ -300,7 +320,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::gptq_marlin_gemm") + @register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -317,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::ggml_dequantize") + @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, n: int) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_vec_a8") + @register_fake("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -331,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake( ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_a8") + @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -341,7 +361,7 @@ def _ggml_mul_mat_a8_fake( batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::marlin_qqq_gemm") + @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, @@ -351,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::marlin_gemm") + @register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, @@ -360,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::awq_dequantize") + @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: @@ -371,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, dtype=scales.dtype, device=scales.device) - @torch.library.register_fake("_C::awq_gemm") + @register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: @@ -380,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, dtype=input.dtype, device=input.device).sum(0) - @torch.library.register_fake("_C::aqlm_gemm") + @register_fake("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: List[int], @@ -396,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, output_sizes.append(-1) return flat_output.reshape(tuple(output_sizes)) - @torch.library.register_fake("_C::aqlm_dequant") + @register_fake("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int]) -> torch.Tensor: @@ -406,14 +426,14 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @torch.library.register_fake("_C::fp8_marlin_gemm") + @register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @torch.library.register_fake("_C::machete_gemm") + @register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, # Should be the tensor returned by machete_prepack_B @@ -431,41 +451,42 @@ def machete_gemm_fake( n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::machete_prepack_B") + @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @torch.library.register_fake("_C::causal_conv1d_fwd") + @register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], - seq_idx_: Optional[torch.Tensor], - initial_states_: Optional[torch.Tensor], - final_states_out_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + cu_seq_len: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::causal_conv1d_update") + @register_fake("_C::causal_conv1d_update") def causal_conv1d_update_fake( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::selective_scan_fwd") - def selective_scan_fwd_fake( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, - B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, index_: Optional[torch.Tensor], - x: Optional[torch.Tensor]) -> List[torch.Tensor]: - a = torch.empty_like(u) - if z_ is not None: - c = torch.empty_like(z_) - return [a, c] - else: - return [a] + @register_fake("_C::selective_scan_fwd") + def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, + A: torch.Tensor, B: torch.Tensor, + C: torch.Tensor, D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + cu_seq_len: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: Optional[torch.Tensor]) -> None: + return None # cutlass @@ -558,6 +579,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -615,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "permute_cols"): - @torch.library.register_fake("_C::permute_cols") + @register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -761,37 +796,37 @@ def ggml_mul_mat_a8( # mamba def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], - seq_idx_: Optional[torch.Tensor], - initial_states_: Optional[torch.Tensor], - final_states_out_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, - initial_states_, final_states_out_, - silu_activation) + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation) def causal_conv1d_update( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias_: Optional[torch.Tensor], - silu_activation: bool, - conv_state_indices: Optional[torch.Tensor], -) -> torch.Tensor: + x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, + silu_activation, cache_seqlens, conv_state_indices) -def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, - B: torch.Tensor, C: torch.Tensor, - D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, index_: Optional[torch.Tensor], - x: Optional[torch.Tensor]) -> List[torch.Tensor]: - return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, - delta_bias_, delta_softplus, index_, - x) +def selective_scan_fwd( + u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, + C: torch.Tensor, D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): + torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, + delta_softplus, query_start_loc, + cache_indices, has_initial_state, + ssm_states) # moe @@ -813,16 +848,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - @torch.library.register_fake("_moe_C::marlin_gemm_moe") + @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, - g_idx: torch.Tensor, perm: torch.Tensor, - workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, size_n: int, size_k: int, - is_k_full: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, From e207343cf8e42b94975897e74bc68b2c9f18759e Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 18:35:08 -0500 Subject: [PATCH 05/13] add for upstream pr check --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index b04613c6b7a6e..98410b4e90df3 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include "../../dispatch_utils.h" From fcddad82b000322573f410b668c62e01882d92e1 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 18:38:00 -0500 Subject: [PATCH 06/13] Make literal into long in std::min --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 98410b4e90df3..8ba1cb9cd0c35 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,6 +1,5 @@ #include #include -#include #include #include "../../dispatch_utils.h" @@ -232,7 +231,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int64_t const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 1024L)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -263,7 +262,7 @@ void dynamic_scaled_int8_quant( int64_t const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 1024L)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { From d02a21b0f08c7696fe49eb8849b2ecd9c40574ff Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 22:31:55 -0500 Subject: [PATCH 07/13] change hidden_size back to int --- .../quantization/compressed_tensors/int8_quant_kernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 8ba1cb9cd0c35..895d2389c0070 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -94,7 +94,7 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const int64_t hidden_size) { + scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; @@ -111,7 +111,7 @@ template __global__ void static_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, azp_type const* azp_ptr, - const int64_t hidden_size) { + const int hidden_size) { int const tid = threadIdx.x; int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; @@ -130,7 +130,7 @@ __global__ void static_scaled_int8_azp_quant_kernel( template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const int64_t hidden_size) { + scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; int64_t const token_idx = blockIdx.x; float absmax_val = 0.0f; @@ -165,7 +165,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( template __global__ void dynamic_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, azp_type* azp, const int64_t hidden_size) { + scale_type* scale, azp_type* azp, const int hidden_size) { int64_t const token_idx = blockIdx.x; out += token_idx * hidden_size; From b9df1dac2294b52c19553056826e401b343fda16 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 22:37:36 -0500 Subject: [PATCH 08/13] change hidden_size back to int --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 895d2389c0070..819971fe8e850 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -259,10 +259,10 @@ void dynamic_scaled_int8_quant( TORCH_CHECK(scales.is_contiguous()); TORCH_CHECK(!azp || azp->is_contiguous()); - int64_t const hidden_size = input.size(-1); + int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024L)); + dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { From c50ce429650d35382120940efecf5eefc8a1bb59 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 22:39:45 -0500 Subject: [PATCH 09/13] change hidden_size back to int --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 819971fe8e850..a2ee027be6c5b 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -228,10 +228,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(scale.numel() == 1); TORCH_CHECK(!azp || azp->numel() == 1); - int64_t const hidden_size = input.size(-1); + int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024L)); + dim3 const block(std::min(hidden_size, 102L)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { From f547c43e4028035f11ccf5324e36199a4f82ee1b Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 15 Oct 2024 22:40:16 -0500 Subject: [PATCH 10/13] change hidden_size back to int --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index a2ee027be6c5b..07fd00c2d471b 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -231,7 +231,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 102L)); + dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { From 1e91784cbbbf051dbd487dc1414aefcd4af3e4ac Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 16 Oct 2024 12:44:47 -0500 Subject: [PATCH 11/13] add comment to clarify use of 64-bit math --- .../compressed_tensors/int8_quant_kernels.cu | 6 +++++- vllm/_custom_ops.py | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 07fd00c2d471b..f5113ae534dc1 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -96,9 +96,10 @@ __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; + // Must be performed using 64-bit math to avoid integer overflow. out += token_idx * hidden_size; input += token_idx * hidden_size; @@ -117,6 +118,7 @@ __global__ void static_scaled_int8_azp_quant_kernel( scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; + // Must be performed using 64-bit math to avoid integer overflow. out += token_idx * hidden_size; input += token_idx * hidden_size; @@ -136,6 +138,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( float absmax_val = 0.0f; float const zero = 0.0f; + // Must be performed using 64-bit math to avoid integer overflow. out += token_idx * hidden_size; input += token_idx * hidden_size; @@ -168,6 +171,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( scale_type* scale, azp_type* azp, const int hidden_size) { int64_t const token_idx = blockIdx.x; + // Must be performed using 64-bit math to avoid integer overflow. out += token_idx * hidden_size; input += token_idx * hidden_size; diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a23692285efe..3ec5830369d11 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -509,8 +509,10 @@ def cutlass_scaled_mm(a: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + # torch.ops._rocm_C.hip_scaled_mm(out, a, b, scale_a, scale_b, bias) + # out = torch.zeros((m, n), dtype=out_dtype, device=a.device) return out @@ -740,13 +742,21 @@ def scaled_int8_quant( Returns: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ - output = torch.empty_like(input, dtype=torch.int8) + # output = torch.empty_like(input, dtype=torch.int8) + output = torch.zeros_like(input, dtype=torch.int8) + input_scales = torch.zeros((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.zeros_like(input_scales, + dtype=torch.int32) if scale is not None: + print(f"call static_scaled_fp8_quant") # static-per-tensor quantization. assert symmetric == ( azp is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + print(f"done with static_scaled_fp8_quant") return output, scale, None # dynamic-per-token quantization. @@ -755,6 +765,9 @@ def scaled_int8_quant( dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + print(f"call dynamic_scaled_int8_quant output = {output.shape}, " + f"input = {input.shape}, input_scales = {input_scales.shape}," + f"input_azp = {'None' if input_azp is None else input_azp.shape}") torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp From e63904bb1b7d0f435f78ac738b8ac10cff584585 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 16 Oct 2024 12:45:26 -0500 Subject: [PATCH 12/13] use current custom ops --- vllm/_custom_ops.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3ec5830369d11..3a23692285efe 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -509,10 +509,8 @@ def cutlass_scaled_mm(a: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - # torch.ops._rocm_C.hip_scaled_mm(out, a, b, scale_a, scale_b, bias) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - # out = torch.zeros((m, n), dtype=out_dtype, device=a.device) return out @@ -742,21 +740,13 @@ def scaled_int8_quant( Returns: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ - # output = torch.empty_like(input, dtype=torch.int8) - output = torch.zeros_like(input, dtype=torch.int8) - input_scales = torch.zeros((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.zeros_like(input_scales, - dtype=torch.int32) + output = torch.empty_like(input, dtype=torch.int8) if scale is not None: - print(f"call static_scaled_fp8_quant") # static-per-tensor quantization. assert symmetric == ( azp is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) - print(f"done with static_scaled_fp8_quant") return output, scale, None # dynamic-per-token quantization. @@ -765,9 +755,6 @@ def scaled_int8_quant( dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - print(f"call dynamic_scaled_int8_quant output = {output.shape}, " - f"input = {input.shape}, input_scales = {input_scales.shape}," - f"input_azp = {'None' if input_azp is None else input_azp.shape}") torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp From 6253fe99e8b9db3556d06e22d41a5c599e0dbb94 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 16 Oct 2024 12:50:23 -0500 Subject: [PATCH 13/13] clang format --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index f5113ae534dc1..e9987535bd3ea 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -96,7 +96,7 @@ __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; // Must be performed using 64-bit math to avoid integer overflow.