From 2e0031a342f16347a5711c10208d963a19b3a137 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 2 Dec 2024 21:12:24 +0000 Subject: [PATCH 01/31] init Signed-off-by: Sage Moore --- CMakeLists.txt | 3 +- csrc/activation_quant_kernels.cu | 127 +++++++++++++++++++++++++++++++ csrc/core/math.hpp | 18 +++++ csrc/ops.h | 2 + csrc/torch_bindings.cpp | 4 +- vllm/compilation/backends.py | 7 +- vllm/compilation/fusion.py | 51 ++++++++++--- vllm/compilation/pass_manager.py | 2 +- 8 files changed, 198 insertions(+), 16 deletions(-) create mode 100644 csrc/activation_quant_kernels.cu create mode 100644 csrc/core/math.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..78331857177a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,7 +240,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/activation_quant_kernels.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/activation_quant_kernels.cu b/csrc/activation_quant_kernels.cu new file mode 100644 index 0000000000000..03edb37450a20 --- /dev/null +++ b/csrc/activation_quant_kernels.cu @@ -0,0 +1,127 @@ +#include +#include +#include + +#include +#include "core/math.hpp" +#include "cuda_compat.h" +#include "dispatch_utils.h" + +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +// using FP8_TYPE = c10::Float8_e4m3fnuz; +namespace vllm { + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } + else { + x = val / scale; + } + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +// Activation and gating kernel template. +template +__global__ void act_and_mul_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const float* scale, + const int d) { + + const int32_t token_idx = blockIdx.x; + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); + const int32_t elems_per_block = + next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + const int32_t block_start = blockIdx.y * elems_per_block; + int32_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + FP8_TYPE* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + prev_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int2* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + + float inverted_scale = 1 / *scale; +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + using scalar_64bit_vec_t = std::array; + + scalar_64bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i] , inverted_scale); + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + // out_ptr[idx] = ACT_FN(x) * y; + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y , inverted_scale); + } + } +} +} + + +// Launch activation, gating, and quantize kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_quant_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + scale.data_ptr(), \ + d); \ + }); + +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, + torch::Tensor& scale) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} \ No newline at end of file diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..bd5241c5703fc --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,18 @@ +#pragma once + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Compute the next multiple of a that is greater than or equal to b +template +static inline constexpr auto next_multiple_of(A a, B b) { + return div_ceil(b, a) * a; +} + +// Compute the largest multiple of a that is less than or equal to b +template +static inline constexpr auto prev_multiple_of(A a, B b) { + return (b / a) * a; +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index ea001190bc202..cdbf6297e816e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -78,6 +78,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4e64b9c92773a..6db92eb98a8f9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -52,9 +52,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + ops.def("silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); // Activation function used in GeGLU with `none` approximation. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..614cce6903258 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -199,6 +199,7 @@ def __init__( self, compilation_configs: CompilationConfig, ): + print("GETTING TO BACKEND") global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -253,8 +254,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) - logger.debug("%s", lazy_format_graph_code("after split", - self.split_gm)) + logger.debug("%s", lazy_format_graph_code("after split", self.split_gm)) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) @@ -480,8 +480,7 @@ def __call__(self, *args) -> Any: ] assert new_input_addresses == entry.input_addresses, ( "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) + f" Expected {entry.input_addresses}, got {new_input_addresses}") entry.cudagraph.replay() return entry.output diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 5efa410fab6a0..e5f9f0f6f10e8 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -14,6 +14,31 @@ logger = init_logger(__name__) +def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + # result + return at2[1] + + +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + print("REPLACEMENT RUNNING") + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + # result, residual + return at[1] + + def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): @@ -171,8 +196,8 @@ def __init__(self, config: CompilationConfig.PassConfig): empty_bf16(1, 5), empty_fp32(1, 1) ] - register_replacement(rms_pattern_static, rms_replacement_static, - inputs, fwd_only, self.patterns) + register_replacement(rms_pattern_static, rms_replacement_static, inputs, + fwd_only, self.patterns) # Fuse fused_add_rms_norm + static_scaled_fp8_quant into # fused_add_rms_norm_static_fp8_quant @@ -192,6 +217,16 @@ def __init__(self, config: CompilationConfig.PassConfig): self.patterns, extra_check=lambda m: self.record_match(m)) + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_fp32(1, 1) + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) + def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and # save it for post-processing. @@ -229,17 +264,15 @@ def process_matches(self, graph: torch.fx.Graph): kwargs = match.kwargs kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm - fused_node = graph.call_function( - auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ), - kwargs=kwargs) + fused_node = graph.call_function(auto_functionalized, ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - residual_node_new = graph.call_function( - operator.getitem, (fused_node, 2)) + residual_node_new = graph.call_function(operator.getitem, + (fused_node, 2)) # Last part of replacement is rebinding the users of nodes in the # match to use the new nodes. diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index fb522ae053e97..26defd035688b 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -44,7 +44,7 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if pass_config.enable_reshape: self.passes += [RedundantReshapesPass(pass_config)] - if pass_config.enable_fusion: + if True: self.passes += [FusionPass.instance(pass_config)] self.fix_functionalization = FixFunctionalizationPass(pass_config) From 8a957c766872619f42b86af44624dee4cf8ed65c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 3 Dec 2024 16:53:28 +0000 Subject: [PATCH 02/31] remove backend format changes Signed-off-by: Sage Moore --- vllm/compilation/backends.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 614cce6903258..7f0e0f7e494c9 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -199,7 +199,6 @@ def __init__( self, compilation_configs: CompilationConfig, ): - print("GETTING TO BACKEND") global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -254,7 +253,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) - logger.debug("%s", lazy_format_graph_code("after split", self.split_gm)) + logger.debug("%s", lazy_format_graph_code("after split", + self.split_gm)) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) @@ -480,7 +480,8 @@ def __call__(self, *args) -> Any: ] assert new_input_addresses == entry.input_addresses, ( "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}") + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) entry.cudagraph.replay() return entry.output From 2913716b7492c22b276ee934f68ae638ad2be5b2 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 3 Dec 2024 16:56:37 +0000 Subject: [PATCH 03/31] format Signed-off-by: Sage Moore --- csrc/activation_quant_kernels.cu | 52 +++++++++++++++----------------- csrc/ops.h | 3 +- csrc/torch_bindings.cpp | 3 +- vllm/compilation/backends.py | 6 ++-- 4 files changed, 30 insertions(+), 34 deletions(-) diff --git a/csrc/activation_quant_kernels.cu b/csrc/activation_quant_kernels.cu index 03edb37450a20..7786f54f80c6d 100644 --- a/csrc/activation_quant_kernels.cu +++ b/csrc/activation_quant_kernels.cu @@ -25,8 +25,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, float x = 0.0f; if constexpr (is_scale_inverted) { x = val * scale; - } - else { + } else { x = val / scale; } float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); @@ -38,9 +37,7 @@ template __global__ void act_and_mul_quant_kernel( FP8_TYPE* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] - const float* scale, - const int d) { - + const float* scale, const int d) { const int32_t token_idx = blockIdx.x; const int32_t blocks_per_token = gridDim.y; @@ -82,7 +79,8 @@ __global__ void act_and_mul_quant_kernel( #pragma unroll for (int i = 0; i < elems_per_128bit_load; i++) { - out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i] , inverted_scale); + out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], + inverted_scale); } out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); @@ -95,33 +93,31 @@ __global__ void act_and_mul_quant_kernel( const scalar_t x = VLLM_LDG(&x_ptr[idx]); const scalar_t y = VLLM_LDG(&y_ptr[idx]); // out_ptr[idx] = ACT_FN(x) * y; - out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y , inverted_scale); + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); } } } -} - +} // namespace vllm // Launch activation, gating, and quantize kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ - dim3 block(std::min(d, 512)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_quant_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), \ - scale.data_ptr(), \ - d); \ - }); - -void silu_and_mul_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, - torch::Tensor& scale) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_quant_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + scale.data_ptr(), d); \ + }); + +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, + torch::Tensor& scale) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index cdbf6297e816e..801c8d753757c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -78,7 +78,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); -void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6db92eb98a8f9..4d892194286fa 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -55,7 +55,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - ops.def("silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + ops.def( + "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); // Activation function used in GeGLU with `none` approximation. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7f0e0f7e494c9..c14f033db3df3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -253,8 +253,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) - logger.debug("%s", lazy_format_graph_code("after split", - self.split_gm)) + logger.debug("%s", lazy_format_graph_code("after split", self.split_gm)) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) @@ -480,8 +479,7 @@ def __call__(self, *args) -> Any: ] assert new_input_addresses == entry.input_addresses, ( "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) + f" Expected {entry.input_addresses}, got {new_input_addresses}") entry.cudagraph.replay() return entry.output From 11c6faef3551c0b71e88e178550024ba28f3cf7c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 3 Dec 2024 17:18:18 +0000 Subject: [PATCH 04/31] move activation_quant_kernels to the quantization dir Signed-off-by: Sage Moore --- CMakeLists.txt | 2 +- .../activation_kernels.cu} | 25 +++++++------------ vllm/compilation/fusion.py | 3 --- 3 files changed, 10 insertions(+), 20 deletions(-) rename csrc/{activation_quant_kernels.cu => quantization/activation_kernels.cu} (87%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 78331857177a3..986ac9ee95306 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,7 +241,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" - "csrc/activation_quant_kernels.cu") + "csrc/quantization/activation_kernels.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/activation_quant_kernels.cu b/csrc/quantization/activation_kernels.cu similarity index 87% rename from csrc/activation_quant_kernels.cu rename to csrc/quantization/activation_kernels.cu index 7786f54f80c6d..4245a55b2f1de 100644 --- a/csrc/activation_quant_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -3,9 +3,9 @@ #include #include -#include "core/math.hpp" -#include "cuda_compat.h" -#include "dispatch_utils.h" +#include "../core/math.hpp" +#include "../cuda_compat.h" +#include "../dispatch_utils.h" using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = @@ -19,15 +19,9 @@ __device__ __forceinline__ T silu_kernel(const T& x) { return (T)(((float)x) / (1.0f + expf((float)-x))); } -template -__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, - float const scale) { - float x = 0.0f; - if constexpr (is_scale_inverted) { - x = val * scale; - } else { - x = val / scale; - } +__device__ __forceinline__ FP8_TYPE +scaled_fp8_conversion(float const val, float const inverted_scale) { + float x = val * inverted_scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); } @@ -79,8 +73,8 @@ __global__ void act_and_mul_quant_kernel( #pragma unroll for (int i = 0; i < elems_per_128bit_load; i++) { - out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], - inverted_scale); + out_vec[i] = + scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], inverted_scale); } out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); @@ -92,8 +86,7 @@ __global__ void act_and_mul_quant_kernel( idx += blockDim.x) { const scalar_t x = VLLM_LDG(&x_ptr[idx]); const scalar_t y = VLLM_LDG(&y_ptr[idx]); - // out_ptr[idx] = ACT_FN(x) * y; - out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); } } } diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e5f9f0f6f10e8..c0266af0fd3c1 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -23,19 +23,16 @@ def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, result=result, input=at1[1], scale=scale) - # result return at2[1] def silu_mul_replacement_static(result: torch.Tensor, result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): - print("REPLACEMENT RUNNING") at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, result=result, input=input, scale=scale) - # result, residual return at[1] From 2dfecb5b200d99737d4ca3aae049cbf2c4f154d5 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 4 Dec 2024 15:24:52 +0000 Subject: [PATCH 05/31] added replacement unit test Signed-off-by: Sage Moore --- tests/compile/test_silu_mul_quant_fusion.py | 74 +++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/compile/test_silu_mul_quant_fusion.py diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py new file mode 100644 index 0000000000000..9ac7e51686b5e --- /dev/null +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -0,0 +1,74 @@ +import pytest +import torch +from compressed_tensors.quantization import FP8_DTYPE + +import vllm.envs as envs +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm._custom_ops import scaled_fp8_quant + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.silu_and_mul = SiluAndMul() + self.scale = torch.rand(1, dtype=torch.float32) + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = scaled_fp8_quant(y, self.scale) + return x2 + + +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") +def test_fusion_silu_and_mul_quant(hidden_size, num_tokens): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.float16) + + # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + + backend = TestBackend(reshape_pass, fusion_pass) + model = TestModel() + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + torch.testing.assert_close(result[0].to(dtype=torch.float16), + result2[0].to(dtype=torch.float16), + atol=1e-3, + rtol=1e-3) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, silu_and_mul_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None From 702fa46cd40988d380fb7f97d8aa4a8ebffd5038 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 5 Dec 2024 22:31:30 +0000 Subject: [PATCH 06/31] added kernel unit test Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 2 + tests/compile/test_silu_mul_quant_fusion.py | 1 - tests/kernels/test_fused_quant_activation.py | 73 ++++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 tests/kernels/test_fused_quant_activation.py diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 4245a55b2f1de..5ca0df1343ed3 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -112,5 +112,7 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, torch::Tensor& scale) // [..., 2 * d] { + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } \ No newline at end of file diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 9ac7e51686b5e..6d890487faa5f 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,6 +1,5 @@ import pytest import torch -from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs from vllm.compilation.fusion import (FusionPass, find_auto_fn, diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py new file mode 100644 index 0000000000000..d661bcaa7953d --- /dev/null +++ b/tests/kernels/test_fused_quant_activation.py @@ -0,0 +1,73 @@ +from typing import Optional, Tuple, Union + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.model_executor.layers.activation import SiluAndMul + +DTYPES = [torch.bfloat16, torch.float16] +QUANT_DTYPES = [torch.float8_e4m3fn] +NUM_TOKENS = [32, 64, 128, 2048, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [32, 64, 128, 2048, 4096] # Arbitrary values for testing +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + # Norm + silu_and_mul_out = silu_and_mul.forward_native(x) + # Quant + out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) + return out + + +def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 2) + out = torch.empty(out_shape, + dtype=torch.torch.float8_e4m3fn, + device=x.device) + torch.ops._C.silu_and_mul_quant(out, x, scale) + return out + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_silu_and_mul( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + layer = SiluAndMul() + + # Make inputs + scale = (torch.randn((1), device=device, dtype=torch.float32)) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + ref_out = ref_impl(layer, x, scale) + ops_out = ops_impl(x, scale) + + # print(ref_out) + # print("@@@@@@@@@@@@@@@@@@@@@@@@") + # print(ops_out) + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + assert ref_out.shape == ops_out.shape + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) From 583ff4c4571b43546d8526bb5d55b0a36809944b Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 14:48:00 +0000 Subject: [PATCH 07/31] misc cleanup Signed-off-by: Sage Moore --- tests/compile/test_silu_mul_quant_fusion.py | 4 ++-- tests/kernels/test_fused_quant_activation.py | 7 +------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 6d890487faa5f..e98d8a6cd2c12 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -25,11 +25,11 @@ def forward(self, x): return x2 -@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fusion_silu_and_mul_quant(hidden_size, num_tokens): +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index d661bcaa7953d..c8eb5c3a00d0f 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple, Union - import pytest import torch @@ -8,7 +6,7 @@ DTYPES = [torch.bfloat16, torch.float16] QUANT_DTYPES = [torch.float8_e4m3fn] -NUM_TOKENS = [32, 64, 128, 2048, 4096] # Arbitrary values for testing +NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing HIDDEN_SIZES = [32, 64, 128, 2048, 4096] # Arbitrary values for testing SEEDS = [0] CUDA_DEVICES = [ @@ -63,9 +61,6 @@ def test_silu_and_mul( ref_out = ref_impl(layer, x, scale) ops_out = ops_impl(x, scale) - # print(ref_out) - # print("@@@@@@@@@@@@@@@@@@@@@@@@") - # print(ops_out) assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype assert ref_out.shape == ops_out.shape From e5680f71e1ca7498a8c804d3c0c60f5e776a0952 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 15:03:34 +0000 Subject: [PATCH 08/31] move activation quant fusion to its own pass Signed-off-by: Sage Moore --- vllm/compilation/activation_quant_fusion.py | 106 ++++++++++++++++++++ vllm/compilation/fusion.py | 32 ------ vllm/compilation/pass_manager.py | 2 + 3 files changed, 108 insertions(+), 32 deletions(-) create mode 100644 vllm/compilation/activation_quant_fusion.py diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py new file mode 100644 index 0000000000000..71aa0a306fecd --- /dev/null +++ b/vllm/compilation/activation_quant_fusion.py @@ -0,0 +1,106 @@ +import operator +from typing import Iterable, Optional + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, + register_replacement) + +from vllm.config import CompilationConfig +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass, is_func + +logger = init_logger(__name__) + + +def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + return at[1] + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp8(*args, **kwargs): + fp8 = torch.float8_e4m3fn + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +class ActivationQuantFusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + It also manually processes multi-output matches, as those are broken in + the torch pattern matcher. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + _instance: 'Optional[ActivationQuantFusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig.PassConfig): + """ + Get the singleton instance of the ActivationQuantFusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = ActivationQuantFusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config: CompilationConfig.PassConfig): + assert self.__class__._instance is None, \ + "ActivationQuantFusionPass singleton instance already exists" + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="activation_quant_fusion_pass") + + inputs = [ + empty_fp8(5, 4), # Quant output + empty_bf16(5, 4), # Silu_and_mul output + empty_bf16(5, 4), # Input + empty_fp32(1, 1) # Scale + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_act_quant_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns in ActivationQuantFusionPass", count) + self.dump_graph(graph, "after_pattern_match") + + self.dump_graph(graph, "after_act_quant_fusion") + self.end_and_log() diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index c0266af0fd3c1..ead3bfef8ba5c 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -14,28 +14,6 @@ logger = init_logger(__name__) -def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - return at2[1] - - -def silu_mul_replacement_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, - result=result, - input=input, - scale=scale) - return at[1] - - def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): @@ -214,16 +192,6 @@ def __init__(self, config: CompilationConfig.PassConfig): self.patterns, extra_check=lambda m: self.record_match(m)) - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_fp32(1, 1) - ] - register_replacement(silu_mul_pattern_static, - silu_mul_replacement_static, inputs, fwd_only, - self.patterns) - def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and # save it for post-processing. diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 26defd035688b..78702f21764cc 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -7,6 +7,7 @@ from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass +from .activation_quant_fusion import ActivationQuantFusionPass from .inductor_pass import InductorPass from .reshapes import RedundantReshapesPass @@ -46,6 +47,7 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if True: self.passes += [FusionPass.instance(pass_config)] + self.passes += [ActivationQuantFusionPass.instance(pass_config)] self.fix_functionalization = FixFunctionalizationPass(pass_config) From 4b775c40712a32114581b4822bf2622c50a1ab06 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 15:05:20 +0000 Subject: [PATCH 09/31] update test Signed-off-by: Sage Moore --- tests/compile/test_silu_mul_quant_fusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index e98d8a6cd2c12..69b54f75d9644 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -2,8 +2,8 @@ import torch import vllm.envs as envs -from vllm.compilation.fusion import (FusionPass, find_auto_fn, - find_auto_fn_maybe) +from vllm.compilation.fusion import (find_auto_fn, find_auto_fn_maybe) +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.reshapes import RedundantReshapesPass from vllm.config import CompilationConfig from vllm.model_executor.layers.activation import SiluAndMul @@ -37,7 +37,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): config = CompilationConfig.PassConfig(enable_fusion=True, enable_reshape=True) reshape_pass = RedundantReshapesPass(config) - fusion_pass = FusionPass.instance(config) + fusion_pass = ActivationQuantFusionPass.instance(config) backend = TestBackend(reshape_pass, fusion_pass) model = TestModel() From d5ff8659a4b10156d990399a8393a9730c3e6b29 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 15:08:39 +0000 Subject: [PATCH 10/31] format Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 3 ++- tests/compile/test_silu_mul_quant_fusion.py | 4 ++-- vllm/compilation/activation_quant_fusion.py | 5 ++--- vllm/compilation/pass_manager.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 5ca0df1343ed3..ce5d2c346c535 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -113,6 +113,7 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] torch::Tensor& scale) // [..., 2 * d] { TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); + TORCH_CHECK(input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } \ No newline at end of file diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 69b54f75d9644..7a6fb8725420c 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -2,12 +2,12 @@ import torch import vllm.envs as envs -from vllm.compilation.fusion import (find_auto_fn, find_auto_fn_maybe) +from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass +from vllm.compilation.fusion import find_auto_fn, find_auto_fn_maybe from vllm.compilation.reshapes import RedundantReshapesPass from vllm.config import CompilationConfig from vllm.model_executor.layers.activation import SiluAndMul -from vllm._custom_ops import scaled_fp8_quant from .backend import TestBackend diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 71aa0a306fecd..16ce2a8576e49 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -1,5 +1,4 @@ -import operator -from typing import Iterable, Optional +from typing import Optional import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -9,7 +8,7 @@ from vllm.config import CompilationConfig from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 78702f21764cc..eb3bc7b0e3133 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,9 +5,9 @@ from vllm.config import CompilationConfig from vllm.logger import init_logger +from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass -from .activation_quant_fusion import ActivationQuantFusionPass from .inductor_pass import InductorPass from .reshapes import RedundantReshapesPass From c970dec21df3e682dd9448ff1bb12805276ec82d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 15:58:24 +0000 Subject: [PATCH 11/31] format Signed-off-by: Sage Moore --- benchmarks/benchmark_latency.py | 26 ++++++++++++++------------ vllm/compilation/fusion.py | 16 +++++++++------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 0a14aedd5feba..23c683c6c37b5 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -5,7 +5,7 @@ import time from pathlib import Path from typing import List, Optional - +from vllm.config import CompilationConfig import numpy as np import torch from tqdm import tqdm @@ -21,6 +21,11 @@ def main(args: argparse.Namespace): engine_args = EngineArgs.from_cli_args(args) + config = CompilationConfig( + level=3, + custom_ops=["+silu_and_mul"], + ) + engine_args.compilation_config = config # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) @@ -70,8 +75,7 @@ def run_to_completion(profile_dir: Optional[str] = None): profile_dir = args.profile_result_dir if not profile_dir: profile_dir = Path( - "." - ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" + ".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -118,21 +122,19 @@ def run_to_completion(profile_dir: Optional[str] = None): type=int, default=30, help='Number of iterations to run.') - parser.add_argument( - '--profile', - action='store_true', - help='profile the generation process of a single batch') + parser.add_argument('--profile', + action='store_true', + help='profile the generation process of a single batch') parser.add_argument( '--profile-result-dir', type=str, default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) - parser.add_argument( - '--output-json', - type=str, - default=None, - help='Path to save the latency results in JSON format.') + parser.add_argument('--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index ead3bfef8ba5c..e4ed2d4edb8f4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -171,8 +171,8 @@ def __init__(self, config: CompilationConfig.PassConfig): empty_bf16(1, 5), empty_fp32(1, 1) ] - register_replacement(rms_pattern_static, rms_replacement_static, inputs, - fwd_only, self.patterns) + register_replacement(rms_pattern_static, rms_replacement_static, + inputs, fwd_only, self.patterns) # Fuse fused_add_rms_norm + static_scaled_fp8_quant into # fused_add_rms_norm_static_fp8_quant @@ -229,15 +229,17 @@ def process_matches(self, graph: torch.fx.Graph): kwargs = match.kwargs kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm - fused_node = graph.call_function(auto_functionalized, ( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ), - kwargs=kwargs) + fused_node = graph.call_function( + auto_functionalized, ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + ), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - residual_node_new = graph.call_function(operator.getitem, - (fused_node, 2)) + residual_node_new = graph.call_function( + operator.getitem, (fused_node, 2)) # Last part of replacement is rebinding the users of nodes in the # match to use the new nodes. From 596c4453723842c35d072eddacc7009e12885b43 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 15:59:39 +0000 Subject: [PATCH 12/31] format Signed-off-by: Sage Moore --- benchmarks/benchmark_latency.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 23c683c6c37b5..50dea6c17c45b 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -5,7 +5,7 @@ import time from pathlib import Path from typing import List, Optional -from vllm.config import CompilationConfig + import numpy as np import torch from tqdm import tqdm @@ -21,11 +21,6 @@ def main(args: argparse.Namespace): engine_args = EngineArgs.from_cli_args(args) - config = CompilationConfig( - level=3, - custom_ops=["+silu_and_mul"], - ) - engine_args.compilation_config = config # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) From 7ab3e1855a7bbf5a0c071bb33aa119aecf3ef98d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 16:03:15 +0000 Subject: [PATCH 13/31] format Signed-off-by: Sage Moore --- vllm/compilation/fusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e4ed2d4edb8f4..dd1a1736605ec 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -230,9 +230,9 @@ def process_matches(self, graph: torch.fx.Graph): kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm fused_node = graph.call_function( - auto_functionalized, ( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ), + auto_functionalized, + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + ), kwargs=kwargs) graph.inserting_after(fused_node) From d347431834e6b9923a3d89696226f4930f2b8966 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 16:03:41 +0000 Subject: [PATCH 14/31] format Signed-off-by: Sage Moore --- vllm/compilation/fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index dd1a1736605ec..5efa410fab6a0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -231,7 +231,7 @@ def process_matches(self, graph: torch.fx.Graph): fused_node = graph.call_function( auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ), kwargs=kwargs) From 553d99ccf15b6399229f773ea1738b868148d80c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 16:05:21 +0000 Subject: [PATCH 15/31] format Signed-off-by: Sage Moore --- vllm/compilation/backends.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c14f033db3df3..464bc2af8fd6d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -253,7 +253,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) - logger.debug("%s", lazy_format_graph_code("after split", self.split_gm)) + logger.debug("%s", lazy_format_graph_code("after split", + self.split_gm)) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) @@ -479,7 +480,8 @@ def __call__(self, *args) -> Any: ] assert new_input_addresses == entry.input_addresses, ( "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}") + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) entry.cudagraph.replay() return entry.output From 774559dd563fbfea2c6f0d53a2306f0aeedd34ed Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 20:12:08 +0000 Subject: [PATCH 16/31] format Signed-off-by: Sage Moore --- benchmarks/benchmark_latency.py | 19 +++++++++++-------- vllm/compilation/activation_quant_fusion.py | 3 ++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 50dea6c17c45b..0a14aedd5feba 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -70,7 +70,8 @@ def run_to_completion(profile_dir: Optional[str] = None): profile_dir = args.profile_result_dir if not profile_dir: profile_dir = Path( - ".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + "." + ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -117,19 +118,21 @@ def run_to_completion(profile_dir: Optional[str] = None): type=int, default=30, help='Number of iterations to run.') - parser.add_argument('--profile', - action='store_true', - help='profile the generation process of a single batch') + parser.add_argument( + '--profile', + action='store_true', + help='profile the generation process of a single batch') parser.add_argument( '--profile-result-dir', type=str, default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) - parser.add_argument('--output-json', - type=str, - default=None, - help='Path to save the latency results in JSON format.') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 16ce2a8576e49..146f77546a7c3 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -98,7 +98,8 @@ def __call__(self, graph: torch.fx.Graph): self.dump_graph(graph, "before_act_quant_fusion") count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns in ActivationQuantFusionPass", count) + logger.debug("Replaced %s patterns in ActivationQuantFusionPass", + count) self.dump_graph(graph, "after_pattern_match") self.dump_graph(graph, "after_act_quant_fusion") From e2fda7f92bbba73b216829f667c4e5f28c0c5ab8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Dec 2024 20:15:57 +0000 Subject: [PATCH 17/31] format Signed-off-by: Sage Moore --- vllm/compilation/activation_quant_fusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 146f77546a7c3..c073d0fb7f767 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -13,8 +13,9 @@ logger = init_logger(__name__) -def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): +def silu_mul_pattern_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, result=result_silu_mul, input=input) From 6915fa2f5e4001fbda44a0ef296777e8df7dc63c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 9 Dec 2024 18:12:41 +0000 Subject: [PATCH 18/31] minor comment fix Signed-off-by: Sage Moore --- vllm/compilation/activation_quant_fusion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index c073d0fb7f767..8372de4118c7e 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -53,8 +53,6 @@ class ActivationQuantFusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. - It also manually processes multi-output matches, as those are broken in - the torch pattern matcher. Because patterns can only be registered once, the pass is a singleton. This will be addressed in a future version of PyTorch: From 6d4b8d0f6f5be63ffe22884ea45593d5c54b9a5a Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 9 Dec 2024 18:27:28 +0000 Subject: [PATCH 19/31] minor updates Signed-off-by: Sage Moore --- tests/kernels/test_fused_quant_activation.py | 2 -- vllm/compilation/pass_manager.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index c8eb5c3a00d0f..0ee67e888613e 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -16,9 +16,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - # Norm silu_and_mul_out = silu_and_mul.forward_native(x) - # Quant out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) return out diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index eb3bc7b0e3133..fe48a58c19bae 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -45,7 +45,7 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if pass_config.enable_reshape: self.passes += [RedundantReshapesPass(pass_config)] - if True: + if pass_config.enable_fusion: self.passes += [FusionPass.instance(pass_config)] self.passes += [ActivationQuantFusionPass.instance(pass_config)] From 6b631b05f89131413513c1c2fa808e4e052bf4c2 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 12 Dec 2024 22:30:17 +0000 Subject: [PATCH 20/31] fix fix-functionalization Signed-off-by: Sage Moore --- vllm/compilation/fix_functionalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 3584cc3608caf..0194513e3a2d1 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -69,12 +69,12 @@ def __call__(self, graph: torch.fx.Graph): self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.silu_and_mul.default: - mutated_args = {1: 'out'} + mutated_args = {1: 'result'} # Because we have an 'out', need to specify args directly self.defunctionalize(graph, node, mutated_args, - args=('out', 'input')) + args=('result', 'input')) else: continue # skip the count From 5b78d809aab912bea35cb917ff574c316635e48a Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 13 Dec 2024 14:57:26 +0000 Subject: [PATCH 21/31] add opcheck test for fused op Signed-off-by: Sage Moore --- tests/kernels/test_fused_quant_activation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index 0ee67e888613e..7494f971a14f3 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -1,5 +1,6 @@ import pytest import torch +from tests.kernels.utils import opcheck import vllm._custom_ops as ops from vllm.model_executor.layers.activation import SiluAndMul @@ -64,3 +65,4 @@ def test_silu_and_mul( assert ref_out.shape == ops_out.shape assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) From 391eea54248c13bf5bf981bd0bdc4da9734f4731 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 13 Dec 2024 17:28:22 +0000 Subject: [PATCH 22/31] fix fix_functionalization tests Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 2 +- tests/compile/test_functionalization.py | 30 ++++++++++++++++++++--- vllm/compilation/fix_functionalization.py | 12 +++------ 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index ce5d2c346c535..2fd3b3c7f8530 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -32,7 +32,7 @@ __global__ void act_and_mul_quant_kernel( FP8_TYPE* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const float* scale, const int d) { - const int32_t token_idx = blockIdx.x; + const int64_t token_idx = blockIdx.x; const int32_t blocks_per_token = gridDim.y; const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 5036189077be2..301a1bb053051 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -6,6 +6,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.reshapes import RedundantReshapesPass from vllm.compilation.vllm_inductor_pass import is_func from vllm.config import CompilationConfig @@ -15,7 +16,6 @@ OPS_IN_MODEL = [ torch.ops._C.rotary_embedding.default, torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.silu_and_mul.default, ] RMS_OP = torch.ops._C.rms_norm.default @@ -27,6 +27,13 @@ ], } +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + +SILU_MUL_QUANT_OPS = { + "static_fp8": [ + torch.ops._C.silu_and_mul_quant.default, + ], +} prompts = [ "Hello, my name is", "The president of the United States is", @@ -45,10 +52,17 @@ def test_fix_functionalization(model: str, do_fusion: bool): config = CompilationConfig.PassConfig(enable_fusion=do_fusion, enable_reshape=True) + # compilation_config = CompilationConfig(level=3, + # custom_ops=["+silu_and_mul"]) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) + act_quant_fusion_pass = ActivationQuantFusionPass.instance(config) - passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + passes = [ + reshape_pass, + fusion_pass, + act_quant_fusion_pass, + ] if do_fusion else [reshape_pass] func_pass = FixFunctionalizationPass(config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) @@ -71,15 +85,23 @@ def test_fix_functionalization(model: str, do_fusion: bool): model_runner.model = torch.compile(orig_model, fullgraph=True, backend=backend_no_func) + gen_no_func = llm.generate(prompts, sampling_params) + # print( + # backend_func.graph_pre_pass.python_code(root_module="self", + # verbose=True).src) for output_func, output_no_func in zip(gen_func, gen_no_func): assert output_func.outputs[0].text == output_no_func.outputs[0].text # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, # and replaced by fused quantized ops in RMS_QUANT_OPS. - ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"] - if do_fusion else [RMS_OP]) + rms_ops = RMS_QUANT_OPS["static_fp8"] if do_fusion else [RMS_OP] + silu_mul_ops = SILU_MUL_QUANT_OPS["static_fp8"] if do_fusion else [ + SILU_MUL_OP + ] + + ops = OPS_IN_MODEL + rms_ops + silu_mul_ops for op in ops: find_auto_fn(backend_no_func.graph_post_pass.nodes, op) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 0194513e3a2d1..fdf5461b15827 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -63,18 +63,12 @@ def __call__(self, graph: torch.fx.Graph): elif at_target in [ torch.ops._C.rms_norm.default, - torch.ops._C.rms_norm_static_fp8_quant.default + torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.silu_and_mul.default, + torch.ops._C.silu_and_mul_quant.default, ]: mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args) - - elif at_target == torch.ops._C.silu_and_mul.default: - mutated_args = {1: 'result'} - # Because we have an 'out', need to specify args directly - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'input')) else: continue # skip the count From 0d79c17a11a34352fab269010bf765aef3b2f122 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 13 Dec 2024 18:56:00 +0000 Subject: [PATCH 23/31] fix fix_functionalization again Signed-off-by: Sage Moore --- vllm/compilation/fix_functionalization.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index fdf5461b15827..420481cec86d4 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -64,11 +64,24 @@ def __call__(self, graph: torch.fx.Graph): elif at_target in [ torch.ops._C.rms_norm.default, torch.ops._C.rms_norm_static_fp8_quant.default, - torch.ops._C.silu_and_mul.default, - torch.ops._C.silu_and_mul_quant.default, ]: mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args) + # For some reason we need to specify the args for both + # silu_and_mul and silu_and_mul_quant. The kwargs + # pathway gets the wrong answer. + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'result'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'input')) + elif at_target == torch.ops._C.silu_and_mul_quant.default: + mutated_args = {1: 'result'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'input', 'scale')) else: continue # skip the count From 3198f640d16393fcdfbc1ca9bfd530eaac5ba284 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 13 Dec 2024 22:37:58 +0000 Subject: [PATCH 24/31] format Signed-off-by: Sage Moore --- tests/compile/test_functionalization.py | 2 +- tests/kernels/test_fused_quant_activation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 6cbabf99bf9ae..7424f7c9c3639 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -3,11 +3,11 @@ import vllm.envs as envs from vllm import LLM, SamplingParams +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.reshapes import RedundantReshapesPass from vllm.config import CompilationConfig diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index 7494f971a14f3..d0fee1c1aa014 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -1,8 +1,8 @@ import pytest import torch -from tests.kernels.utils import opcheck import vllm._custom_ops as ops +from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import SiluAndMul DTYPES = [torch.bfloat16, torch.float16] From 58111a958d4262d12ac30926cba9a072fed1fe8c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Sat, 14 Dec 2024 19:36:15 +0000 Subject: [PATCH 25/31] fixup includes Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 5 +---- vllm/compilation/activation_quant_fusion.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 2fd3b3c7f8530..91aa966ce7739 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -3,14 +3,11 @@ #include #include +#include "fp8/common.cuh" #include "../core/math.hpp" #include "../cuda_compat.h" #include "../dispatch_utils.h" -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = - std::numeric_limits::max(); -// using FP8_TYPE = c10::Float8_e4m3fnuz; namespace vllm { template diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 8372de4118c7e..9aa50749fc544 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -99,7 +99,6 @@ def __call__(self, graph: torch.fx.Graph): count = self.patterns.apply(graph) logger.debug("Replaced %s patterns in ActivationQuantFusionPass", count) - self.dump_graph(graph, "after_pattern_match") self.dump_graph(graph, "after_act_quant_fusion") self.end_and_log() From 9a18085cedb4cdf97683489467d7fc82cac054d1 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 16 Dec 2024 15:38:46 +0000 Subject: [PATCH 26/31] refactor math.hpp Signed-off-by: Sage Moore --- csrc/core/math.hpp | 20 ++++++++++++-------- csrc/quantization/activation_kernels.cu | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index bd5241c5703fc..6e79c94e52518 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -5,14 +5,18 @@ static inline constexpr auto div_ceil(A a, B b) { return (a + b - 1) / b; } -// Compute the next multiple of a that is greater than or equal to b -template -static inline constexpr auto next_multiple_of(A a, B b) { - return div_ceil(b, a) * a; +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) +{ + return a % b == 0 ? a : (a / b) * b; } -// Compute the largest multiple of a that is less than or equal to b -template -static inline constexpr auto prev_multiple_of(A a, B b) { - return (b / a) * a; +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) +{ + return a % b == 0 ? a : ((a / b) + 1) * b; } \ No newline at end of file diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 91aa966ce7739..024331fa4e64e 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -36,7 +36,7 @@ __global__ void act_and_mul_quant_kernel( const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); const int32_t elems_per_block = - next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); const int32_t block_start = blockIdx.y * elems_per_block; int32_t block_end = block_start + elems_per_block; block_end = block_end > d ? d : block_end; @@ -47,7 +47,7 @@ __global__ void act_and_mul_quant_kernel( // 128-bit vectorized code const int32_t vec_loop_end = - prev_multiple_of(elems_per_128bit_load, block_end); + round_to_previous_multiple_of(elems_per_128bit_load, block_end); const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; const int32_t vec_start_idx = block_start / elems_per_128bit_load; From e051b24b1c54d2f82e507bcd50dfad384aade86c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 18 Dec 2024 15:11:12 +0000 Subject: [PATCH 27/31] fix amd build Signed-off-by: Sage Moore --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3bef0ed0e0d5e..6a5604594d761 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,6 +198,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") @@ -241,8 +242,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" - "csrc/quantization/activation_kernels.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" From 8514b0e5fbfcdf335ddab873605812f406bed1d9 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 19 Dec 2024 17:29:16 +0000 Subject: [PATCH 28/31] review comments and format Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 14 +++++++++----- tests/compile/test_functionalization.py | 5 ----- tests/kernels/test_fused_quant_activation.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 024331fa4e64e..72210468bd557 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -29,11 +29,12 @@ __global__ void act_and_mul_quant_kernel( FP8_TYPE* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const float* scale, const int d) { - const int64_t token_idx = blockIdx.x; const int32_t blocks_per_token = gridDim.y; const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + // We don't expect the hidden dimension to exceed 32 bits so int32 should + // be safe here. const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); const int32_t elems_per_block = round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); @@ -41,6 +42,9 @@ __global__ void act_and_mul_quant_kernel( int32_t block_end = block_start + elems_per_block; block_end = block_end > d ? d : block_end; + // token_idx is 64 bit to prevent 32 bit overflow when the number of tokens + // is very large + const int64_t token_idx = blockIdx.x; const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; FP8_TYPE* __restrict__ out_ptr = out + token_idx * d; @@ -105,12 +109,12 @@ __global__ void act_and_mul_quant_kernel( scale.data_ptr(), d); \ }); -void silu_and_mul_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, - torch::Tensor& scale) // [..., 2 * d] -{ +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) { TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); + TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } \ No newline at end of file diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 7424f7c9c3639..992108f7873e7 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -49,8 +49,6 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, config = CompilationConfig.PassConfig(enable_fusion=do_fusion, enable_reshape=True) - # compilation_config = CompilationConfig(level=3, - # custom_ops=["+silu_and_mul"]) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) act_quant_fusion_pass = ActivationQuantFusionPass.instance(config) @@ -85,9 +83,6 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, gen_no_func = llm.generate(prompts, sampling_params) - # print( - # backend_func.graph_pre_pass.python_code(root_module="self", - # verbose=True).src) for output_func, output_no_func in zip(gen_func, gen_no_func): assert output_func.outputs[0].text == output_no_func.outputs[0].text diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index d0fee1c1aa014..fdb86c13adbb7 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -8,7 +8,7 @@ DTYPES = [torch.bfloat16, torch.float16] QUANT_DTYPES = [torch.float8_e4m3fn] NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing -HIDDEN_SIZES = [32, 64, 128, 2048, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) From ec1290af35b2f34fedebbe68cae088843be4aa38 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 19 Dec 2024 21:01:10 +0000 Subject: [PATCH 29/31] fix amd build Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 72210468bd557..8f28580961666 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -20,7 +20,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, float const inverted_scale) { float x = val * inverted_scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); - return static_cast(r); + return static_cast(r); } // Activation and gating kernel template. @@ -117,4 +117,4 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] input.dtype() == torch::kBFloat16); TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); -} \ No newline at end of file +} From 008b7250f5607e64a5088f35671914088159d75a Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 20 Dec 2024 19:22:13 +0000 Subject: [PATCH 30/31] review comments and format Signed-off-by: Sage Moore --- csrc/quantization/activation_kernels.cu | 13 +++---------- tests/compile/test_functionalization.py | 14 +++++--------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8f28580961666..9edebd9c3caae 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -16,13 +16,6 @@ __device__ __forceinline__ T silu_kernel(const T& x) { return (T)(((float)x) / (1.0f + expf((float)-x))); } -__device__ __forceinline__ FP8_TYPE -scaled_fp8_conversion(float const val, float const inverted_scale) { - float x = val * inverted_scale; - float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); - return static_cast(r); -} - // Activation and gating kernel template. template __global__ void act_and_mul_quant_kernel( @@ -74,8 +67,8 @@ __global__ void act_and_mul_quant_kernel( #pragma unroll for (int i = 0; i < elems_per_128bit_load; i++) { - out_vec[i] = - scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], inverted_scale); + out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], + inverted_scale); } out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); @@ -87,7 +80,7 @@ __global__ void act_and_mul_quant_kernel( idx += blockDim.x) { const scalar_t x = VLLM_LDG(&x_ptr[idx]); const scalar_t y = VLLM_LDG(&y_ptr[idx]); - out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); } } } diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 992108f7873e7..871e388ec2212 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -22,11 +22,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default -SILU_MUL_QUANT_OPS = { - "static_fp8": [ - torch.ops._C.silu_and_mul_quant.default, - ], -} +SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default prompts = [ "Hello, my name is", "The president of the United States is", @@ -90,10 +86,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # and replaced by fused quantized ops in RMS_QUANT_OPS. rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] ] if do_fusion else [RMS_OP] - silu_mul_ops = SILU_MUL_QUANT_OPS[ - "static_fp8"] if do_fusion and quant_key == kFp8StaticTensorSym else [ - SILU_MUL_OP - ] + silu_mul_ops = SILU_MUL_QUANT_OP if do_fusion and \ + quant_key == kFp8StaticTensorSym else [ + SILU_MUL_OP + ] ops = OPS_IN_MODEL + rms_ops + silu_mul_ops From 4a0ac7e73a486b02812508715a579e486e8ddd54 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 20 Dec 2024 19:28:29 +0000 Subject: [PATCH 31/31] minor test fix Signed-off-by: Sage Moore --- tests/compile/test_functionalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 871e388ec2212..1c8fe53baa0fb 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -86,7 +86,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # and replaced by fused quantized ops in RMS_QUANT_OPS. rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] ] if do_fusion else [RMS_OP] - silu_mul_ops = SILU_MUL_QUANT_OP if do_fusion and \ + silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \ quant_key == kFp8StaticTensorSym else [ SILU_MUL_OP ]