diff --git a/CMakeLists.txt b/CMakeLists.txt index 2051d7560be25..35846fd1cfa99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,16 @@ set(VLLM_EXT_SRC "csrc/pybind.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.0 + GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + ) + FetchContent_MakeAvailable(cutlass) + list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" @@ -180,7 +190,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/custom_all_reduce.cu") + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") + + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() define_gpu_extension_target( @@ -190,6 +214,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} WITH_SOABI) # diff --git a/csrc/ops.h b/csrc/ops.h index ef37131c962f8..8c2c2ae6e1f5a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack( int64_t size_k, int64_t size_n, int64_t num_bits); + +int cutlass_scaled_mm_dq( + torch::Tensor& out, + torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 0339eba70c013..f5b4865506568 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp new file mode 100644 index 0000000000000..999b7b251ab33 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "cutlass/cutlass.h" + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } diff --git a/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp new file mode 100644 index 0000000000000..ddbee15e54ab6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass It's beem modified to support either +// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x. +// Important because this saves us a factor 4x on the number of kernels +// compiled. +// +#pragma once + +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +// clang-format on + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if(get<1>(coord_v(i)) < n) + { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->ptr_col) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if(pred(i)){ + dst_v(i) = params_ptr->null_default; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu new file mode 100644 index 0000000000000..3ec454f78c654 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -0,0 +1,296 @@ +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/device_memory.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + +#include "cutlass_visitor_2x_broadcast_epilogue.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_2x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; + + using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::threadblock::Sm80EVT; + + using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, + Stride, Int<0>>>; + + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + + // clang-format off + using RowMajor = typename cutlass::layout::RowMajor; + using ColumnMajor = typename cutlass::layout::ColumnMajor; + using KernelType = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, + ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, + ElementAcc, float, cutlass::arch::OpClassTensorOp, + Arch, + TileShape, WarpShape, InstructionShape, + EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + MainLoopStages, Operator, + 1 /* epilogue stages */ + >::GemmKernel; + // clang-format on + + using Op = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = Stride, Int<0>>; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + auto a_scales_ptr = a_scales.data_ptr(); + auto b_scales_ptr = b_scales.data_ptr(); + + // If A and B are quantized per-tensor, then these scale tensors are scalars, + // and they are passed in via the second argument. + using ScaleAArgs = typename Gemm::ScaleA::Arguments; + ScaleAArgs a_args = a_scales.numel() == 1 + ? ScaleAArgs{nullptr, a_scales.item(), {}} + : ScaleAArgs{a_scales.data_ptr(), {}, {}}; + + using ScaleBArgs = typename Gemm::ScaleB::Arguments; + ScaleBArgs b_args = b_scales.numel() == 1 + ? ScaleBArgs{nullptr, b_scales.item(), {}} + : ScaleBArgs{b_scales.data_ptr(), {}, {}}; + + typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, + evt0_compute_args}; + typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + + typename Gemm::EVTD::Arguments epilogue_args{ + evt1_compute_args, + d_args, + }; + + typename Gemm::Op::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count + epilogue_args, + a_ptr, + b_ptr, + nullptr, + nullptr, + 0, + 0, + 0, + 0, + lda, + ldb, + ldc, + ldc}; + + // Launch the CUTLASS GEMM kernel. + typename Gemm::Op gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(args)); + cutlass::Status status = gemm_op(args, workspace.get()); + CUTLASS_CHECK(status); +} + +} // namespace + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu new file mode 100644 index 0000000000000..37b096de23e3b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -0,0 +1,240 @@ +#include + +#include +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for + NVIDIA GPUs with sm90a (Hopper) or later. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<0>, Int<0>>>; + + using ScaleBDescriptor = + cutlass::epilogue::collective::detail::RowBroadcastDescriptor< + EpilogueDescriptor, float>; + + using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute1>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + using ScaleA_Args = typename Gemm::ScaleA::Arguments; + using ScaleB_Args = typename Gemm::ScaleB::Arguments; + ScaleA_Args a_args = a_scales.numel() == 1 + ? ScaleA_Args{nullptr, a_scales.item(), {}} + : ScaleA_Args{a_scales.data_ptr(), {}, {}}; + + ScaleB_Args b_args = b_scales.numel() == 1 + ? ScaleB_Args{nullptr, b_scales.item(), {}} + : ScaleB_Args{b_scales.data_ptr(), {}, {}}; + + args.epilogue.thread = {a_args, {b_args}}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + TORCH_CHECK(workspace_size == 0); + + cutlass::Status status = gemm_op.run(args); + CUTLASS_CHECK(status); +} +} // namespace + +void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu new file mode 100644 index 0000000000000..a4e696d4a3322 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -0,0 +1,65 @@ +#include +#include +#include + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + int32_t major_capability; + int32_t minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + + if (version_num >= 90) { + // Hopper + cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); + } else if (version_num == 89) { + // Ada Lovelace + cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); + } else if (version_num >= 80) { + // Ampere + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); + } else { + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); + } +} diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py new file mode 100644 index 0000000000000..fdfd1dee29ce6 --- /dev/null +++ b/tests/kernels/test_cutlass.py @@ -0,0 +1,192 @@ +"""Tests for cutlass kernels + +Run `pytest tests/kernels/test_cutlass.py`. +""" +from typing import Type + +import pytest +import torch + +from vllm import _custom_ops as ops + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +def to_fp8(tensor: torch.tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def cutlass_fp8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_fp8(torch.randn((m, k), device=device)) + b = to_fp8(torch.randn((n, k), device=device).t()) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + + +def cutlass_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +# For the following two tests: +# N and K correspond to the size of the weight matrix and likely to be multiples +# of a large power of two. In any case, the kernel will have a naive fallback +# when N and K are not divisible by 16. But M is the number of tokens and the +# kernel must handle any M thrown at it. +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +# Test working with a subset of A and B +def test_cutlass_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + + whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5) + whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5) + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 95baa84262658..9e7d0d96bf004 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Type import torch @@ -163,6 +163,22 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_k) +# cutlass +def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + out_dtype: Type[torch.dtype]) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + + return out + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor,