Skip to content

Commit

Permalink
misc: refactor cutlass includes (flashinfer-ai#594)
Browse files Browse the repository at this point in the history
Move cutlass utilities outside of gemm folder as we are about to use
cutlass for attention kernels.

This PR also bumps cutlass version to v3.6.0, where the cutlass include
order issue (flashinfer-ai#589) was
resolved (it was fixed by
https://github.com/NVIDIA/cutlass/blob/d656afbd2a01112c0e4d90aafe0f8f78145c6585/include/cutlass/epilogue/collective/collective_builder.hpp#L106
).
  • Loading branch information
yzh119 authored Nov 8, 2024
1 parent ce44799 commit 3467617
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 96 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 366 files
75 changes: 75 additions & 0 deletions include/flashinfer/cutlass_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_CUTLASS_UTILS_CUH_
#define FLASHINFER_CUTLASS_UTILS_CUH_

#include <cuda_runtime.h>
#include <cutlass/cutlass.h>

#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

namespace flashinfer {

template <typename T>
struct cutlass_dtype {
using value = T;
};

template <>
struct cutlass_dtype<half> {
using value = cutlass::half_t;
};

template <>
struct cutlass_dtype<nv_bfloat16> {
using value = cutlass::bfloat16_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
using value = cutlass::float_e4m3_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
using value = cutlass::float_e5m2_t;
};

} // namespace flashinfer

#endif // FLASHINFER_CUTLASS_UTILS_CUH_
2 changes: 1 addition & 1 deletion include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <sstream>

#include "../allocator.h"
#include "group_gemm_cutlass.cuh"
#include "../cutlass_utils.cuh"

namespace flashinfer {

Expand Down
63 changes: 0 additions & 63 deletions include/flashinfer/gemm/group_gemm_cutlass.cuh

This file was deleted.

26 changes: 1 addition & 25 deletions include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,11 @@
#ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
#define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_

// clang-format off
// NOTE: This header needs to be included before cutlass headers.
// See: https://github.com/NVIDIA/cutlass/issues/1827
#include "group_gemm_cutlass.cuh"
// clang-format on

#include <sstream>

#include "../allocator.h"
#include "../cutlass_utils.cuh"
#include "../utils.cuh"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

namespace flashinfer {

Expand Down
2 changes: 1 addition & 1 deletion python/csrc/group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_proble
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] {
using cutlass_t = typename cutlass_dtype<c_type>::type;
using cutlass_t = typename cutlass_dtype<c_type>::value;
auto status = CutlassSegmentGEMMRun<cutlass_t>(
workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0),
all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(),
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer,
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] {
using cutlass_t = typename cutlass_dtype<c_type>::type;
using cutlass_t = typename cutlass_dtype<c_type>::value;
auto status = CutlassSegmentGEMMSM90Run<cutlass_t, cutlass_t>(
float_workspace_buffer.data_ptr(),
float_workspace_buffer.element_size() * float_workspace_buffer.size(0),
Expand Down
8 changes: 4 additions & 4 deletions tests/test_group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_segment_gemm(


if __name__ == "__main__":
test_segment_gemm(199, 99, 128, 1024, False, False)
test_segment_gemm(199, 99, 128, 1024, False, True)
test_segment_gemm(199, 99, 128, 1024, True, False)
test_segment_gemm(199, 99, 128, 1024, True, True)
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0")

0 comments on commit 3467617

Please sign in to comment.