From aa13b896b386bae82ca92e08fd2d83444cc7c72e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 26 Apr 2024 21:53:11 +0000 Subject: [PATCH 01/18] Hopper Int and FP8 kernel for per tensor or per row/col --- .gitmodules | 3 + CMakeLists.txt | 5 +- csrc/ops.h | 7 + csrc/pybind.cpp | 1 + csrc/quantization/cutlass/scaled_mm_dq.cu | 285 ++++++++++++++++++++++ csrc/third_party/cutlass | 1 + cutlass_fp8_fused_dq_and_scales.py | 129 ++++++++++ cutlass_int8_fused_dq_and_scales.py | 124 ++++++++++ tests/kernels/test_cutlass.py | 112 +++++++++ vllm/_custom_ops.py | 9 + 10 files changed, 675 insertions(+), 1 deletion(-) create mode 100644 .gitmodules create mode 100644 csrc/quantization/cutlass/scaled_mm_dq.cu create mode 160000 csrc/third_party/cutlass create mode 100644 cutlass_fp8_fused_dq_and_scales.py create mode 100644 cutlass_int8_fused_dq_and_scales.py create mode 100644 tests/kernels/test_cutlass.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000..a8108241542e9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "csrc/third_party/cutlass"] + path = csrc/third_party/cutlass + url = https://github.com/nvidia/cutlass diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c7dfe0c048b0..a8e52cf7b91be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -165,6 +165,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/quantization/cutlass/scaled_mm_dq.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/fp8/common.cu" @@ -188,7 +189,9 @@ define_gpu_extension_target( LANGUAGE ${VLLM_GPU_LANG} SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} + # ARCHITECTURES ${VLLM_GPU_ARCHES} + ARCHITECTURES 90a + INCLUDE_DIRECTORIES csrc/third_party/cutlass/include;csrc/third_party/cutlass/tools/util/include WITH_SOABI) # diff --git a/csrc/ops.h b/csrc/ops.h index 9541adcb3de88..678c80d7d0511 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -144,6 +144,13 @@ torch::Tensor gptq_marlin_repack( int64_t size_k, int64_t size_n, int64_t num_bits); + +int cutlass_scaled_mm_dq( + torch::Tensor& out, + torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 173e0b1732e13..ca18ca7eb5eea 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); diff --git a/csrc/quantization/cutlass/scaled_mm_dq.cu b/csrc/quantization/cutlass/scaled_mm_dq.cu new file mode 100644 index 0000000000000..ff45417990c46 --- /dev/null +++ b/csrc/quantization/cutlass/scaled_mm_dq.cu @@ -0,0 +1,285 @@ +#include + +#include +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +// clang-format on + +///////////////////////////////////////// +// Begin automatically generated section +// clang-format off + +using namespace cute; + +namespace int8_kernel +{ + +using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + cute::Shape<_128, _128, _128>, cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::bfloat16_t, cutlass::bfloat16_t, + cutlass::epilogue::TmaWarpSpecialized +>; + +using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + +using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + cute::Stride, cute::Int<0>, cute::Int<0>> +>; + +using ScaleBDescriptor = cutlass::epilogue::collective::detail::RowBroadcastDescriptor; + +using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, cute::Stride, cute::Int<1>, cute::Int<0>> +>; + +using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< + Compute0, + ScaleB, + Accum>; + +using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, cutlass::bfloat16_t, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< + Compute1, + ScaleA, + EVTCompute0>; + +using ElementD = cutlass::bfloat16_t; +using StrideD = cute::Stride, cute::Int<0>>; +using ElementC = void; +using StrideC = StrideD; + + + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, float, + ElementC, StrideC, 4, + ElementD, StrideD, 4, + cutlass::epilogue::TmaWarpSpecialized, + EVTCompute1 + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, cutlass::layout::RowMajor, 16, + int8_t, cutlass::layout::ColumnMajor, 16, + int32_t, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + +// Gemm operator cutlass3x_sm90_tensorop_i64x128x32gemm_s8_s8_s32_bf16_bf16_128x128x128_2x1x1_0_tnt_align16_warpspecialized_pingpong_epi_tma +using cutlass3x_sm90_tensorop_i64x128x32gemm_s8_s8_s32_bf16_bf16_128x128x128_2x1x1_0_tnt_align16_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler +>; + +// Define named type +struct GemmKernel : + public cutlass3x_sm90_tensorop_i64x128x32gemm_s8_s8_s32_bf16_bf16_128x128x128_2x1x1_0_tnt_align16_warpspecialized_pingpong_epi_tma_base { }; + +} // namespace int8_kernel + +namespace fp8_kernel +{ + +using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + cute::Shape<_256, _128, _128>, cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::bfloat16_t, cutlass::bfloat16_t, + cutlass::epilogue::TmaWarpSpecializedCooperative +>; + +using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + +using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + cute::Stride, cute::Int<0>, cute::Int<0>> +>; + +using ScaleBDescriptor = cutlass::epilogue::collective::detail::RowBroadcastDescriptor; + +using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, cute::Stride, cute::Int<1>, cute::Int<0>> +>; + +using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< + Compute0, + ScaleB, + Accum>; + +using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, cutlass::bfloat16_t, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< + Compute1, + ScaleA, + EVTCompute0>; + +using ElementD = cutlass::bfloat16_t; +using StrideD = cute::Stride, cute::Int<0>>; +using ElementC = void; +using StrideC = StrideD; + + + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + ElementC, StrideC, 1, + ElementD, StrideD, 1, + cutlass::epilogue::TmaWarpSpecializedCooperative, + EVTCompute1 + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + +// Gemm operator cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_bf16_bf16_256x128x128_1x2x1_0_tnt_align16_warpspecialized_cooperative_epi_tma +using cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_bf16_bf16_256x128x128_1x2x1_0_tnt_align16_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler +>; + +// Define named type +struct GemmKernel : + public cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_bf16_bf16_256x128x128_1x2x1_0_tnt_align16_warpspecialized_cooperative_epi_tma_base { }; + +} // namespace fp8_kernel + +// clang-format on +// End automatically generated section +///////////////////////////////////////// + +using StrideA = cute::Stride, cute::Int<0>>; +using StrideB = cute::Stride, cute::Int<0>>; + +template +int entry_cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, {m, n, 1}); + + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + typename ScaleA::Arguments a_args = + a_scales.numel() == 1 + ? typename ScaleA::Arguments{nullptr, a_scales.item(), {}} + : typename ScaleA::Arguments{a_scales.data_ptr(), {}, {}}; + + typename ScaleB::Arguments b_args = + b_scales.numel() == 1 + ? typename ScaleB::Arguments{nullptr, b_scales.item(), {}} + : typename ScaleB::Arguments{b_scales.data_ptr(), {}, {}}; + + args.epilogue.thread = {a_args, {b_args}}; + + // Launch the CUTLASS GEMM kernel. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + Gemm gemm_op; + cutlass::Status status = gemm_op.run(args); + + // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. + if (status != cutlass::Status::kSuccess) { + return cudaErrorUnknown; + } + + // Return success, if no errors were encountered. + return cudaSuccess; +} + +int cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + if (a.dtype() == torch::kInt8) { + + return entry_cutlass_scaled_mm_dq< + int8_kernel::GemmKernel, int8_kernel::ScaleA, int8_kernel::ScaleB, + int8_kernel::StrideC, int8_t, cutlass::bfloat16_t>(out, a, b, a_scales, + b_scales); + } else { + + return entry_cutlass_scaled_mm_dq< + fp8_kernel::GemmKernel, fp8_kernel::ScaleA, fp8_kernel::ScaleB, + fp8_kernel::StrideC, cutlass::float_e4m3_t, cutlass::bfloat16_t>( + out, a, b, a_scales, b_scales); + } +} diff --git a/csrc/third_party/cutlass b/csrc/third_party/cutlass new file mode 160000 index 0000000000000..5c447dd84f8ae --- /dev/null +++ b/csrc/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit 5c447dd84f8ae0e1d48ff9a2eae26ce8c4958101 diff --git a/cutlass_fp8_fused_dq_and_scales.py b/cutlass_fp8_fused_dq_and_scales.py new file mode 100644 index 0000000000000..a72b70545df34 --- /dev/null +++ b/cutlass_fp8_fused_dq_and_scales.py @@ -0,0 +1,129 @@ +import torch +import cutlass +from cutlass.epilogue import relu +from cutlass import Tensor as FakeTensor +from cutlass.utils.profiler import CUDAEventProfiler + +# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to +# omit this information. +print_module = True + +# The Epilogue Visitor feature currently only works for SM80 and 90 +from cutlass.backend.utils.device import device_cc + +if device_cc() not in [86, 80, 90]: + import sys + + sys.exit() + +m = 512 +n = 512 +k = 512 + +type_A = torch.float8_e4m3fn +type_B = torch.float8_e4m3fn +type_C = torch.bfloat16 +type_D = torch.bfloat16 + + +def to_fp8(tensor): + # Assuming input tensor is float32 + # Scale tensor to range of FP8 E4M3 by clamping exponent and truncating mantissa + max_exp = 2**4 - 1 # Maximum exponent for E4M3 + max_mantissa = 2**3 - 1 # Maximum mantissa for E4M3 + base = 2**max_exp + # Scale the mantissa + scaled = torch.clamp(tensor, -base, base) + # Quantize the mantissa + quantized = torch.round(scaled * max_mantissa) / max_mantissa + return quantized.to(dtype=torch.float8_e4m3fn) + + +torch.manual_seed(2023) +tensor_A = to_fp8(torch.rand(size=(m, k), device="cuda")) +tensor_B = to_fp8(torch.rand(size=(n, k), device="cuda").t()) +tensor_D = torch.zeros(size=(m, n), dtype=type_C, device="cuda") +tensor_C = torch.zeros(size=(m, n), dtype=type_C, device="cuda") + +tensor_scale_a = torch.rand(size=(m, 1), device="cuda") +tensor_scale_b = torch.rand(size=(1, n), device="cuda") + +plan = cutlass.op.Gemm( + element_A=type_A, + element_B=type_B, + element_C=type_C, + element_D=type_D, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.float32, + kernel_cc=90, +) + + +# Define epilogue visitor +def example_epilogue(accum, scale_a, scale_b): + D = scale_a * (scale_b * accum) + return D + + +# Construct inputs and outputs +epilogue_tensors = { + "accum": FakeTensor( + element=torch.float32, + shape=(m, n), + layout_tag=cutlass.LayoutType.RowMajor, + ), + "D": tensor_D, + "scale_a": tensor_scale_a, + "scale_b": tensor_scale_b, +} + +# Trace the epilogue visitor +epilogue_visitor = cutlass.epilogue.trace(example_epilogue, epilogue_tensors) + +visitor_args = {"scale_a": tensor_scale_a, "scale_b": tensor_scale_b, "D": tensor_D} + +plan.epilogue_visitor = epilogue_visitor +plan.run( + tensor_A, + tensor_B, + tensor_C, + tensor_D, + visitor_args=visitor_args, + print_module=print_module, +) + + +class TorchReference(torch.nn.Module): + def forward(self, A, B, C, scale_a, scale_b): + accum = torch.matmul(A.to(dtype=torch.float32), B.to(dtype=torch.float32)) + return example_epilogue(accum.to(dtype=torch.float32), scale_a, scale_b).to( + type_D + ) + + +torch_reference = TorchReference() +tensor_D_ref = torch_reference( + tensor_A, tensor_B, tensor_C, tensor_scale_a, tensor_scale_b +) + +print(tensor_D) +print(tensor_D_ref) +assert torch.allclose(tensor_D, tensor_D_ref, 1e-1) + +warmup_iterations = 10 +profile_iterations = 50 +# Profile CUTLASS fused kernel +duration = CUDAEventProfiler( + plan, + warmup_iterations, + profile_iterations, + tensor_A, + tensor_B, + tensor_C, + tensor_D, + visitor_args=visitor_args, +)() + +print(f"CUTLASS duration: {duration:.2f} ms") diff --git a/cutlass_int8_fused_dq_and_scales.py b/cutlass_int8_fused_dq_and_scales.py new file mode 100644 index 0000000000000..09ed7c177b491 --- /dev/null +++ b/cutlass_int8_fused_dq_and_scales.py @@ -0,0 +1,124 @@ +import torch +import cutlass +from cutlass.epilogue import relu +from cutlass import Tensor as FakeTensor +from cutlass.utils.profiler import CUDAEventProfiler + +# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to +# omit this information. +print_module = True + +# The Epilogue Visitor feature currently only works for SM80 and 90 +from cutlass.backend.utils.device import device_cc + +if device_cc() not in [86, 80, 90]: + import sys + + sys.exit() + +m = 512 +n = 512 +k = 512 + +type_A = torch.int8 +type_B = torch.int8 +type_C = torch.bfloat16 +type_D = torch.bfloat16 + + +def to_int8(tensor): + min = -127 # use 127 for symmetry + max = 127 + scaled = torch.clamp(tensor, min, max) + quantized = torch.round(scaled) + return quantized.to(dtype=torch.int8) + + +torch.manual_seed(2023) +tensor_A = to_int8(torch.rand(size=(m, k), device="cuda") * 10) +tensor_B = to_int8(torch.rand(size=(n, k), device="cuda").t() * 10) +tensor_D = torch.zeros(size=(m, n), dtype=type_C, device="cuda") +tensor_C = torch.zeros(size=(m, n), dtype=type_C, device="cuda") + +tensor_scale_a = torch.rand(size=(m, 1), device="cuda") +tensor_scale_b = torch.rand(size=(1, n), device="cuda") + +plan = cutlass.op.Gemm( + element_A=type_A, + element_B=type_B, + element_C=type_C, + element_D=type_D, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32, + kernel_cc=90, +) + + +# Define epilogue visitor +def example_epilogue(accum, scale_a, scale_b): + D = scale_a * (scale_b * accum) + return D + + +# Construct inputs and outputs +epilogue_tensors = { + "accum": FakeTensor( + element=torch.int32, + shape=(m, n), + layout_tag=cutlass.LayoutType.RowMajor, + ), + "D": tensor_D, + "scale_a": tensor_scale_a, + "scale_b": tensor_scale_b, +} + +# Trace the epilogue visitor +epilogue_visitor = cutlass.epilogue.trace(example_epilogue, epilogue_tensors) + +visitor_args = {"scale_a": tensor_scale_a, "scale_b": tensor_scale_b, "D": tensor_D} + +plan.epilogue_visitor = epilogue_visitor +plan.run( + tensor_A, + tensor_B, + tensor_C, + tensor_D, + visitor_args=visitor_args, + print_module=print_module, +) + + +class TorchReference(torch.nn.Module): + def forward(self, A, B, C, scale_a, scale_b): + accum = torch.matmul(A.to(dtype=torch.float32), B.to(dtype=torch.float32)) + return example_epilogue(accum.to(dtype=torch.float32), scale_a, scale_b).to( + type_D + ) + + +torch_reference = TorchReference() +tensor_D_ref = torch_reference( + tensor_A, tensor_B, tensor_C, tensor_scale_a, tensor_scale_b +) + +print(tensor_D) +print(tensor_D_ref) +assert torch.allclose(tensor_D, tensor_D_ref, 1e-1) + +warmup_iterations = 10 +profile_iterations = 50 +# Profile CUTLASS fused kernel +duration = CUDAEventProfiler( + plan, + warmup_iterations, + profile_iterations, + tensor_A, + tensor_B, + tensor_C, + tensor_D, + visitor_args=visitor_args, +)() + +print(f"CUTLASS duration: {duration:.2f} ms") diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py new file mode 100644 index 0000000000000..69f5d3c9f5a29 --- /dev/null +++ b/tests/kernels/test_cutlass.py @@ -0,0 +1,112 @@ +"""Tests for cutlass kernels + +Run `pytest tests/kernels/test_cutlass.py`. +""" +import pytest +import torch + +from vllm import _custom_ops as ops + +def to_fp8(tensor): + # Assuming input tensor is float32 + # Scale tensor to range of FP8 E4M3 by clamping exponent and truncating mantissa + max_exp = 2**4 - 1 # Maximum exponent for E4M3 + max_mantissa = 2**3 - 1 # Maximum mantissa for E4M3 + base = 2**max_exp + # Scale the mantissa + scaled = torch.clamp(tensor, -base, base) + # Quantize the mantissa + quantized = torch.round(scaled * max_mantissa) / max_mantissa + return quantized.to(dtype=torch.float8_e4m3fn) + +def to_int8(tensor): + return torch.round(torch.clamp(tensor, -128, 127)).to(dtype=torch.int8) + +def cutlass_fp8_gemm_per_row_and_col_scales( + m: int, + n: int, + k: int, +): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_fp8(torch.randn((m, k), device='cuda')) + b = to_fp8(torch.randn((n, k), device='cuda').t()) + + scale_a = torch.randn((m,1), device='cuda', dtype=torch.float32) / 10 + scale_b = torch.randn((1,n), device='cuda', dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + +def cutlass_int8_gemm_per_row_and_col_scales( + m: int, + n: int, + k: int, +): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device='cuda') * 5) + b = to_int8(torch.randn((n, k), device='cuda').t() * 5) + + scale_a = torch.randn((m,1), device='cuda', dtype=torch.float32) / 10 + scale_b = torch.randn((1,n), device='cuda', dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024]) +def test_cutlass_fp8_gemm( + m: int, n: int, k: int, +): + cutlass_fp8_gemm_per_row_and_col_scales(m,n,k) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [511]) +@pytest.mark.skip(reason="Illegal instruction at k=511") +def test_cutlass_bad_size( + m: int, n: int, k: int, +): + cutlass_int8_gemm_per_row_and_col_scales(m,n,k) + cutlass_fp8_gemm_per_row_and_col_scales(m,n,k) + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024]) +def test_cutlass_int8_gemm( + m: int, + n: int, + k: int, +): + cutlass_int8_gemm_per_row_and_col_scales(m,n,k) + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024]) +def test_cutlass_int8_gemm_per_tensor_scales( + m: int, + n: int, + k: int, +): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device='cuda') * 5) + b = to_int8(torch.randn((n, k), device='cuda').t() * 5) + + scale_a = torch.randn((1,1), dtype=torch.float32) / 10 + scale_b = torch.randn((1,1), dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + baseline = torch.mm(scale_a.to(device='cuda') * a.to(dtype=torch.float32), + scale_b.to(device='cuda') * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-4, atol=1e-1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 42dedfdf76c4f..206530e4e48d2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -152,6 +152,15 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, size_n, size_k) +# cutlass +def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor) -> torch.Tensor: + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m,n), dtype=torch.bfloat16, device="cuda") + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + return out + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, From d2fef83ed0a491b6be7e4688b1141ecb3542e0ef Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 2 May 2024 21:38:12 +0000 Subject: [PATCH 02/18] 2.x kernel for Ampere and Lovelace --- CMakeLists.txt | 20 +- csrc/ops.h | 16 +- csrc/pybind.cpp | 4 +- csrc/quantization/cutlass/common.hpp | 47 +++ .../cutlass_visitor_2x_broadcast_epilogue.hpp | 332 ++++++++++++++++++ .../quantization/cutlass/scaled_mm_dq_sm8x.cu | 230 ++++++++++++ .../{scaled_mm_dq.cu => scaled_mm_dq_sm90.cu} | 42 +-- cutlass_int8_fused_dq_and_scales.py | 6 +- tests/kernels/test_cutlass.py | 98 +++--- vllm/_custom_ops.py | 27 +- 10 files changed, 739 insertions(+), 83 deletions(-) create mode 100644 csrc/quantization/cutlass/common.hpp create mode 100644 csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp create mode 100644 csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu rename csrc/quantization/cutlass/{scaled_mm_dq.cu => scaled_mm_dq_sm90.cu} (92%) diff --git a/CMakeLists.txt b/CMakeLists.txt index a8e52cf7b91be..874893b2dcbd4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,17 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") + +# Workaround for now until: +# https://github.com/pytorch/pytorch/commit/6e99f739235980e8d47e8fe6246c7466f2ce2f58 +# lands +if ($ENV{TORCH_CUDA_ARCH_LIST} MATCHES "9.0a") + set(CMAKE_CUDA_FLAGS "-gencode arch=compute_90a,code=sm_90a ${CMAKE_CUDA_FLAGS}") + string(REPLACE "9.0a" "" TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST}) +endif() + +# Supported NVIDIA architectures. +set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0;9.0a") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") @@ -165,7 +175,6 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" - "csrc/quantization/cutlass/scaled_mm_dq.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/fp8/common.cu" @@ -180,7 +189,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/marlin_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/custom_all_reduce.cu") + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu" + "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu") endif() define_gpu_extension_target( @@ -189,8 +200,7 @@ define_gpu_extension_target( LANGUAGE ${VLLM_GPU_LANG} SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} - # ARCHITECTURES ${VLLM_GPU_ARCHES} - ARCHITECTURES 90a + ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES csrc/third_party/cutlass/include;csrc/third_party/cutlass/tools/util/include WITH_SOABI) diff --git a/csrc/ops.h b/csrc/ops.h index 678c80d7d0511..196b8fca0d6d9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,7 +145,21 @@ torch::Tensor gptq_marlin_repack( int64_t size_n, int64_t num_bits); -int cutlass_scaled_mm_dq( +int cutlass_scaled_mm_dq_sm80( + torch::Tensor& out, + torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +int cutlass_scaled_mm_dq_sm89( + torch::Tensor& out, + torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +int cutlass_scaled_mm_dq_sm90( torch::Tensor& out, torch::Tensor const &a, torch::Tensor const &b, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index ca18ca7eb5eea..1b831d636f4d6 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,7 +71,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); + ops.def("cutlass_scaled_mm_dq_sm80", &cutlass_scaled_mm_dq_sm80, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); + ops.def("cutlass_scaled_mm_dq_sm89", &cutlass_scaled_mm_dq_sm89, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); + ops.def("cutlass_scaled_mm_dq_sm90", &cutlass_scaled_mm_dq_sm90, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); diff --git a/csrc/quantization/cutlass/common.hpp b/csrc/quantization/cutlass/common.hpp new file mode 100644 index 0000000000000..39b72351f7030 --- /dev/null +++ b/csrc/quantization/cutlass/common.hpp @@ -0,0 +1,47 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +// Taken from cutlass/examples/common/helper.h + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } diff --git a/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp new file mode 100644 index 0000000000000..d2c8a3324f766 --- /dev/null +++ b/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp @@ -0,0 +1,332 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass It's beem modified to support either +// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x. +// Important because this saves us a factor 4x on the number of kernels +// compiled. +// +#pragma once + +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +// clang-format on + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +using X = Underscore; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + } + } else { + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + dst_v(i) = filled_vec; + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + if (params_ptr->ptr_col) { + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + copy_if(pred, tC_gCol, tC_rCol); + } else { + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + dst_v(i) = params_ptr->null_default; + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu new file mode 100644 index 0000000000000..76bda2f3106e2 --- /dev/null +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu @@ -0,0 +1,230 @@ +#include + +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + +#include "cutlass_visitor_2x_broadcast_epilogue.hpp" +#include "common.hpp" +// clang-format on + +///////////////////////////////////////// + +template +struct sm8x_gemm +{ + +using Operator = typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, cutlass::arch::OpMultiplyAdd>::type; + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + float, + 4, + 1 /* epilogue stages */ +>; + + +using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + +using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, + cute::Stride, cute::Int<0>, cute::Int<0>> +>; + +using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, + cute::Stride, cute::Int<1>, cute::Int<0>> +>; + +using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT< + Compute0, + ScaleB, + Accum>; + +using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementOut, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT< + Compute1, + ScaleA, + EVTCompute0>; + +using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, cutlass::bfloat16_t, cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride, cute::Int<0>> +>; + +using EVTD = cutlass::epilogue::threadblock::Sm80EVT< + D, + EVTCompute1>; + + +// Gemm operator cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16 +using cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementIn, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 16, + ElementIn, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, + ElementAcc, + float, + cutlass::arch::OpClassTensorOp, + Arch, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + 3, + Operator, + 1 /* epilogue stages */ +>::GemmKernel; + +using Op = cutlass::gemm::device::GemmUniversalAdapter< + cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base>; +}; + +///////////////////////////////////////// + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = cute::Stride, cute::Int<0>>; + StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, {m, n, 1}); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + auto a_scales_ptr = a_scales.data_ptr(); + auto b_scales_ptr = b_scales.data_ptr(); + + using ScaleAArgs = typename Gemm::ScaleA::Arguments; + ScaleAArgs a_args = a_scales.numel() == 1 + ? ScaleAArgs{nullptr, a_scales.item(), {}} + : ScaleAArgs{a_scales.data_ptr(), {}, {}}; + + using ScaleBArgs = typename Gemm::ScaleB::Arguments; + ScaleBArgs b_args = b_scales.numel() == 1 + ? ScaleBArgs{nullptr, b_scales.item(), {}} + : ScaleBArgs{b_scales.data_ptr(), {}, {}}; + + typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, + evt0_compute_args}; + typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + + typename Gemm::EVTD::Arguments epilogue_args{ + evt1_compute_args, + d_args, + }; + + typename Gemm::Op::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + problem_size, // problem size + 1, // batch count + epilogue_args, + a_ptr, + b_ptr, + nullptr, + nullptr, + 0, + 0, + 0, + 0, + lda, + ldb, + ldc, + ldc}; + + // Launch the CUTLASS GEMM kernel. + typename Gemm::Op gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(args)); + cutlass::Status status = gemm_op(args, workspace.get()); + CUTLASS_CHECK(status); +} + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) +void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + assert(a.dtype() == torch::kInt8); + assert(b.dtype() == torch::kInt8); + assert(a_scales.dtype() == torch::kFloat32); + assert(b_scales.dtype() == torch::kFloat32); + assert(out.dtype() == torch::kBFloat16); + + return cutlass_scaled_mm_dq_dispatcher, int8_t, cutlass::bfloat16_t>(out, a, b, a_scales, b_scales); +} +#endif + +#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) +void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + if (a.dtype() == torch::kInt8) { + assert(b.dtype() == torch::kInt8); + assert(a_scales.dtype() == torch::kFloat32); + assert(b_scales.dtype() == torch::kFloat32); + assert(out.dtype() == torch::kBFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + sm8x_gemm, int8_t, + cutlass::bfloat16_t>(out, a, b, a_scales, b_scales); + } else { + assert(a.dtype() == torch::kFloat8_e4m3fn); + assert(b.dtype() == torch::kFloat8_e4m3fn); + assert(a_scales.dtype() == torch::kFloat32); + assert(b_scales.dtype() == torch::kFloat32); + assert(out.dtype() == torch::kBFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + sm8x_gemm, + cutlass::float_e4m3_t, cutlass::bfloat16_t>(out, a, b, a_scales, + b_scales); + } +} +#endif diff --git a/csrc/quantization/cutlass/scaled_mm_dq.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu similarity index 92% rename from csrc/quantization/cutlass/scaled_mm_dq.cu rename to csrc/quantization/cutlass/scaled_mm_dq_sm90.cu index ff45417990c46..7e81089a3928e 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu @@ -18,6 +18,8 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" + +#include "common.hpp" // clang-format on ///////////////////////////////////////// @@ -212,11 +214,11 @@ using StrideA = cute::Stride, cute::Int<0>>; using StrideB = cute::Stride, cute::Int<0>>; template -int entry_cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { + typename StrideC, typename ElementIn, typename ElementOut> +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { int32_t m = a.size(0); int32_t n = b.size(1); @@ -228,12 +230,12 @@ int entry_cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; - auto c_ptr = static_cast(out.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; @@ -255,31 +257,29 @@ int entry_cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, // Launch the CUTLASS GEMM kernel. using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); cutlass::Status status = gemm_op.run(args); - - // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. - if (status != cutlass::Status::kSuccess) { - return cudaErrorUnknown; - } - - // Return success, if no errors were encountered. - return cudaSuccess; + CUTLASS_CHECK(status); } -int cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { if (a.dtype() == torch::kInt8) { - return entry_cutlass_scaled_mm_dq< + return cutlass_scaled_mm_dq_dispatcher< int8_kernel::GemmKernel, int8_kernel::ScaleA, int8_kernel::ScaleB, int8_kernel::StrideC, int8_t, cutlass::bfloat16_t>(out, a, b, a_scales, b_scales); } else { - return entry_cutlass_scaled_mm_dq< + return cutlass_scaled_mm_dq_dispatcher< fp8_kernel::GemmKernel, fp8_kernel::ScaleA, fp8_kernel::ScaleB, fp8_kernel::StrideC, cutlass::float_e4m3_t, cutlass::bfloat16_t>( out, a, b, a_scales, b_scales); } } +#endif + diff --git a/cutlass_int8_fused_dq_and_scales.py b/cutlass_int8_fused_dq_and_scales.py index 09ed7c177b491..3fcaaffd76ffc 100644 --- a/cutlass_int8_fused_dq_and_scales.py +++ b/cutlass_int8_fused_dq_and_scales.py @@ -22,7 +22,7 @@ type_A = torch.int8 type_B = torch.int8 -type_C = torch.bfloat16 +type_C = torch.float32 type_D = torch.bfloat16 @@ -37,7 +37,7 @@ def to_int8(tensor): torch.manual_seed(2023) tensor_A = to_int8(torch.rand(size=(m, k), device="cuda") * 10) tensor_B = to_int8(torch.rand(size=(n, k), device="cuda").t() * 10) -tensor_D = torch.zeros(size=(m, n), dtype=type_C, device="cuda") +tensor_D = torch.zeros(size=(m, n), dtype=type_D, device="cuda") tensor_C = torch.zeros(size=(m, n), dtype=type_C, device="cuda") tensor_scale_a = torch.rand(size=(m, 1), device="cuda") @@ -52,7 +52,7 @@ def to_int8(tensor): layout_B=cutlass.LayoutType.ColumnMajor, layout_C=cutlass.LayoutType.RowMajor, element_accumulator=torch.int32, - kernel_cc=90, + kernel_cc=80, ) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 69f5d3c9f5a29..bfefdb4cef6b7 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -7,6 +7,9 @@ from vllm import _custom_ops as ops +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + def to_fp8(tensor): # Assuming input tensor is float32 # Scale tensor to range of FP8 E4M3 by clamping exponent and truncating mantissa @@ -22,18 +25,23 @@ def to_fp8(tensor): def to_int8(tensor): return torch.round(torch.clamp(tensor, -128, 127)).to(dtype=torch.int8) -def cutlass_fp8_gemm_per_row_and_col_scales( +def cutlass_fp8_gemm_helper( m: int, n: int, k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, ): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_fp8(torch.randn((m, k), device='cuda')) b = to_fp8(torch.randn((n, k), device='cuda').t()) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 - scale_a = torch.randn((m,1), device='cuda', dtype=torch.float32) / 10 - scale_b = torch.randn((1,n), device='cuda', dtype=torch.float32) / 10 + scale_a = torch.randn((m_a_scales,1), device='cuda', dtype=torch.float32) / 10 + scale_b = torch.randn((1,n_b_scales), device='cuda', dtype=torch.float32) / 10 out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), @@ -41,72 +49,68 @@ def cutlass_fp8_gemm_per_row_and_col_scales( assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) -def cutlass_int8_gemm_per_row_and_col_scales( +def cutlass_int8_gemm_helper( m: int, n: int, k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, ): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_int8(torch.randn((m, k), device='cuda') * 5) b = to_int8(torch.randn((n, k), device='cuda').t() * 5) - scale_a = torch.randn((m,1), device='cuda', dtype=torch.float32) / 10 - scale_b = torch.randn((1,n), device='cuda', dtype=torch.float32) / 10 + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = torch.randn((m_a_scales,1), device='cuda', dtype=torch.float32) / 10 + scale_b = torch.randn((1,n_b_scales), device='cuda', dtype=torch.float32) / 10 out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) - assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm( - m: int, n: int, k: int, + m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool ): - cutlass_fp8_gemm_per_row_and_col_scales(m,n,k) - + cutlass_fp8_gemm_helper(m,n,k, per_act_token, per_out_ch) @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [511]) -@pytest.mark.skip(reason="Illegal instruction at k=511") -def test_cutlass_bad_size( - m: int, n: int, k: int, -): - cutlass_int8_gemm_per_row_and_col_scales(m,n,k) - cutlass_fp8_gemm_per_row_and_col_scales(m,n,k) - -@pytest.mark.parametrize("m", [512, 222, 33, 1]) -@pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) def test_cutlass_int8_gemm( - m: int, - n: int, - k: int, -): - cutlass_int8_gemm_per_row_and_col_scales(m,n,k) - -@pytest.mark.parametrize("m", [512, 222, 33, 1]) -@pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024]) -def test_cutlass_int8_gemm_per_tensor_scales( - m: int, - n: int, - k: int, + m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool ): - # Test for a cutlass kernel with per-token activation quantization - # and per-output channel weight quantization. - a = to_int8(torch.randn((m, k), device='cuda') * 5) - b = to_int8(torch.randn((n, k), device='cuda').t() * 5) - - scale_a = torch.randn((1,1), dtype=torch.float32) / 10 - scale_b = torch.randn((1,1), dtype=torch.float32) / 10 - - out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) - baseline = torch.mm(scale_a.to(device='cuda') * a.to(dtype=torch.float32), - scale_b.to(device='cuda') * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + cutlass_int8_gemm_helper(m,n,k, per_act_token, per_out_ch) + +# For the following two tests: +# N and K correspond to the size of the weight matrix and likely to be multiples of a large power of two. +# In any case, the kernel will have a naive fallback when N and K are not divisible by 16 +# But M is the number of tokens and the kernel must handle any M thrown at it. +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(16, 256, 16): + for m in range(1, 128): + cutlass_fp8_gemm_helper(m,nk,nk, per_act_token, per_out_ch) + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(16, 256, 16): + for m in range(1, 128): + cutlass_int8_gemm_helper(m,nk,nk, per_act_token, per_out_ch) + - assert torch.allclose(out, baseline, rtol=1e-4, atol=1e-1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 206530e4e48d2..10ea9126a63b0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2,6 +2,9 @@ import torch +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + try: from vllm._C import cache_ops as vllm_cache_ops from vllm._C import ops as vllm_ops @@ -155,11 +158,25 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor) -> torch.Tensor: - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m,n), dtype=torch.bfloat16, device="cuda") - vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) - return out + + shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 + + if capability < 800 or shape_fallback: + a_bf16 = a.to(dtype=torch.bfloat16) + b_bf16 = b.to(dtype=torch.bfloat16) + + return (b_scales * (a_scales * torch.mm(a_bf16, b_bf16))).to(dtype=torch.bfloat16) + else: + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m,n), dtype=torch.bfloat16, device="cuda") + if capability >= 900: + vllm_ops.cutlass_scaled_mm_dq_sm90(out, a, b, a_scales, b_scales) + elif capability >= 890: + vllm_ops.cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales) + else: + vllm_ops.cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales) + return out # aqlm From 24358c422d7a573ffd8e3afc28ee6af902695c07 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 8 May 2024 17:20:33 +0000 Subject: [PATCH 03/18] Refactor, cleanup --- .gitmodules | 3 - CMakeLists.txt | 24 +- csrc/ops.h | 15 +- csrc/pybind.cpp | 4 +- .../cutlass/scaled_mm_dq_entry.cu | 42 +++ .../quantization/cutlass/scaled_mm_dq_sm8x.cu | 199 +++++------ .../quantization/cutlass/scaled_mm_dq_sm90.cu | 336 +++++++----------- csrc/third_party/cutlass | 1 - cutlass_fp8_fused_dq_and_scales.py | 129 ------- cutlass_int8_fused_dq_and_scales.py | 124 ------- tests/kernels/test_cutlass.py | 99 ++++-- vllm/_custom_ops.py | 336 ++++++++++++------ 12 files changed, 570 insertions(+), 742 deletions(-) delete mode 100644 .gitmodules create mode 100644 csrc/quantization/cutlass/scaled_mm_dq_entry.cu delete mode 160000 csrc/third_party/cutlass delete mode 100644 cutlass_fp8_fused_dq_and_scales.py delete mode 100644 cutlass_int8_fused_dq_and_scales.py diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index a8108241542e9..0000000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "csrc/third_party/cutlass"] - path = csrc/third_party/cutlass - url = https://github.com/nvidia/cutlass diff --git a/CMakeLists.txt b/CMakeLists.txt index 874893b2dcbd4..0ff0c3b9d50d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,18 +15,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") -# Supported NVIDIA architectures. - -# Workaround for now until: -# https://github.com/pytorch/pytorch/commit/6e99f739235980e8d47e8fe6246c7466f2ce2f58 -# lands -if ($ENV{TORCH_CUDA_ARCH_LIST} MATCHES "9.0a") - set(CMAKE_CUDA_FLAGS "-gencode arch=compute_90a,code=sm_90a ${CMAKE_CUDA_FLAGS}") - string(REPLACE "9.0a" "" TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST}) -endif() - # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0;9.0a") +set(CMAKE_CUDA_FLAGS "-gencode arch=compute_90a,code=sm_90a ${CMAKE_CUDA_FLAGS}") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") @@ -183,6 +174,16 @@ set(VLLM_EXT_SRC "csrc/pybind.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.0 + GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + ) + FetchContent_MakeAvailable(cutlass) + list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" @@ -190,6 +191,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass/scaled_mm_dq_entry.cu" "csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu" "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu") endif() @@ -201,7 +203,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES csrc/third_party/cutlass/include;csrc/third_party/cutlass/tools/util/include + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} WITH_SOABI) # diff --git a/csrc/ops.h b/csrc/ops.h index 196b8fca0d6d9..d8659610f3223 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,26 +145,13 @@ torch::Tensor gptq_marlin_repack( int64_t size_n, int64_t num_bits); -int cutlass_scaled_mm_dq_sm80( +int cutlass_scaled_mm_dq( torch::Tensor& out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); -int cutlass_scaled_mm_dq_sm89( - torch::Tensor& out, - torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); - -int cutlass_scaled_mm_dq_sm90( - torch::Tensor& out, - torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 1b831d636f4d6..bb55c8448cd9e 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -70,10 +70,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); #endif - ops.def("cutlass_scaled_mm_dq_sm80", &cutlass_scaled_mm_dq_sm80, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); - ops.def("cutlass_scaled_mm_dq_sm89", &cutlass_scaled_mm_dq_sm89, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); - ops.def("cutlass_scaled_mm_dq_sm90", &cutlass_scaled_mm_dq_sm90, "CUTLASS quantized w8a8 GEMM, supporting symmetric quantized per-channel or per-tensor weights and symmetric quantized per-token or per-tensor activations. Inputs are either int8, int8 or fp8_e4m3fn, fp8_e4m3fn. Output must be bfloat16"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); diff --git a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu new file mode 100644 index 0000000000000..6139acef44f3f --- /dev/null +++ b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu @@ -0,0 +1,42 @@ +#include + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + // TODO(tms): Hack. cudaGetDeviceProperties is very slow. + static std::optional maybe_version_num = std::nullopt; + if (!maybe_version_num.has_value()) { + int device; + cudaGetDevice(&device); + cudaDeviceProp properties; + cudaGetDeviceProperties(&properties, device); + maybe_version_num = properties.major * 10 + properties.minor; + } + int32_t version_num = maybe_version_num.value(); + + if (version_num >= 90) /* H100 */ { + // TODO: This kernel only works for sm90a + // -- figure out how to detect 90a vs 90 + + cutlass_scaled_mm_dq_sm90(out, a, b, a_scales, b_scales); + } else if (version_num == 89) /* Ada Lovelace */ { + cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales); + } else if (version_num >= 80) /* Ampere */ { + cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales); + } else { + throw std::runtime_error("Unsupported GPU architecture"); + } +} diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu index 76bda2f3106e2..4bca3ca90197a 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu @@ -9,8 +9,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/packed_stride.hpp" #include "cutlass/util/device_memory.h" #include "cutlass/cutlass.h" @@ -25,99 +23,89 @@ #include "cutlass_visitor_2x_broadcast_epilogue.hpp" #include "common.hpp" -// clang-format on +// clang-format on ///////////////////////////////////////// -template -struct sm8x_gemm -{ - -using Operator = typename std::conditional, - cutlass::arch::OpMultiplyAddSaturate, cutlass::arch::OpMultiplyAdd>::type; - -using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - float, - 4, - 1 /* epilogue stages */ ->; - - -using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - -using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, float, - cute::Stride, cute::Int<0>, cute::Int<0>> ->; - -using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, float, - cute::Stride, cute::Int<1>, cute::Int<0>> ->; - -using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest ->; - -using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT< - Compute0, - ScaleB, - Accum>; - -using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementOut, float, - cutlass::FloatRoundStyle::round_to_nearest ->; - -using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT< - Compute1, - ScaleA, - EVTCompute0>; - -using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, cutlass::bfloat16_t, cutlass::FloatRoundStyle::round_to_nearest, - cute::Stride, cute::Int<0>> ->; - -using EVTD = cutlass::epilogue::threadblock::Sm80EVT< - D, - EVTCompute1>; - - -// Gemm operator cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16 -using cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementIn, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 16, - ElementIn, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 16, - float, cutlass::layout::RowMajor, 4, - ElementAcc, - float, - cutlass::arch::OpClassTensorOp, - Arch, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 32>, - EVTD, - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, - 3, - Operator, - 1 /* epilogue stages */ ->::GemmKernel; - -using Op = cutlass::gemm::device::GemmUniversalAdapter< - cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base>; +template +struct sm8x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, float, 4, + 1 /* epilogue stages */ + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::threadblock::Sm80EVT; + + using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, cutlass::bfloat16_t, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride, cute::Int<0>>>; + + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + + // Gemm operator cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16 + using cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, 16, ElementAB, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, ElementAcc, float, + cutlass::arch::OpClassTensorOp, Arch, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3, Operator, + 1 /* epilogue stages */ + >::GemmKernel; + + using Op = cutlass::gemm::device::GemmUniversalAdapter< + cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base>; }; ///////////////////////////////////////// -template +template void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); @@ -129,11 +117,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, int64_t ldc = out.stride(0); using StrideC = cute::Stride, cute::Int<0>>; - StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, {m, n, 1}); + StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); auto a_scales_ptr = a_scales.data_ptr(); auto b_scales_ptr = b_scales.data_ptr(); @@ -160,9 +148,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, }; typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemm, // universal mode - problem_size, // problem size - 1, // batch count + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + problem_size, // problem size + 1, // batch count epilogue_args, a_ptr, b_ptr, @@ -187,24 +175,25 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, CUTLASS_CHECK(status); } -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { assert(a.dtype() == torch::kInt8); assert(b.dtype() == torch::kInt8); assert(a_scales.dtype() == torch::kFloat32); assert(b_scales.dtype() == torch::kFloat32); assert(out.dtype() == torch::kBFloat16); - return cutlass_scaled_mm_dq_dispatcher, int8_t, cutlass::bfloat16_t>(out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_dq_dispatcher< + sm8x_gemm>( + out, a, b, a_scales, b_scales); } -#endif -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { if (a.dtype() == torch::kInt8) { assert(b.dtype() == torch::kInt8); assert(a_scales.dtype() == torch::kFloat32); @@ -212,8 +201,8 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, assert(out.dtype() == torch::kBFloat16); return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm, int8_t, - cutlass::bfloat16_t>(out, a, b, a_scales, b_scales); + sm8x_gemm>( + out, a, b, a_scales, b_scales); } else { assert(a.dtype() == torch::kFloat8_e4m3fn); assert(b.dtype() == torch::kFloat8_e4m3fn); @@ -221,10 +210,8 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, assert(b_scales.dtype() == torch::kFloat32); assert(out.dtype() == torch::kBFloat16); - return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm, - cutlass::float_e4m3_t, cutlass::bfloat16_t>(out, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_dq_dispatcher>( + out, a, b, a_scales, b_scales); } } -#endif diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu index 7e81089a3928e..1c6ab40a90e6f 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu @@ -12,274 +12,182 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/packed_stride.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "common.hpp" -// clang-format on - -///////////////////////////////////////// -// Begin automatically generated section -// clang-format off +// clang-format on using namespace cute; -namespace int8_kernel -{ - -using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< - cute::Shape<_128, _128, _128>, cutlass::epilogue::collective::EpilogueTileAuto, - cutlass::bfloat16_t, cutlass::bfloat16_t, - cutlass::epilogue::TmaWarpSpecialized ->; - -using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - -using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, - cute::Stride, cute::Int<0>, cute::Int<0>> ->; - -using ScaleBDescriptor = cutlass::epilogue::collective::detail::RowBroadcastDescriptor; - -using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< - ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, - typename ScaleBDescriptor::Element, cute::Stride, cute::Int<1>, cute::Int<0>> ->; - -using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest ->; - -using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< - Compute0, - ScaleB, - Accum>; - -using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, cutlass::bfloat16_t, float, - cutlass::FloatRoundStyle::round_to_nearest ->; - -using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< - Compute1, - ScaleA, - EVTCompute0>; - -using ElementD = cutlass::bfloat16_t; -using StrideD = cute::Stride, cute::Int<0>>; -using ElementC = void; -using StrideC = StrideD; - - - -using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - int32_t, float, - ElementC, StrideC, 4, - ElementD, StrideD, 4, - cutlass::epilogue::TmaWarpSpecialized, - EVTCompute1 - >::CollectiveOp; - -using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - int8_t, cutlass::layout::RowMajor, 16, - int8_t, cutlass::layout::ColumnMajor, 16, - int32_t, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedPingpong - >::CollectiveOp; - -// Gemm operator cutlass3x_sm90_tensorop_i64x128x32gemm_s8_s8_s32_bf16_bf16_128x128x128_2x1x1_0_tnt_align16_warpspecialized_pingpong_epi_tma -using cutlass3x_sm90_tensorop_i64x128x32gemm_s8_s8_s32_bf16_bf16_128x128x128_2x1x1_0_tnt_align16_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler ->; - -// Define named type -struct GemmKernel : - public cutlass3x_sm90_tensorop_i64x128x32gemm_s8_s8_s32_bf16_bf16_128x128x128_2x1x1_0_tnt_align16_warpspecialized_pingpong_epi_tma_base { }; - -} // namespace int8_kernel - -namespace fp8_kernel -{ - -using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< - cute::Shape<_256, _128, _128>, cutlass::epilogue::collective::EpilogueTileAuto, - cutlass::bfloat16_t, cutlass::bfloat16_t, - cutlass::epilogue::TmaWarpSpecializedCooperative ->; - -using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - -using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, - cute::Stride, cute::Int<0>, cute::Int<0>> ->; - -using ScaleBDescriptor = cutlass::epilogue::collective::detail::RowBroadcastDescriptor; - -using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< - ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, - typename ScaleBDescriptor::Element, cute::Stride, cute::Int<1>, cute::Int<0>> ->; - -using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest ->; - -using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< - Compute0, - ScaleB, - Accum>; - -using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, cutlass::bfloat16_t, float, - cutlass::FloatRoundStyle::round_to_nearest ->; - -using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< - Compute1, - ScaleA, - EVTCompute0>; - -using ElementD = cutlass::bfloat16_t; -using StrideD = cute::Stride, cute::Int<0>>; -using ElementC = void; -using StrideC = StrideD; - - - -using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - ElementC, StrideC, 1, - ElementD, StrideD, 1, - cutlass::epilogue::TmaWarpSpecializedCooperative, - EVTCompute1 - >::CollectiveOp; - -using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 16, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedCooperative - >::CollectiveOp; - -// Gemm operator cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_bf16_bf16_256x128x128_1x2x1_0_tnt_align16_warpspecialized_cooperative_epi_tma -using cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_bf16_bf16_256x128x128_1x2x1_0_tnt_align16_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler ->; - -// Define named type -struct GemmKernel : - public cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_bf16_bf16_256x128x128_1x2x1_0_tnt_align16_warpspecialized_cooperative_epi_tma_base { }; - -} // namespace fp8_kernel +///////////////////////////////////////// + +template +struct sm90_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<0>, Int<0>>>; + + using ScaleBDescriptor = + cutlass::epilogue::collective::detail::RowBroadcastDescriptor< + EpilogueDescriptor, float>; + + using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute1>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, + cutlass::layout::RowMajor, 16, ElementAB, + cutlass::layout::ColumnMajor, 16, ElementAcc, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + CEStorageSize)>, + KernelSchedule>::CollectiveOp; + + using KernelType = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + + struct GemmKernel : public KernelType {}; +}; -// clang-format on -// End automatically generated section ///////////////////////////////////////// -using StrideA = cute::Stride, cute::Int<0>>; -using StrideB = cute::Stride, cute::Int<0>>; +using StrideA = Stride, Int<0>>; +using StrideB = Stride, Int<0>>; -template +template void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); - StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); - StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); - StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, {m, n, 1}); + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = typename Gemm::StrideC; + StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; - auto c_ptr = static_cast(out.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args}; - typename ScaleA::Arguments a_args = - a_scales.numel() == 1 - ? typename ScaleA::Arguments{nullptr, a_scales.item(), {}} - : typename ScaleA::Arguments{a_scales.data_ptr(), {}, {}}; + using ScaleA_Args = typename Gemm::ScaleA::Arguments; + using ScaleB_Args = typename Gemm::ScaleB::Arguments; + ScaleA_Args a_args = a_scales.numel() == 1 + ? ScaleA_Args{nullptr, a_scales.item(), {}} + : ScaleA_Args{a_scales.data_ptr(), {}, {}}; - typename ScaleB::Arguments b_args = - b_scales.numel() == 1 - ? typename ScaleB::Arguments{nullptr, b_scales.item(), {}} - : typename ScaleB::Arguments{b_scales.data_ptr(), {}, {}}; + ScaleB_Args b_args = b_scales.numel() == 1 + ? ScaleB_Args{nullptr, b_scales.item(), {}} + : ScaleB_Args{b_scales.data_ptr(), {}, {}}; args.epilogue.thread = {a_args, {b_args}}; // Launch the CUTLASS GEMM kernel. - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - Gemm gemm_op; + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + assert(workspace_size == 0); + cutlass::Status status = gemm_op.run(args); CUTLASS_CHECK(status); } -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { if (a.dtype() == torch::kInt8) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; return cutlass_scaled_mm_dq_dispatcher< - int8_kernel::GemmKernel, int8_kernel::ScaleA, int8_kernel::ScaleB, - int8_kernel::StrideC, int8_t, cutlass::bfloat16_t>(out, a, b, a_scales, - b_scales); + sm90_gemm>(out, a, b, a_scales, + b_scales); } else { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; return cutlass_scaled_mm_dq_dispatcher< - fp8_kernel::GemmKernel, fp8_kernel::ScaleA, fp8_kernel::ScaleB, - fp8_kernel::StrideC, cutlass::float_e4m3_t, cutlass::bfloat16_t>( + sm90_gemm>( out, a, b, a_scales, b_scales); } } -#endif - diff --git a/csrc/third_party/cutlass b/csrc/third_party/cutlass deleted file mode 160000 index 5c447dd84f8ae..0000000000000 --- a/csrc/third_party/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5c447dd84f8ae0e1d48ff9a2eae26ce8c4958101 diff --git a/cutlass_fp8_fused_dq_and_scales.py b/cutlass_fp8_fused_dq_and_scales.py deleted file mode 100644 index a72b70545df34..0000000000000 --- a/cutlass_fp8_fused_dq_and_scales.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch -import cutlass -from cutlass.epilogue import relu -from cutlass import Tensor as FakeTensor -from cutlass.utils.profiler import CUDAEventProfiler - -# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to -# omit this information. -print_module = True - -# The Epilogue Visitor feature currently only works for SM80 and 90 -from cutlass.backend.utils.device import device_cc - -if device_cc() not in [86, 80, 90]: - import sys - - sys.exit() - -m = 512 -n = 512 -k = 512 - -type_A = torch.float8_e4m3fn -type_B = torch.float8_e4m3fn -type_C = torch.bfloat16 -type_D = torch.bfloat16 - - -def to_fp8(tensor): - # Assuming input tensor is float32 - # Scale tensor to range of FP8 E4M3 by clamping exponent and truncating mantissa - max_exp = 2**4 - 1 # Maximum exponent for E4M3 - max_mantissa = 2**3 - 1 # Maximum mantissa for E4M3 - base = 2**max_exp - # Scale the mantissa - scaled = torch.clamp(tensor, -base, base) - # Quantize the mantissa - quantized = torch.round(scaled * max_mantissa) / max_mantissa - return quantized.to(dtype=torch.float8_e4m3fn) - - -torch.manual_seed(2023) -tensor_A = to_fp8(torch.rand(size=(m, k), device="cuda")) -tensor_B = to_fp8(torch.rand(size=(n, k), device="cuda").t()) -tensor_D = torch.zeros(size=(m, n), dtype=type_C, device="cuda") -tensor_C = torch.zeros(size=(m, n), dtype=type_C, device="cuda") - -tensor_scale_a = torch.rand(size=(m, 1), device="cuda") -tensor_scale_b = torch.rand(size=(1, n), device="cuda") - -plan = cutlass.op.Gemm( - element_A=type_A, - element_B=type_B, - element_C=type_C, - element_D=type_D, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.ColumnMajor, - layout_C=cutlass.LayoutType.RowMajor, - element_accumulator=torch.float32, - kernel_cc=90, -) - - -# Define epilogue visitor -def example_epilogue(accum, scale_a, scale_b): - D = scale_a * (scale_b * accum) - return D - - -# Construct inputs and outputs -epilogue_tensors = { - "accum": FakeTensor( - element=torch.float32, - shape=(m, n), - layout_tag=cutlass.LayoutType.RowMajor, - ), - "D": tensor_D, - "scale_a": tensor_scale_a, - "scale_b": tensor_scale_b, -} - -# Trace the epilogue visitor -epilogue_visitor = cutlass.epilogue.trace(example_epilogue, epilogue_tensors) - -visitor_args = {"scale_a": tensor_scale_a, "scale_b": tensor_scale_b, "D": tensor_D} - -plan.epilogue_visitor = epilogue_visitor -plan.run( - tensor_A, - tensor_B, - tensor_C, - tensor_D, - visitor_args=visitor_args, - print_module=print_module, -) - - -class TorchReference(torch.nn.Module): - def forward(self, A, B, C, scale_a, scale_b): - accum = torch.matmul(A.to(dtype=torch.float32), B.to(dtype=torch.float32)) - return example_epilogue(accum.to(dtype=torch.float32), scale_a, scale_b).to( - type_D - ) - - -torch_reference = TorchReference() -tensor_D_ref = torch_reference( - tensor_A, tensor_B, tensor_C, tensor_scale_a, tensor_scale_b -) - -print(tensor_D) -print(tensor_D_ref) -assert torch.allclose(tensor_D, tensor_D_ref, 1e-1) - -warmup_iterations = 10 -profile_iterations = 50 -# Profile CUTLASS fused kernel -duration = CUDAEventProfiler( - plan, - warmup_iterations, - profile_iterations, - tensor_A, - tensor_B, - tensor_C, - tensor_D, - visitor_args=visitor_args, -)() - -print(f"CUTLASS duration: {duration:.2f} ms") diff --git a/cutlass_int8_fused_dq_and_scales.py b/cutlass_int8_fused_dq_and_scales.py deleted file mode 100644 index 3fcaaffd76ffc..0000000000000 --- a/cutlass_int8_fused_dq_and_scales.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -import cutlass -from cutlass.epilogue import relu -from cutlass import Tensor as FakeTensor -from cutlass.utils.profiler import CUDAEventProfiler - -# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to -# omit this information. -print_module = True - -# The Epilogue Visitor feature currently only works for SM80 and 90 -from cutlass.backend.utils.device import device_cc - -if device_cc() not in [86, 80, 90]: - import sys - - sys.exit() - -m = 512 -n = 512 -k = 512 - -type_A = torch.int8 -type_B = torch.int8 -type_C = torch.float32 -type_D = torch.bfloat16 - - -def to_int8(tensor): - min = -127 # use 127 for symmetry - max = 127 - scaled = torch.clamp(tensor, min, max) - quantized = torch.round(scaled) - return quantized.to(dtype=torch.int8) - - -torch.manual_seed(2023) -tensor_A = to_int8(torch.rand(size=(m, k), device="cuda") * 10) -tensor_B = to_int8(torch.rand(size=(n, k), device="cuda").t() * 10) -tensor_D = torch.zeros(size=(m, n), dtype=type_D, device="cuda") -tensor_C = torch.zeros(size=(m, n), dtype=type_C, device="cuda") - -tensor_scale_a = torch.rand(size=(m, 1), device="cuda") -tensor_scale_b = torch.rand(size=(1, n), device="cuda") - -plan = cutlass.op.Gemm( - element_A=type_A, - element_B=type_B, - element_C=type_C, - element_D=type_D, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.ColumnMajor, - layout_C=cutlass.LayoutType.RowMajor, - element_accumulator=torch.int32, - kernel_cc=80, -) - - -# Define epilogue visitor -def example_epilogue(accum, scale_a, scale_b): - D = scale_a * (scale_b * accum) - return D - - -# Construct inputs and outputs -epilogue_tensors = { - "accum": FakeTensor( - element=torch.int32, - shape=(m, n), - layout_tag=cutlass.LayoutType.RowMajor, - ), - "D": tensor_D, - "scale_a": tensor_scale_a, - "scale_b": tensor_scale_b, -} - -# Trace the epilogue visitor -epilogue_visitor = cutlass.epilogue.trace(example_epilogue, epilogue_tensors) - -visitor_args = {"scale_a": tensor_scale_a, "scale_b": tensor_scale_b, "D": tensor_D} - -plan.epilogue_visitor = epilogue_visitor -plan.run( - tensor_A, - tensor_B, - tensor_C, - tensor_D, - visitor_args=visitor_args, - print_module=print_module, -) - - -class TorchReference(torch.nn.Module): - def forward(self, A, B, C, scale_a, scale_b): - accum = torch.matmul(A.to(dtype=torch.float32), B.to(dtype=torch.float32)) - return example_epilogue(accum.to(dtype=torch.float32), scale_a, scale_b).to( - type_D - ) - - -torch_reference = TorchReference() -tensor_D_ref = torch_reference( - tensor_A, tensor_B, tensor_C, tensor_scale_a, tensor_scale_b -) - -print(tensor_D) -print(tensor_D_ref) -assert torch.allclose(tensor_D, tensor_D_ref, 1e-1) - -warmup_iterations = 10 -profile_iterations = 50 -# Profile CUTLASS fused kernel -duration = CUDAEventProfiler( - plan, - warmup_iterations, - profile_iterations, - tensor_A, - tensor_B, - tensor_C, - tensor_D, - visitor_args=visitor_args, -)() - -print(f"CUTLASS duration: {duration:.2f} ms") diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index bfefdb4cef6b7..adfa79d9f2a7d 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -10,9 +10,11 @@ capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] + def to_fp8(tensor): # Assuming input tensor is float32 - # Scale tensor to range of FP8 E4M3 by clamping exponent and truncating mantissa + # Scale tensor to range of FP8 E4M3 + # by clamping exponent and truncating mantissa max_exp = 2**4 - 1 # Maximum exponent for E4M3 max_mantissa = 2**3 - 1 # Maximum mantissa for E4M3 base = 2**max_exp @@ -22,9 +24,11 @@ def to_fp8(tensor): quantized = torch.round(scaled * max_mantissa) / max_mantissa return quantized.to(dtype=torch.float8_e4m3fn) + def to_int8(tensor): return torch.round(torch.clamp(tensor, -128, 127)).to(dtype=torch.int8) + def cutlass_fp8_gemm_helper( m: int, n: int, @@ -34,21 +38,25 @@ def cutlass_fp8_gemm_helper( ): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. - a = to_fp8(torch.randn((m, k), device='cuda')) - b = to_fp8(torch.randn((n, k), device='cuda').t()) - + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((n, k), device="cuda").t()) + m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 - scale_a = torch.randn((m_a_scales,1), device='cuda', dtype=torch.float32) / 10 - scale_b = torch.randn((1,n_b_scales), device='cuda', dtype=torch.float32) / 10 + scale_a = (torch.randn( + (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) - baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + def cutlass_int8_gemm_helper( m: int, n: int, @@ -58,59 +66,86 @@ def cutlass_int8_gemm_helper( ): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. - a = to_int8(torch.randn((m, k), device='cuda') * 5) - b = to_int8(torch.randn((n, k), device='cuda').t() * 5) + a = to_int8(torch.randn((m, k), device="cuda") * 5) + b = to_int8(torch.randn((n, k), device="cuda").t() * 5) m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 - scale_a = torch.randn((m_a_scales,1), device='cuda', dtype=torch.float32) / 10 - scale_b = torch.randn((1,n_b_scales), device='cuda', dtype=torch.float32) / 10 + scale_a = (torch.randn( + (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) - baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm( - m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool -): - cutlass_fp8_gemm_helper(m,n,k, per_act_token, per_out_ch) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch) + @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -def test_cutlass_int8_gemm( - m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool -): - cutlass_int8_gemm_helper(m,n,k, per_act_token, per_out_ch) +def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) + # For the following two tests: -# N and K correspond to the size of the weight matrix and likely to be multiples of a large power of two. -# In any case, the kernel will have a naive fallback when N and K are not divisible by 16 -# But M is the number of tokens and the kernel must handle any M thrown at it. +# N and K correspond to the size of the weight matrix and likely to be multiples +# of a large power of two. In any case, the kernel will have a naive fallback +# when N and K are not divisible by 16. But M is the number of tokens and the +# kernel must handle any M thrown at it. @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): - for nk in range(16, 256, 16): + for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_fp8_gemm_helper(m,nk,nk, per_act_token, per_out_ch) + cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): - for nk in range(16, 256, 16): + for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_int8_gemm_helper(m,nk,nk, per_act_token, per_out_ch) - + cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +# Test working with a subset of A and B +def test_cutlass_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + + whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5) + whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5) + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 10ea9126a63b0..ca333b12cecae 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2,12 +2,12 @@ import torch -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] - try: from vllm._C import cache_ops as vllm_cache_ops from vllm._C import ops as vllm_ops + + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] except ImportError: pass @@ -49,10 +49,21 @@ def paged_attention_v1( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, kv_scale) + vllm_ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) def paged_attention_v2( @@ -73,11 +84,24 @@ def paged_attention_v2( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, - max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale) + vllm_ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) # pos encoding ops @@ -89,126 +113,208 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, - is_neox) + vllm_ops.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - vllm_ops.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) +def batched_rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor, +) -> None: + vllm_ops.batched_rotary_embedding( + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox, + rot_dim, + cos_sin_cache_offsets, + ) # layer norm ops -def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: +def rms_norm( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float +) -> None: vllm_ops.rms_norm(out, input, weight, epsilon) -def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: +def fused_add_rms_norm( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> None: vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops # awq -def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: - return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, - thy) - - -def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, - scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: +def awq_dequantize( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: int, + thx: int, + thy: int, +) -> torch.Tensor: + return vllm_ops.awq_dequantize( + qweight, scales, zeros, split_k_iters, thx, thy + ) + + +def awq_gemm( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq -def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, use_exllama: bool, - bit: int) -> torch.Tensor: - return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) - - -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, - bit: int) -> None: +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, +) -> torch.Tensor: + return vllm_ops.gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit + ) + + +def gptq_shuffle( + q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int +) -> None: vllm_ops.gptq_shuffle(q_weight, q_perm, bit) # squeezellm -def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, - lookup_table: torch.Tensor) -> None: +def squeezellm_gemm( + vec: torch.Tensor, + mat: torch.Tensor, + mul: torch.Tensor, + lookup_table: torch.Tensor, +) -> None: vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin -def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) - -# cutlass -def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor) -> torch.Tensor: - +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return vllm_ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# cutlass +def cutlass_scaled_mm_dq( + a: torch.Tensor, + b: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, +) -> torch.Tensor: shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 - if capability < 800 or shape_fallback: + if capability < 80 or shape_fallback: a_bf16 = a.to(dtype=torch.bfloat16) b_bf16 = b.to(dtype=torch.bfloat16) - return (b_scales * (a_scales * torch.mm(a_bf16, b_bf16))).to(dtype=torch.bfloat16) + return (b_scales * (a_scales * torch.mm(a_bf16, b_bf16))).to( + dtype=torch.bfloat16 + ) + else: m = a.shape[0] n = b.shape[1] - out = torch.empty((m,n), dtype=torch.bfloat16, device="cuda") - if capability >= 900: - vllm_ops.cutlass_scaled_mm_dq_sm90(out, a, b, a_scales, b_scales) - elif capability >= 890: - vllm_ops.cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales) - else: - vllm_ops.cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales) - return out + out = torch.empty((m, n), dtype=torch.bfloat16, device="cuda") + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) -# aqlm -def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) + return out -def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: torch.Tensor) -> torch.Tensor: +# aqlm +def aqlm_gemm( + input: torch.Tensor, + codes: torch.Tensor, + codebooks: torch.Tensor, + scales: torch.Tensor, + codebook_partition_sizes: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + return vllm_ops.aqlm_gemm( + input, codes, codebooks, scales, codebook_partition_sizes, bias + ) + + +def aqlm_dequant( + codes: torch.Tensor, + codebooks: torch.Tensor, + codebook_partition_sizes: torch.Tensor, +) -> torch.Tensor: return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) # gptq_marlin -def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) - - -def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, size_k: int, - is_k_full: bool) -> torch.Tensor: - return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, num_bits, size_m, size_n, - size_k, is_k_full) +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack( + b_q_weight, perm, size_k, size_n, num_bits + ) + + +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, +) -> torch.Tensor: + return vllm_ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + g_idx, + perm, + workspace, + num_bits, + size_m, + size_n, + size_k, + is_k_full, + ) # fp8 @@ -252,13 +358,22 @@ def scaled_fp8_quant( # moe -def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: - vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) +def moe_align_block_size( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + vllm_ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) def reshape_and_cache( @@ -270,8 +385,15 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + vllm_cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + kv_scale, + ) def reshape_and_cache_flash( @@ -282,12 +404,16 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + vllm_cache_ops.reshape_and_cache_flash( + key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype + ) -def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, - block_mapping: torch.Tensor) -> None: +def copy_blocks( + key_caches: torch.Tensor, + value_caches: torch.Tensor, + block_mapping: torch.Tensor, +) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) @@ -303,4 +429,4 @@ def convert_fp8(output: torch.Tensor, vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) -#TODO: cuda_utils, custom_ar +# TODO: cuda_utils, custom_ar From 9e746c5f02e1850b06bd8617f6ddb60168198972 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 13 May 2024 15:33:33 +0000 Subject: [PATCH 04/18] format --- vllm/_custom_ops.py | 50 ++++++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ca333b12cecae..eba919d06464f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -113,9 +113,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - vllm_ops.rotary_embedding( - positions, query, key, head_size, cos_sin_cache, is_neox - ) + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) def batched_rotary_embedding( @@ -141,9 +140,8 @@ def batched_rotary_embedding( # layer norm ops -def rms_norm( - out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float -) -> None: +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: vllm_ops.rms_norm(out, input, weight, epsilon) @@ -166,9 +164,8 @@ def awq_dequantize( thx: int, thy: int, ) -> torch.Tensor: - return vllm_ops.awq_dequantize( - qweight, scales, zeros, split_k_iters, thx, thy - ) + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + thy) def awq_gemm( @@ -191,14 +188,12 @@ def gptq_gemm( use_exllama: bool, bit: int, ) -> torch.Tensor: - return vllm_ops.gptq_gemm( - a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit - ) + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) -def gptq_shuffle( - q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int -) -> None: +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: vllm_ops.gptq_shuffle(q_weight, q_perm, bit) @@ -222,9 +217,8 @@ def marlin_gemm( size_n: int, size_k: int, ) -> torch.Tensor: - return vllm_ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) # cutlass @@ -240,9 +234,8 @@ def cutlass_scaled_mm_dq( a_bf16 = a.to(dtype=torch.bfloat16) b_bf16 = b.to(dtype=torch.bfloat16) - return (b_scales * (a_scales * torch.mm(a_bf16, b_bf16))).to( - dtype=torch.bfloat16 - ) + return (b_scales * + (a_scales * torch.mm(a_bf16, b_bf16))).to(dtype=torch.bfloat16) else: m = a.shape[0] @@ -263,9 +256,8 @@ def aqlm_gemm( codebook_partition_sizes: torch.Tensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - return vllm_ops.aqlm_gemm( - input, codes, codebooks, scales, codebook_partition_sizes, bias - ) + return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) def aqlm_dequant( @@ -284,9 +276,8 @@ def gptq_marlin_repack( size_n: int, num_bits: int, ) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack( - b_q_weight, perm, size_k, size_n, num_bits - ) + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) def gptq_marlin_gemm( @@ -404,9 +395,8 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - vllm_cache_ops.reshape_and_cache_flash( - key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype - ) + vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) def copy_blocks( From 15c46ca06a322c565c849d53d73cdd4cd96677f9 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 13 May 2024 15:46:13 +0000 Subject: [PATCH 05/18] fixup --- vllm/_custom_ops.py | 282 +++++++++++--------------------------------- 1 file changed, 70 insertions(+), 212 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eba919d06464f..0a12046122d52 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,13 +1,10 @@ -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch try: from vllm._C import cache_ops as vllm_cache_ops from vllm._C import ops as vllm_ops - - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] except ImportError: pass @@ -49,21 +46,10 @@ def paged_attention_v1( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, kv_scale) def paged_attention_v2( @@ -84,24 +70,11 @@ def paged_attention_v2( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_ops.paged_attention_v2( - out, - exp_sum, - max_logits, - tmp_out, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, + max_seq_len, alibi_slopes, kv_cache_dtype, + kv_scale) # pos encoding ops @@ -117,26 +90,14 @@ def rotary_embedding( is_neox) -def batched_rotary_embedding( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - head_size: int, - cos_sin_cache: torch.Tensor, - is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor, -) -> None: - vllm_ops.batched_rotary_embedding( - positions, - query, - key, - head_size, - cos_sin_cache, - is_neox, - rot_dim, - cos_sin_cache_offsets, - ) +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) # layer norm ops @@ -145,49 +106,30 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, vllm_ops.rms_norm(out, input, weight, epsilon) -def fused_add_rms_norm( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - epsilon: float, -) -> None: +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops # awq -def awq_dequantize( - qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - split_k_iters: int, - thx: int, - thy: int, -) -> torch.Tensor: +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) -def awq_gemm( - input: torch.Tensor, - qweight: torch.Tensor, - qzeros: torch.Tensor, - scales: torch.Tensor, - split_k_iters: int, -) -> torch.Tensor: +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq -def gptq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, - use_exllama: bool, - bit: int, -) -> torch.Tensor: +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit) @@ -198,25 +140,15 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, # squeezellm -def squeezellm_gemm( - vec: torch.Tensor, - mat: torch.Tensor, - mul: torch.Tensor, - lookup_table: torch.Tensor, -) -> None: +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, size_n, size_k) @@ -230,6 +162,9 @@ def cutlass_scaled_mm_dq( ) -> torch.Tensor: shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < 80 or shape_fallback: a_bf16 = a.to(dtype=torch.bfloat16) b_bf16 = b.to(dtype=torch.bfloat16) @@ -248,98 +183,43 @@ def cutlass_scaled_mm_dq( # aqlm -def aqlm_gemm( - input: torch.Tensor, - codes: torch.Tensor, - codebooks: torch.Tensor, - scales: torch.Tensor, - codebook_partition_sizes: torch.Tensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, codebook_partition_sizes, bias) -def aqlm_dequant( - codes: torch.Tensor, - codebooks: torch.Tensor, - codebook_partition_sizes: torch.Tensor, -) -> torch.Tensor: +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: torch.Tensor) -> torch.Tensor: return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) # gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, -) -> torch.Tensor: - return vllm_ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - g_idx, - perm, - workspace, - num_bits, - size_m, - size_n, - size_k, - is_k_full, - ) +def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, size_k: int, + is_k_full: bool) -> torch.Tensor: + return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, - batch_dim_padding: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensor for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - batch_dim_padding: If specified, pad the first dimension - of the output to at least this value. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - if batch_dim_padding: - shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) - output = torch.empty(shape, - device=input.device, - dtype=torch.float8_e4m3fn) - else: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) @@ -349,22 +229,13 @@ def scaled_fp8_quant( # moe -def moe_align_block_size( - topk_ids: torch.Tensor, - num_experts: int, - block_size: int, - sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, -) -> None: - vllm_ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_token_ids, - experts_ids, - num_tokens_post_pad, - ) +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) def reshape_and_cache( @@ -376,15 +247,8 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - kv_scale, - ) + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) def reshape_and_cache_flash( @@ -399,24 +263,18 @@ def reshape_and_cache_flash( slot_mapping, kv_cache_dtype) -def copy_blocks( - key_caches: torch.Tensor, - value_caches: torch.Tensor, - block_mapping: torch.Tensor, -) -> None: +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: torch.Tensor) -> None: + block_mapping: Dict[int, int]) -> None: vllm_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, - input: torch.Tensor, - scale: float = 1.0, - kv_dtype: str = "fp8") -> None: - vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: + vllm_cache_ops.convert_fp8(output, input) -# TODO: cuda_utils, custom_ar +#TODO: cuda_utils, custom_ar From 73fe18acbce77b58e7c0ddf2a7ecc43778f33443 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 13 May 2024 15:47:36 +0000 Subject: [PATCH 06/18] fixup --- vllm/_custom_ops.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0a12046122d52..a2fccc27d2ad3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch @@ -218,8 +218,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) @@ -269,12 +295,15 @@ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: Dict[int, int]) -> None: + block_mapping: torch.Tensor) -> None: vllm_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: - vllm_cache_ops.convert_fp8(output, input) +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) #TODO: cuda_utils, custom_ar From 22316910e3c20b20af785f7ed2869b94977f9af8 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 13 May 2024 21:26:08 +0000 Subject: [PATCH 07/18] Only compile the cutlass kernel for sm90a. Saves 17 MB --- CMakeLists.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ff0c3b9d50d2..986e205d5cf03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,6 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0;9.0a") -set(CMAKE_CUDA_FLAGS "-gencode arch=compute_90a,code=sm_90a ${CMAKE_CUDA_FLAGS}") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") @@ -194,6 +193,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass/scaled_mm_dq_entry.cu" "csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu" "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu") + + set_source_files_properties( + "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() define_gpu_extension_target( From a4c88df383b2019354d687a0de3f97cfcaddc2e6 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 13 May 2024 21:35:40 +0000 Subject: [PATCH 08/18] add comment --- CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 986e205d5cf03..c7dbda77c89be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -194,6 +194,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu" "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu") + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. set_source_files_properties( "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu" PROPERTIES From 47e95adff775399e5b4893d4afcdc035fc09e51d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 14 May 2024 16:26:58 +0000 Subject: [PATCH 09/18] Optimizations for Ampere, scalar broadcast checks --- .../cutlass_visitor_2x_broadcast_epilogue.hpp | 30 +++++--- .../quantization/cutlass/scaled_mm_dq_sm8x.cu | 77 ++++++++++--------- 2 files changed, 59 insertions(+), 48 deletions(-) diff --git a/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp index d2c8a3324f766..ddbee15e54ab6 100644 --- a/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp +++ b/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp @@ -45,15 +45,13 @@ #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" #include "cute/tensor.hpp" -// clang-format on +// clang-format on namespace cutlass::epilogue::threadblock { using namespace cute; using namespace detail; -using X = Underscore; - template< class ThreadMap, class Element, @@ -126,14 +124,16 @@ struct VisitorRowOrScalarBroadcast { auto src_v = filter(tC_gRow); auto coord_v = filter(tC_cRow); auto dst_v = filter(tC_rRow); - + if (params_ptr->ptr_row) { + // In this case we are loading from a row vector and broadcasting CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src_v); ++i) { bool guard = get<1>(coord_v(i)) < n; cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); } } else { + // In this case we are loading from a scalar and broadcasting VecType filled_vec; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < VecLength; i++) { @@ -142,7 +142,10 @@ struct VisitorRowOrScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src_v); ++i) { - dst_v(i) = filled_vec; + if(get<1>(coord_v(i)) < n) + { + dst_v(i) = filled_vec; + } } } } @@ -268,19 +271,24 @@ struct VisitorColOrScalarBroadcast { begin_epilogue() { clear(tC_rCol); + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + if (params_ptr->ptr_col) { - Tensor pred = make_tensor(shape(tC_gCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tC_cCol(i)) < m; - } + // In this case we are loading from a column vector and broadcasting copy_if(pred, tC_gCol, tC_rCol); } else { + // In this case we are loading from a scalar and broadcasting auto dst_v = filter(tC_rCol); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(dst_v); ++i) { - dst_v(i) = params_ptr->null_default; + if(pred(i)){ + dst_v(i) = params_ptr->null_default; + } } } } diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu index 4bca3ca90197a..3111f91089eba 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu @@ -25,9 +25,12 @@ #include "common.hpp" // clang-format on +using namespace cute; + ///////////////////////////////////////// -template +template struct sm8x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; @@ -42,20 +45,16 @@ struct sm8x_gemm { using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, float, 4, - 1 /* epilogue stages */ + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ >; using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, float, - cute::Stride, cute::Int<0>, cute::Int<0>>>; + OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, float, - cute::Stride, cute::Int<1>, cute::Int<0>>>; + OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, @@ -74,27 +73,21 @@ struct sm8x_gemm { using D = cutlass::epilogue::threadblock::VisitorAuxStore< OutputTileThreadMap, cutlass::bfloat16_t, cutlass::FloatRoundStyle::round_to_nearest, - cute::Stride, cute::Int<0>>>; + Stride, Int<0>>>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT; - // Gemm operator cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16 - using cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementAB, cutlass::layout::RowMajor, - cutlass::ComplexTransform::kNone, 16, ElementAB, - cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 16, - float, cutlass::layout::RowMajor, 4, ElementAcc, float, - cutlass::arch::OpClassTensorOp, Arch, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 32>, EVTD, - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3, Operator, - 1 /* epilogue stages */ - >::GemmKernel; - - using Op = cutlass::gemm::device::GemmUniversalAdapter< - cutlass_tensorop_f32_i16832gemm_s8_256x128_64x3_tn_align16_base>; + using KernelType = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + 16, ElementAB, cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, 16, float, cutlass::layout::RowMajor, 4, + ElementAcc, float, cutlass::arch::OpClassTensorOp, Arch, TileShape, + WarpShape, InstructionShape, EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 5, Operator, + 1 /* epilogue stages */ + >::GemmKernel; + + using Op = cutlass::gemm::device::GemmUniversalAdapter; }; ///////////////////////////////////////// @@ -116,8 +109,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); - using StrideC = cute::Stride, cute::Int<0>>; - StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + using StrideC = Stride, Int<0>>; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); @@ -148,9 +141,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, }; typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemm, // universal mode - problem_size, // problem size - 1, // batch count + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count epilogue_args, a_ptr, b_ptr, @@ -185,15 +178,24 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, assert(b_scales.dtype() == torch::kFloat32); assert(out.dtype() == torch::kBFloat16); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm>( - out, a, b, a_scales, b_scales); + sm8x_gemm>(out, a, b, a_scales, b_scales); } void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + if (a.dtype() == torch::kInt8) { assert(b.dtype() == torch::kInt8); assert(a_scales.dtype() == torch::kFloat32); @@ -201,8 +203,8 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, assert(out.dtype() == torch::kBFloat16); return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm>( - out, a, b, a_scales, b_scales); + sm8x_gemm>(out, a, b, a_scales, b_scales); } else { assert(a.dtype() == torch::kFloat8_e4m3fn); assert(b.dtype() == torch::kFloat8_e4m3fn); @@ -210,8 +212,9 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, assert(b_scales.dtype() == torch::kFloat32); assert(out.dtype() == torch::kBFloat16); - return cutlass_scaled_mm_dq_dispatcher>( + return cutlass_scaled_mm_dq_dispatcher< + sm8x_gemm>( out, a, b, a_scales, b_scales); } } From 62a252bb8bd97777d7cd22dccdf59f5c810d933e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 14 May 2024 17:07:41 +0000 Subject: [PATCH 10/18] review comments --- csrc/quantization/cutlass/common.hpp | 47 +++------------- .../cutlass/scaled_mm_dq_entry.cu | 5 +- .../quantization/cutlass/scaled_mm_dq_sm8x.cu | 53 ++++++++++++++----- .../quantization/cutlass/scaled_mm_dq_sm90.cu | 40 +++++++++++--- 4 files changed, 81 insertions(+), 64 deletions(-) diff --git a/csrc/quantization/cutlass/common.hpp b/csrc/quantization/cutlass/common.hpp index 39b72351f7030..999b7b251ab33 100644 --- a/csrc/quantization/cutlass/common.hpp +++ b/csrc/quantization/cutlass/common.hpp @@ -1,47 +1,12 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - #pragma once -// Taken from cutlass/examples/common/helper.h +#include "cutlass/cutlass.h" /** - * Panic wrapper for unwinding CUTLASS errors + * Helper function for checking CUTLASS errors */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ - << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ } diff --git a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu index 6139acef44f3f..8e8f94c0b9bb6 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu @@ -34,9 +34,8 @@ void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, cutlass_scaled_mm_dq_sm90(out, a, b, a_scales, b_scales); } else if (version_num == 89) /* Ada Lovelace */ { cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales); - } else if (version_num >= 80) /* Ampere */ { + } else /* Ampere */ { + TORCH_CHECK(version_num >= 80); cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales); - } else { - throw std::runtime_error("Unsupported GPU architecture"); } } diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu index 3111f91089eba..e7e2c5daf3c68 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu @@ -1,7 +1,6 @@ -#include - #include #include +#include // clang-format will break include orders // clang-format off @@ -29,11 +28,31 @@ using namespace cute; ///////////////////////////////////////// +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). + + A and B may be either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. They must have + symmetric quantization. + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + template struct sm8x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; + using ElementAcc = typename std::conditional, int32_t, float>::type; @@ -77,15 +96,22 @@ struct sm8x_gemm { using EVTD = cutlass::epilogue::threadblock::Sm80EVT; - using KernelType = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementAB, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, - 16, ElementAB, cutlass::layout::ColumnMajor, - cutlass::ComplexTransform::kNone, 16, float, cutlass::layout::RowMajor, 4, - ElementAcc, float, cutlass::arch::OpClassTensorOp, Arch, TileShape, - WarpShape, InstructionShape, EVTD, + // clang-format off + using RowMajor = typename cutlass::layout::RowMajor; + using ColumnMajor = typename cutlass::layout::ColumnMajor; + using KernelType = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, + ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, + ElementAcc, float, cutlass::arch::OpClassTensorOp, + Arch, + TileShape, WarpShape, InstructionShape, + EVTD, cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 5, Operator, 1 /* epilogue stages */ >::GemmKernel; + // clang-format on using Op = cutlass::gemm::device::GemmUniversalAdapter; }; @@ -119,6 +145,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, auto a_scales_ptr = a_scales.data_ptr(); auto b_scales_ptr = b_scales.data_ptr(); + // If A and B are quantized per-tensor, then these scale tensors are scalars, + // and they are passed in via the second argument. using ScaleAArgs = typename Gemm::ScaleA::Arguments; ScaleAArgs a_args = a_scales.numel() == 1 ? ScaleAArgs{nullptr, a_scales.item(), {}} @@ -141,9 +169,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, }; typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count epilogue_args, a_ptr, b_ptr, @@ -168,6 +196,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, CUTLASS_CHECK(status); } +} // namespace + void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, @@ -191,7 +221,6 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu b/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu index 1c6ab40a90e6f..8fb363f59f4ea 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu @@ -24,6 +24,25 @@ using namespace cute; ///////////////////////////////////////// +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for + NVIDIA GPUs with sm90a (Hopper) or later. + + A and B may be either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. They must have + symmetric quantization. + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + template @@ -80,15 +99,19 @@ struct sm90_gemm { static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + // clang-format off using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, - cutlass::layout::RowMajor, 16, ElementAB, - cutlass::layout::ColumnMajor, 16, ElementAcc, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - CEStorageSize)>, + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, KernelSchedule>::CollectiveOp; + // clang-format on using KernelType = cutlass::gemm::kernel::GemmUniversal< cute::Shape, CollectiveMainloop, CollectiveEpilogue, @@ -99,9 +122,6 @@ struct sm90_gemm { ///////////////////////////////////////// -using StrideA = Stride, Int<0>>; -using StrideB = Stride, Int<0>>; - template void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -118,7 +138,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; using StrideC = typename Gemm::StrideC; + StrideA a_stride{lda, Int<1>{}, Int<0>{}}; StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; @@ -161,6 +184,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, cutlass::Status status = gemm_op.run(args); CUTLASS_CHECK(status); } +} // namespace void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, From cdb8e0663a4375dad710f622f019f940ea397519 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 14 May 2024 21:50:10 +0000 Subject: [PATCH 11/18] Support SM75 and fp16 output types --- CMakeLists.txt | 6 +- ...aled_mm_dq_sm8x.cu => scaled_mm_dq_c2x.cu} | 96 ++++++++++++++----- ...aled_mm_dq_sm90.cu => scaled_mm_dq_c3x.cu} | 36 +++++-- .../cutlass/scaled_mm_dq_entry.cu | 14 ++- tests/kernels/test_cutlass.py | 27 +++++- vllm/_custom_ops.py | 14 +-- 6 files changed, 146 insertions(+), 47 deletions(-) rename csrc/quantization/cutlass/{scaled_mm_dq_sm8x.cu => scaled_mm_dq_c2x.cu} (70%) rename csrc/quantization/cutlass/{scaled_mm_dq_sm90.cu => scaled_mm_dq_c3x.cu} (86%) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7dbda77c89be..584c8e071999f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -191,15 +191,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu" "csrc/quantization/cutlass/scaled_mm_dq_entry.cu" - "csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu" - "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu") + "csrc/quantization/cutlass/scaled_mm_dq_c2x.cu" + "csrc/quantization/cutlass/scaled_mm_dq_c3x.cu") # # The CUTLASS kernels for Hopper require sm90a to be enabled. # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. # That adds an extra 17MB to compiled binary, so instead we selectively enable it. set_source_files_properties( - "csrc/quantization/cutlass/scaled_mm_dq_sm90.cu" + "csrc/quantization/cutlass/scaled_mm_dq_c3x.cu" PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu b/csrc/quantization/cutlass/scaled_mm_dq_c2x.cu similarity index 70% rename from csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu rename to csrc/quantization/cutlass/scaled_mm_dq_c2x.cu index e7e2c5daf3c68..cff1b3e7edb91 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm8x.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_c2x.cu @@ -12,6 +12,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm_coord.h" +#include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/device/gemm.h" @@ -37,7 +38,7 @@ using namespace cute; per-row. B can be quantized per-tensor or per-column. They must have symmetric quantization. - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. ScaleA and ScaleB define the epilogue functions that apply the scales for @@ -48,8 +49,9 @@ using namespace cute; namespace { template -struct sm8x_gemm { + typename TileShape, typename WarpShape, typename InstructionShape, + int32_t MainLoopStages> +struct cutlass_2x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; @@ -90,7 +92,7 @@ struct sm8x_gemm { cutlass::epilogue::threadblock::Sm80EVT; using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, cutlass::bfloat16_t, + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, Stride, Int<0>>>; @@ -108,7 +110,8 @@ struct sm8x_gemm { Arch, TileShape, WarpShape, InstructionShape, EVTD, - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 5, Operator, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + MainLoopStages, Operator, 1 /* epilogue stages */ >::GemmKernel; // clang-format on @@ -169,9 +172,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, }; typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count epilogue_args, a_ptr, b_ptr, @@ -196,7 +199,34 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, CUTLASS_CHECK(status); } -} // namespace +} // namespace + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + assert(a.dtype() == torch::kInt8); + assert(b.dtype() == torch::kInt8); + assert(a_scales.dtype() == torch::kFloat32); + assert(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -206,15 +236,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, assert(b.dtype() == torch::kInt8); assert(a_scales.dtype() == torch::kFloat32); assert(b_scales.dtype() == torch::kFloat32); - assert(out.dtype() == torch::kBFloat16); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm>(out, a, b, a_scales, b_scales); + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } } void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, @@ -229,21 +267,35 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, assert(b.dtype() == torch::kInt8); assert(a_scales.dtype() == torch::kFloat32); assert(b_scales.dtype() == torch::kFloat32); - assert(out.dtype() == torch::kBFloat16); - return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm>(out, a, b, a_scales, b_scales); + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } } else { assert(a.dtype() == torch::kFloat8_e4m3fn); assert(b.dtype() == torch::kFloat8_e4m3fn); assert(a_scales.dtype() == torch::kFloat32); assert(b_scales.dtype() == torch::kFloat32); - assert(out.dtype() == torch::kBFloat16); - return cutlass_scaled_mm_dq_dispatcher< - sm8x_gemm>( - out, a, b, a_scales, b_scales); + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } } } diff --git a/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu b/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu similarity index 86% rename from csrc/quantization/cutlass/scaled_mm_dq_sm90.cu rename to csrc/quantization/cutlass/scaled_mm_dq_c3x.cu index 8fb363f59f4ea..be81d181fdda0 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_sm90.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu @@ -46,7 +46,7 @@ namespace { template -struct sm90_gemm { +struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = @@ -184,7 +184,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, cutlass::Status status = gemm_op.run(args); CUTLASS_CHECK(status); } -} // namespace +} // namespace void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -197,10 +197,18 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - return cutlass_scaled_mm_dq_dispatcher< - sm90_gemm>(out, a, b, a_scales, - b_scales); + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } } else { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; @@ -209,9 +217,17 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; - return cutlass_scaled_mm_dq_dispatcher< - sm90_gemm>( - out, a, b, a_scales, b_scales); + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } } } diff --git a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu index 8e8f94c0b9bb6..9ac6dea9a9033 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu @@ -1,13 +1,20 @@ #include +void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); + void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); + void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, @@ -34,8 +41,11 @@ void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, cutlass_scaled_mm_dq_sm90(out, a, b, a_scales, b_scales); } else if (version_num == 89) /* Ada Lovelace */ { cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales); - } else /* Ampere */ { - TORCH_CHECK(version_num >= 80); + } else if (version_num >= 80) /* Ampere */ { cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales); + } else + { + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_dq_sm75(out, a, b, a_scales, b_scales); } } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index adfa79d9f2a7d..0572d4362fd45 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +from typing import Type + import pytest import torch @@ -35,6 +37,7 @@ def cutlass_fp8_gemm_helper( k: int, per_token_act_quant: bool, per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16 ): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. @@ -49,10 +52,10 @@ def cutlass_fp8_gemm_helper( scale_b = (torch.randn( (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) - out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * - b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + b.to(dtype=torch.float32)).to(out_dtype) assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) @@ -63,6 +66,7 @@ def cutlass_int8_gemm_helper( k: int, per_token_act_quant: bool, per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16 ): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. @@ -77,10 +81,10 @@ def cutlass_int8_gemm_helper( scale_b = (torch.randn( (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) - out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * - b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + b.to(dtype=torch.float32)).to(dtype=out_dtype) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) @@ -106,6 +110,21 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool): cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, + per_out_ch: bool, out_dtype: Type[torch.dtype]): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, out_dtype) + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, + per_out_ch: bool, out_dtype: Type[torch.dtype]): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, out_dtype) # For the following two tests: # N and K correspond to the size of the weight matrix and likely to be multiples diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a2fccc27d2ad3..f4760c8f65849 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Type import torch @@ -159,23 +159,25 @@ def cutlass_scaled_mm_dq( b: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, + out_dtype : Type[torch.dtype] = torch.bfloat16 ) -> torch.Tensor: shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 + assert(out_dtype is torch.bfloat16 or out_dtype is torch.float16) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] - if capability < 80 or shape_fallback: - a_bf16 = a.to(dtype=torch.bfloat16) - b_bf16 = b.to(dtype=torch.bfloat16) + if capability < 75 or shape_fallback: + a_float = a.to(out_dtype) + b_float = b.to(out_dtype) return (b_scales * - (a_scales * torch.mm(a_bf16, b_bf16))).to(dtype=torch.bfloat16) + (a_scales * torch.mm(a_float, b_float))).to(dtype=out_dtype) else: m = a.shape[0] n = b.shape[1] - out = torch.empty((m, n), dtype=torch.bfloat16, device="cuda") + out = torch.empty((m, n), dtype=out_dtype, device="cuda") vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) From 5be8af785258a4874f36264f662cf58bc7718799 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 14 May 2024 21:58:47 +0000 Subject: [PATCH 12/18] format.sh --- tests/kernels/test_cutlass.py | 48 +++++++++++++++++------------------ vllm/_custom_ops.py | 13 +++++----- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 0572d4362fd45..8554303b90ad4 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -31,14 +31,12 @@ def to_int8(tensor): return torch.round(torch.clamp(tensor, -128, 127)).to(dtype=torch.int8) -def cutlass_fp8_gemm_helper( - m: int, - n: int, - k: int, - per_token_act_quant: bool, - per_out_channel_weight_quant: bool, - out_dtype: Type[torch.dtype] = torch.bfloat16 -): +def cutlass_fp8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_fp8(torch.randn((m, k), device="cuda")) @@ -54,20 +52,17 @@ def cutlass_fp8_gemm_helper( out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * - b.to(dtype=torch.float32)).to(out_dtype) + scale_b * b.to(dtype=torch.float32)).to(out_dtype) assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) -def cutlass_int8_gemm_helper( - m: int, - n: int, - k: int, - per_token_act_quant: bool, - per_out_channel_weight_quant: bool, - out_dtype: Type[torch.dtype] = torch.bfloat16 -): +def cutlass_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_int8(torch.randn((m, k), device="cuda") * 5) @@ -110,21 +105,26 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool): cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) + @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) -def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, - per_out_ch: bool, out_dtype: Type[torch.dtype]): - cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, out_dtype) +def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, - per_out_ch: bool, out_dtype: Type[torch.dtype]): - cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, out_dtype) +def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + # For the following two tests: # N and K correspond to the size of the weight matrix and likely to be multiples diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f4760c8f65849..8221fc7011306 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -155,14 +155,13 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass def cutlass_scaled_mm_dq( - a: torch.Tensor, - b: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - out_dtype : Type[torch.dtype] = torch.bfloat16 -) -> torch.Tensor: + a: torch.Tensor, + b: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + out_dtype: Type[torch.dtype] = torch.bfloat16) -> torch.Tensor: shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 - assert(out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] From ca43e2c59060b55e3ddf0a6f55705572cb2d0de2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 15 May 2024 14:32:05 +0000 Subject: [PATCH 13/18] use cudaDeviceGetAttribute --- csrc/quantization/cutlass/scaled_mm_dq_c3x.cu | 2 ++ .../cutlass/scaled_mm_dq_entry.cu | 19 +++++++------------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu index be81d181fdda0..bf8c783f9b4ac 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu @@ -204,6 +204,7 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, out, a, b, a_scales, b_scales); } else { assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< cutlass_3x_gemm>( @@ -224,6 +225,7 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, out, a, b, a_scales, b_scales); } else { assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< cutlass_3x_gemm>( diff --git a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu index 9ac6dea9a9033..3e07e078dc238 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass/scaled_mm_dq_entry.cu @@ -1,4 +1,5 @@ #include +#include void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -23,16 +24,11 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { - // TODO(tms): Hack. cudaGetDeviceProperties is very slow. - static std::optional maybe_version_num = std::nullopt; - if (!maybe_version_num.has_value()) { - int device; - cudaGetDevice(&device); - cudaDeviceProp properties; - cudaGetDeviceProperties(&properties, device); - maybe_version_num = properties.major * 10 + properties.minor; - } - int32_t version_num = maybe_version_num.value(); + int32_t major_capability; + int32_t minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0); + int32_t version_num = major_capability * 10 + minor_capability; if (version_num >= 90) /* H100 */ { // TODO: This kernel only works for sm90a @@ -43,8 +39,7 @@ void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales); } else if (version_num >= 80) /* Ampere */ { cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales); - } else - { + } else { TORCH_CHECK(version_num >= 75); cutlass_scaled_mm_dq_sm75(out, a, b, a_scales, b_scales); } From d09d9731f7b7d12a2661fec1810372f63b873a76 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 15 May 2024 16:38:33 +0000 Subject: [PATCH 14/18] Review comments --- CMakeLists.txt | 8 +-- .../{cutlass => cutlass_w8a8}/common.hpp | 0 .../cutlass_visitor_2x_broadcast_epilogue.hpp | 0 .../scaled_mm_dq_c2x.cu | 0 .../scaled_mm_dq_c3x.cu | 0 .../scaled_mm_dq_entry.cu | 3 ++ tests/kernels/test_cutlass.py | 53 ++++++++++++++----- 7 files changed, 48 insertions(+), 16 deletions(-) rename csrc/quantization/{cutlass => cutlass_w8a8}/common.hpp (100%) rename csrc/quantization/{cutlass => cutlass_w8a8}/cutlass_visitor_2x_broadcast_epilogue.hpp (100%) rename csrc/quantization/{cutlass => cutlass_w8a8}/scaled_mm_dq_c2x.cu (100%) rename csrc/quantization/{cutlass => cutlass_w8a8}/scaled_mm_dq_c3x.cu (100%) rename csrc/quantization/{cutlass => cutlass_w8a8}/scaled_mm_dq_entry.cu (95%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 584c8e071999f..bb4045dcb01fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0;9.0a") +set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") @@ -190,9 +190,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu" - "csrc/quantization/cutlass/scaled_mm_dq_entry.cu" - "csrc/quantization/cutlass/scaled_mm_dq_c2x.cu" - "csrc/quantization/cutlass/scaled_mm_dq_c3x.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") # # The CUTLASS kernels for Hopper require sm90a to be enabled. diff --git a/csrc/quantization/cutlass/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp similarity index 100% rename from csrc/quantization/cutlass/common.hpp rename to csrc/quantization/cutlass_w8a8/common.hpp diff --git a/csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp similarity index 100% rename from csrc/quantization/cutlass/cutlass_visitor_2x_broadcast_epilogue.hpp rename to csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp diff --git a/csrc/quantization/cutlass/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu similarity index 100% rename from csrc/quantization/cutlass/scaled_mm_dq_c2x.cu rename to csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu diff --git a/csrc/quantization/cutlass/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu similarity index 100% rename from csrc/quantization/cutlass/scaled_mm_dq_c3x.cu rename to csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu diff --git a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu similarity index 95% rename from csrc/quantization/cutlass/scaled_mm_dq_entry.cu rename to csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index 3e07e078dc238..cefb89a4c8d81 100644 --- a/csrc/quantization/cutlass/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,5 +1,6 @@ #include #include +#include void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -30,6 +31,8 @@ void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0); int32_t version_num = major_capability * 10 + minor_capability; + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + if (version_num >= 90) /* H100 */ { // TODO: This kernel only works for sm90a // -- figure out how to detect 90a vs 90 diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 8554303b90ad4..1700d341e7d0c 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -9,11 +9,15 @@ from vllm import _custom_ops as ops +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] -def to_fp8(tensor): +def to_fp8(tensor: torch.tensor): # Assuming input tensor is float32 # Scale tensor to range of FP8 E4M3 # by clamping exponent and truncating mantissa @@ -27,7 +31,7 @@ def to_fp8(tensor): return quantized.to(dtype=torch.float8_e4m3fn) -def to_int8(tensor): +def to_int8(tensor: torch.tensor): return torch.round(torch.clamp(tensor, -128, 127)).to(dtype=torch.int8) @@ -36,19 +40,23 @@ def cutlass_fp8_gemm_helper(m: int, k: int, per_token_act_quant: bool, per_out_channel_weight_quant: bool, - out_dtype: Type[torch.dtype] = torch.bfloat16): + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. - a = to_fp8(torch.randn((m, k), device="cuda")) - b = to_fp8(torch.randn((n, k), device="cuda").t()) + a = to_fp8(torch.randn((m, k), device=device)) + b = to_fp8(torch.randn((n, k), device=device).t()) + + print(a.device) + print(device) m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 scale_a = (torch.randn( - (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) scale_b = (torch.randn( - (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) + (1, n_b_scales), device=device, dtype=torch.float32) / 10) out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), @@ -62,19 +70,20 @@ def cutlass_int8_gemm_helper(m: int, k: int, per_token_act_quant: bool, per_out_channel_weight_quant: bool, - out_dtype: Type[torch.dtype] = torch.bfloat16): + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. - a = to_int8(torch.randn((m, k), device="cuda") * 5) - b = to_int8(torch.randn((n, k), device="cuda").t() * 5) + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 scale_a = (torch.randn( - (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) scale_b = (torch.randn( - (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) + (1, n_b_scales), device=device, dtype=torch.float32) / 10) out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), @@ -125,6 +134,26 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, out_dtype) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + # For the following two tests: # N and K correspond to the size of the weight matrix and likely to be multiples From 9a07760851628ad4267d48e74e717e30b59959f5 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 15 May 2024 16:49:11 +0000 Subject: [PATCH 15/18] fixups, cruft --- CMakeLists.txt | 2 +- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 55 ++++++++++++------- tests/kernels/test_cutlass.py | 11 ++-- vllm/_custom_ops.py | 2 +- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bb4045dcb01fd..ecf1cfea27ff6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. # That adds an extra 17MB to compiled binary, so instead we selectively enable it. set_source_files_properties( - "csrc/quantization/cutlass/scaled_mm_dq_c3x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index cefb89a4c8d81..c221511722e1b 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,49 +1,66 @@ -#include -#include +#include #include +#include +#include -void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, +void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); -void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, +void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); -void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, +void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); -void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, +void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales); -void cutlass_scaled_mm_dq(torch::Tensor &out, torch::Tensor const &a, +void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { int32_t major_capability; int32_t minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0); + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); int32_t version_num = major_capability * 10 + minor_capability; - at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + // Checks for conformality + assert(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + assert(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + assert(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + assert(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + assert(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + assert(b.stride(0) == 1); // Column-major + assert(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + assert(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (version_num >= 90) /* H100 */ { - // TODO: This kernel only works for sm90a - // -- figure out how to detect 90a vs 90 + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); - cutlass_scaled_mm_dq_sm90(out, a, b, a_scales, b_scales); - } else if (version_num == 89) /* Ada Lovelace */ { - cutlass_scaled_mm_dq_sm89(out, a, b, a_scales, b_scales); - } else if (version_num >= 80) /* Ampere */ { - cutlass_scaled_mm_dq_sm80(out, a, b, a_scales, b_scales); + if (version_num >= 90) { + // Hopper + cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); + } else if (version_num == 89) { + // Ada Lovelace + cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); + } else if (version_num >= 80) { + // Ampere + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); } else { + // Turing TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_dq_sm75(out, a, b, a_scales, b_scales); + cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); } } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 1700d341e7d0c..c835b5b8a11f0 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -47,9 +47,6 @@ def cutlass_fp8_gemm_helper(m: int, a = to_fp8(torch.randn((m, k), device=device)) b = to_fp8(torch.randn((n, k), device=device).t()) - print(a.device) - print(device) - m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 @@ -134,6 +131,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, out_dtype) + @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -144,15 +142,14 @@ def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, torch.bfloat16, device) + @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(capability < 89, - reason="FP8 is not supported on this GPU type.") def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, - device: str): + device: str): cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, - torch.bfloat16, device) + torch.bfloat16, device) # For the following two tests: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8221fc7011306..963c291e474f5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -176,7 +176,7 @@ def cutlass_scaled_mm_dq( else: m = a.shape[0] n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device="cuda") + out = torch.empty((m, n), dtype=out_dtype, device=a.device) vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) From 4909784981c9a89af1e4886f85fb829191e87cee Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 15 May 2024 21:43:18 +0000 Subject: [PATCH 16/18] review comments --- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 23 ++++++------- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 13 +++----- tests/kernels/test_cutlass.py | 26 +++++++-------- vllm/_custom_ops.py | 32 ++++++------------- 4 files changed, 35 insertions(+), 59 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index cff1b3e7edb91..5014183449efe 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -27,16 +27,15 @@ using namespace cute; -///////////////////////////////////////// - /* This defines a quantized GEMM operation with dequantized output, similar to torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for NVIDIA GPUs with SM versions prior to sm90 (Hopper). - A and B may be either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. They must have - symmetric quantization. + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. @@ -92,8 +91,7 @@ struct cutlass_2x_gemm { cutlass::epilogue::threadblock::Sm80EVT; using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementD, - cutlass::FloatRoundStyle::round_to_nearest, + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, Stride, Int<0>>>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT; @@ -119,8 +117,6 @@ struct cutlass_2x_gemm { using Op = cutlass::gemm::device::GemmUniversalAdapter; }; -///////////////////////////////////////// - template void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -172,9 +168,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, }; typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count epilogue_args, a_ptr, b_ptr, @@ -199,7 +195,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, CUTLASS_CHECK(status); } -} // namespace +} // namespace void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -274,6 +270,7 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, TileShape, WarpShape, InstructionShape, 5>>( out, a, b, a_scales, b_scales); } else { + assert(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>( diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index bf8c783f9b4ac..115020ad9a474 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -22,16 +22,15 @@ using namespace cute; -///////////////////////////////////////// - /* This defines a quantized GEMM operation with dequantized output, similar to torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for NVIDIA GPUs with sm90a (Hopper) or later. - A and B may be either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. They must have - symmetric quantization. + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. @@ -120,8 +119,6 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; -///////////////////////////////////////// - template void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, @@ -184,7 +181,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, cutlass::Status status = gemm_op.run(args); CUTLASS_CHECK(status); } -} // namespace +} // namespace void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index c835b5b8a11f0..fdfd1dee29ce6 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -18,21 +18,13 @@ def to_fp8(tensor: torch.tensor): - # Assuming input tensor is float32 - # Scale tensor to range of FP8 E4M3 - # by clamping exponent and truncating mantissa - max_exp = 2**4 - 1 # Maximum exponent for E4M3 - max_mantissa = 2**3 - 1 # Maximum mantissa for E4M3 - base = 2**max_exp - # Scale the mantissa - scaled = torch.clamp(tensor, -base, base) - # Quantize the mantissa - quantized = torch.round(scaled * max_mantissa) / max_mantissa - return quantized.to(dtype=torch.float8_e4m3fn) + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) def to_int8(tensor: torch.tensor): - return torch.round(torch.clamp(tensor, -128, 127)).to(dtype=torch.int8) + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) def cutlass_fp8_gemm_helper(m: int, @@ -92,7 +84,7 @@ def cutlass_int8_gemm_helper(m: int, @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif(capability < 89, @@ -104,7 +96,7 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, @@ -188,7 +180,11 @@ def test_cutlass_subset(): scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b) + out = ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 963c291e474f5..4bb16be71518e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -154,33 +154,19 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass -def cutlass_scaled_mm_dq( - a: torch.Tensor, - b: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - out_dtype: Type[torch.dtype] = torch.bfloat16) -> torch.Tensor: - shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 +def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + out_dtype: Type[torch.dtype]) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) - if capability < 75 or shape_fallback: - a_float = a.to(out_dtype) - b_float = b.to(out_dtype) + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) - return (b_scales * - (a_scales * torch.mm(a_float, b_float))).to(dtype=out_dtype) - - else: - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) - - return out + return out # aqlm From 26f5890df186eac3b6c912b27788768968609900 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 16 May 2024 18:01:27 +0000 Subject: [PATCH 17/18] assert -> TORCH_CHECK --- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 36 +++++++++---------- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 14 ++++++-- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 17 +++++---- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 5014183449efe..3ec454f78c654 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,4 +1,3 @@ -#include #include #include @@ -201,10 +200,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { - assert(a.dtype() == torch::kInt8); - assert(b.dtype() == torch::kInt8); - assert(a_scales.dtype() == torch::kFloat32); - assert(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -216,7 +215,7 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, TileShape, WarpShape, InstructionShape, 2>>( out, a, b, a_scales, b_scales); } else { - assert(out.dtype() == torch::kFloat16); + TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>(out, a, b, a_scales, @@ -228,10 +227,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { - assert(a.dtype() == torch::kInt8); - assert(b.dtype() == torch::kInt8); - assert(a_scales.dtype() == torch::kFloat32); - assert(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -243,7 +242,7 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, TileShape, WarpShape, InstructionShape, 5>>( out, a, b, a_scales, b_scales); } else { - assert(out.dtype() == torch::kFloat16); + TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>(out, a, b, a_scales, @@ -259,10 +258,11 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + if (a.dtype() == torch::kInt8) { - assert(b.dtype() == torch::kInt8); - assert(a_scales.dtype() == torch::kFloat32); - assert(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher< @@ -277,10 +277,8 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, out, a, b, a_scales, b_scales); } } else { - assert(a.dtype() == torch::kFloat8_e4m3fn); - assert(b.dtype() == torch::kFloat8_e4m3fn); - assert(a_scales.dtype() == torch::kFloat32); - assert(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, b_scales); } else { - assert(out.dtype() == torch::kFloat16); + TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 115020ad9a474..37b096de23e3b 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -176,7 +176,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); - assert(workspace_size == 0); + TORCH_CHECK(workspace_size == 0); cutlass::Status status = gemm_op.run(args); CUTLASS_CHECK(status); @@ -187,7 +187,12 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const &b_scales) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using KernelSchedule = @@ -200,7 +205,7 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, KernelSchedule, EpilogueSchedule>>( out, a, b, a_scales, b_scales); } else { - assert(out.dtype() == torch::kFloat16); + TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_3x_gemm; using ClusterShape = Shape<_1, _2, _1>; using KernelSchedule = @@ -221,7 +229,7 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, ClusterShape, KernelSchedule, EpilogueSchedule>>( out, a, b, a_scales, b_scales); } else { - assert(out.dtype() == torch::kFloat16); + TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_3x_gemm #include #include #include @@ -35,17 +34,17 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, int32_t version_num = major_capability * 10 + minor_capability; // Checks for conformality - assert(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - assert(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && b.size(1) == c.size(1)); - assert(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - assert(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); // Check for strides and alignment - assert(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - assert(b.stride(0) == 1); // Column-major - assert(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment - assert(a_scales.is_contiguous() && b_scales.is_contiguous()); + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); From 39b838f5650360b1a658faf567a26e1f3a349628 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 16 May 2024 18:11:53 +0000 Subject: [PATCH 18/18] format --- vllm/_custom_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f7378e27b77b2..9e7d0d96bf004 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -162,6 +162,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, workspace, num_bits, size_m, size_n, size_k) + # cutlass def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor,