Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d41cf57
add flash attention interface
jikunshang Jul 31, 2025
ce9f31d
update interface
jikunshang Aug 1, 2025
fb6784f
add cutlass deps (#1)
jikunshang Aug 4, 2025
ce27fa2
add chunk_prefill step<1>
YizhouZ Aug 7, 2025
ed0f846
fix register
YizhouZ Aug 7, 2025
b02a5a8
fix cmake
YizhouZ Aug 7, 2025
a4a76ee
debug msg
YizhouZ Aug 8, 2025
ee1b719
functional ready
YizhouZ Aug 11, 2025
4ef938f
dev base
Liangliang-Ma Aug 21, 2025
480c72f
base of grouped_gemm_fp8
Liangliang-Ma Aug 22, 2025
24709b8
update func
Liangliang-Ma Aug 26, 2025
f5757a9
add test
Liangliang-Ma Aug 29, 2025
435e6df
update functor
Liangliang-Ma Aug 29, 2025
f76fb97
update grouped_gemm
Liangliang-Ma Aug 30, 2025
9408e94
build ready
Liangliang-Ma Aug 31, 2025
439cf3c
base integration done
Liangliang-Ma Sep 1, 2025
48abd9f
grouped gemm base ready
Liangliang-Ma Sep 2, 2025
67eeb47
gemm2 use cutlass grouped_mm
Liangliang-Ma Sep 3, 2025
a62752f
gemm1 use cutlass group_mm
Liangliang-Ma Sep 3, 2025
cfb724b
rm flash_attn in this pr
Liangliang-Ma Sep 4, 2025
f7518e0
rebase CMakeLists
Liangliang-Ma Sep 4, 2025
083bde5
use main Cmakes
Liangliang-Ma Sep 4, 2025
48a4808
use main setup
Liangliang-Ma Sep 4, 2025
22d1ade
mv utils
Liangliang-Ma Sep 4, 2025
c0e70c4
Merge branch 'main' into grouped_gemm_cutlass
Liangliang-Ma Sep 4, 2025
1c7f46d
finish rebase
Liangliang-Ma Sep 4, 2025
df0b915
add profile and change to col-maj
Liangliang-Ma Sep 5, 2025
76fe4bc
dont not reserve block_C
Liangliang-Ma Sep 9, 2025
ad0fdd6
remove redundant allocation
Liangliang-Ma Sep 11, 2025
54e64a7
e2e debug
Liangliang-Ma Sep 11, 2025
3c40008
add release func
Liangliang-Ma Sep 11, 2025
985004d
gemm args allocate once
Liangliang-Ma Sep 11, 2025
9c18092
hidden_states copy
Liangliang-Ma Sep 11, 2025
a47ecef
output bf16
Liangliang-Ma Sep 14, 2025
1a2d655
use static tensor buffer
Liangliang-Ma Sep 14, 2025
f7dee65
remove ptr_C
Liangliang-Ma Sep 15, 2025
ad2dc48
fix device lost
Liangliang-Ma Sep 17, 2025
56cb570
acc and oom fixed
Liangliang-Ma Sep 19, 2025
81555ab
Fix acc and oom issue
Liangliang-Ma Sep 19, 2025
d1edf17
base
Liangliang-Ma Sep 23, 2025
55f36a8
update CMakeLists
Liangliang-Ma Sep 23, 2025
54e7219
Merge branch 'main' into grouped_gemm_cutlass
Liangliang-Ma Sep 23, 2025
513377a
refactor csrc of cutlass
Liangliang-Ma Sep 23, 2025
534c7c3
put src in vllm
Liangliang-Ma Sep 23, 2025
1fc6959
add adapter src
Liangliang-Ma Sep 23, 2025
db6b292
clean up
Liangliang-Ma Sep 24, 2025
d651d9d
add test
Liangliang-Ma Sep 24, 2025
a29cfa6
clean up
Liangliang-Ma Sep 24, 2025
c66f152
fix format
Liangliang-Ma Sep 24, 2025
a681e73
fix format f841
Liangliang-Ma Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,21 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use")

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
FetchContent_Declare(
cutlass-sycl
GIT_REPOSITORY https://github.com/intel/cutlass-sycl

# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG ${CUTLASS_REVISION}
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
GIT_SHALLOW FALSE
)

# cutlass compilation flags
Expand All @@ -196,7 +197,6 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA")
# list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " )
# list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " )


FetchContent_MakeAvailable(cutlass-sycl)
set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
Expand Down Expand Up @@ -269,11 +269,15 @@ endif ()
#
# xpu only ops/kernels, implemented with cutlass/onednn/sycl.
#
file(GLOB CUTLASS_BACKEND_SRCS
csrc/xpu/cutlass_kernels/*.cpp
)
if(VLLM_GPU_LANG STREQUAL "SYCL")
set(VLLM_EXT_XPU_SRC
"csrc/xpu/torch_bindings.cpp"
"csrc/xpu/lora/lora_shrink.cpp"
"csrc/xpu/lora/lora_expand.cpp"
${CUTLASS_BACKEND_SRCS}
)
include_directories("/usr/include")
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
Expand All @@ -282,6 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" )
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64")
list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" )
# CUTLASS FLAGS
list(APPEND VLLM_GPU_FLAGS "-O3" "-DNDEBUG")
list(APPEND VLLM_GPU_FLAGS "-gline-tables-only")
list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10")
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen")
list(APPEND VLLM_GPU_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread")
endif()

if(ONEDNN_FOUND)
Expand All @@ -305,6 +315,8 @@ define_gpu_extension_target(
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

Expand Down
1 change: 0 additions & 1 deletion csrc/core/registration.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#pragma once

#include <Python.h>

#define _CONCAT(A, B) A##B
Expand Down
306 changes: 306 additions & 0 deletions csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/

/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix
multiply-add with the appropriate threadblock-scoped epilogue.

Note, CUTLASS epilogues universally target row-major outputs. Column-major
outputs are accommodated by exchanging A and B operands and assuming
transposed layouts. Partial specializations here choose
'device::GemmTransposed' to implement this functionality.

*/

#pragma once

#include "cutlass/cutlass.h"

#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"

#include "gemm_universal_k.h"
#include "cutlass/gemm/kernel/gemm_universal_streamk.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"

#include "cutlass/layout/permute.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace kernel {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Gather operand A by using an index array
bool GatherA = false,
/// Gather operand B by using an index array
bool GatherB = false,
/// Scatter result D by using an index array
bool ScatterD = false,
/// Permute result D
typename PermuteDLayout = layout::NoPermute,
/// Permute operand A
typename PermuteALayout_ = layout::NoPermute,
/// Permute operand B
typename PermuteBLayout_ = layout::NoPermute,
///
typename Enable = void>
struct DefaultGemmUniversal;

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//

template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB,
/// Scatter result D by using an index array
bool ScatterD,
/// Permute result D
typename PermuteDLayout,
/// Permute operand A
typename PermuteALayout,
/// Permute operand B
typename PermuteBLayout>
struct DefaultGemmUniversal<
ElementA, LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA, ElementB, LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, GatherA, GatherB,
ScatterD, PermuteDLayout, PermuteALayout, PermuteBLayout,
typename platform::enable_if<
!cutlass::is_complex<ElementAccumulator>::value>::type> {
using DefaultGemmKernel = typename kernel::DefaultGemm<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
true, Operator, SharedMemoryClear, GatherA, GatherB, ScatterD,
PermuteDLayout, PermuteALayout, PermuteBLayout>::GemmKernel;

/// Universal kernel without StreamkFeature member type
template <class SwizzleT, class Enable = void>
class SelectBase
: public kernel::GemmUniversal<typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT> {};

/// Universal kernel with StreamkFeature member type
template <class SwizzleT>
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature>
: public kernel::GemmUniversalStreamk<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue, SwizzleT> {};

/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
using GemmKernel = SelectBase<ThreadblockSwizzle>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

//
// Complex-valued GEMM kernels
//

template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DefaultGemmUniversal<
ElementA, LayoutA, TransformA, kAlignmentA, ElementB, LayoutB, TransformB,
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, false, false,
false, layout::NoPermute, layout::NoPermute, layout::NoPermute,
typename platform::enable_if<
cutlass::is_complex<ElementAccumulator>::value>::type> {
using DefaultGemmKernel = typename kernel::DefaultGemmComplex<
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
TransformA, TransformB, Operator, false>::GemmKernel;

/// Universal kernel without StreamkFeature member type
template <class SwizzleT, class Enable = void>
class SelectBase
: public kernel::GemmUniversal<typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT> {};

/// Universal kernel with StreamkFeature member type
template <class SwizzleT>
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature>
: public kernel::GemmUniversalStreamk<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue, SwizzleT> {};

/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
using GemmKernel = SelectBase<ThreadblockSwizzle>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace kernel
} // namespace gemm
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
Loading