From 436b17988e373ea68ed85919e2c3d919b3b8af2e Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 19 Dec 2024 16:30:28 +0800 Subject: [PATCH 1/2] enable collective column major gemm and add case --- benchmarks/pvc/benchmarks.hpp | 33 ++++ benchmarks/pvc/input.in | 6 + include/cute/atom/copy_traits_xe.hpp | 181 +++++++++++++++------ include/cutlass/gemm/collective/xe_mma.hpp | 77 ++++----- 4 files changed, 197 insertions(+), 100 deletions(-) diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index 3745a0108..d653d6178 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -80,11 +80,41 @@ using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration< TiledMMA>>, XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V>; +using PvcGemmBF16BF16FP32_RRR_6 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, Shape<_8, _128, _32>, + TiledMMA>>, + XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T>; + +using PvcGemmBF16BF16FP32_RRR_7 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_8, _128, _32>, + TiledMMA>>, + XE_2D_U16x16x16_LD_T, XE_2D_U16x32x32_LD_V>; + +using PvcGemmBF16BF16FP32_RRR_8 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, Shape<_8, _128, _32>, + TiledMMA>>, + XE_2D_U16x16x16_LD_T, XE_2D_U16x16x16_LD_T>; + CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_6); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_7); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_8); static void register_benchmarks() { CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); @@ -92,4 +122,7 @@ static void register_benchmarks() { CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_6); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_7); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_8); } diff --git a/benchmarks/pvc/input.in b/benchmarks/pvc/input.in index 4f5d47648..fa2fab239 100644 --- a/benchmarks/pvc/input.in +++ b/benchmarks/pvc/input.in @@ -21,3 +21,9 @@ PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n= PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_6 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_RRR_6 --bm_name=bf16_bf16_fp32 --l=32 --m=256 --k=2048 --n=16384 +PvcGemmBF16BF16FP32_RRR_7 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_RRR_7 --bm_name=bf16_bf16_fp32 --l=32 --m=128 --k=1024 --n=8192 +PvcGemmBF16BF16FP32_RRR_8 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_RRR_8 --bm_name=bf16_bf16_fp32 --l=32 --m=16384 --k=4096 --n=1024 diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index daf78eb04..82fec9f04 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -40,48 +40,81 @@ namespace cute { namespace detail { - template - struct is_transpose : bool_constant {}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; - - template<> - struct is_transpose : bool_constant{}; + struct MKL_Indicator {}; + struct NKL_Indicator {}; + + template + struct is_MKL_layout { + static constexpr bool value = false; + }; + + template + struct is_MKL_layout>> { + static constexpr bool value = true; + }; + + template + struct is_NKL_layout { + static constexpr bool value = false; + }; + + template + struct is_NKL_layout>> { + static constexpr bool value = true; + }; + + template + struct is_transpose_load{ + static constexpr bool value = (is_MKL_layout::value + && std::is_same_v, cutlass::layout::ColumnMajor>) + || (is_NKL_layout::value + && std::is_same_v, cutlass::layout::ColumnMajor>); + }; template constexpr bool has_inst_dtype = false; template constexpr bool has_inst_dtype> = true; + + template + struct size_of_inst { + static constexpr auto value = sizeof(dtype); + }; + + template + struct size_of_inst>> { + static constexpr auto value = sizeof(typename T::inst_dtype); + }; + } // namespace detail end -template struct XE_2D_LD_Unpack { +template , int64_t>> +struct XE_2D_LD_Unpack { const void *base_ptr; uint32_t width; uint32_t height; uint32_t pitch; - - XE_2D_LD_Unpack(const void *ptr, uint32_t const &w, - uint32_t const &h, uint32_t const &p) - : base_ptr(ptr), width(w), height(h), pitch(p) {} + + static constexpr bool is_mkl = detail::is_MKL_layout::value; + static constexpr bool is_nkl = detail::is_NKL_layout::value; + static constexpr bool is_transpose = detail::is_transpose_load::value; + + static_assert(is_mkl != is_nkl); + + XE_2D_LD_Unpack(const void *ptr, uint32_t const &y, + uint32_t const &x, uint32_t const &p = 0) : base_ptr(ptr) { + if (is_nkl) { + width = is_transpose ? x : y; + height = is_transpose ? y : x; + pitch = (p == 0 ? width : p); + } else { + width = is_transpose ? y : x; + height = is_transpose ? x : y; + pitch = (p == 0 ? width : p); + } + } template XE_2D_LD_Unpack(TraitsArgs const &traits) : base_ptr(traits.base_ptr), @@ -89,7 +122,7 @@ template struct XE_2D_LD_Unpack { XE_2D_LD_Unpack() {} - using Traits_LD_t = Copy_Traits; + using Traits_LD_t = Copy_Traits; template CUTE_HOST_DEVICE friend constexpr void @@ -100,19 +133,23 @@ template struct XE_2D_LD_Unpack { using dtype = typename Tensor::value_type; dtype *base_addr = (dtype *)traits.base_ptr; - - auto [m, n, l] = src.data().coord_; - - auto inst_size = sizeof(dtype); - - if constexpr (detail::has_inst_dtype) { - inst_size = sizeof(typename CopyOp::inst_dtype); + + int x, y; + auto [coord_0, coord_1, z] = src.data().coord_; + if constexpr (is_mkl ^ is_transpose) { + x = coord_1; + y = coord_0; + } else { + x = coord_0; + y = coord_1; } - CopyOp::copy(base_addr + l * traits.width * traits.height, + static constexpr auto inst_size = detail::size_of_inst::value; + + CopyOp::copy(base_addr + z * traits.width * traits.height, traits.width * sizeof(dtype), traits.height, traits.pitch * sizeof(dtype), - intel::coord_t{(int)(n * sizeof(dtype) / inst_size), (int)(m)}, + intel::coord_t{(int)(x * sizeof(dtype) / inst_size), y}, &*dst.data()); } @@ -134,20 +171,62 @@ template struct XE_2D_LD_Unpack { intel::coord_t{(int)n, (int)m}); } - template {})> - CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(int m_coord, int n_coord, int l_coord, + GShape const &shape) const { + + auto R = rank(GShape{}); + static_assert(R == 3, "mismatch rank"); + + auto t_shape = cute::tuple_cat(make_shape(_1{}), take<1, R>(shape)); + + auto basis = make_seq{}; + + if constexpr (is_mkl) { + if constexpr (!is_transpose) { + auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(basis, typename CopyOp::Shape_MN{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), + make_layout(t_shape, t_stride)); + } else { + auto t_stride = cute::tuple_cat(make_stride(_1{}), transform((basis), typename CopyOp::Shape_MN{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), + make_layout(t_shape, t_stride)); + } + } else if constexpr (is_nkl) { + if constexpr (!is_transpose) { + auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(reverse(basis), typename CopyOp::Shape_MN{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), + make_layout(t_shape, t_stride)); + } else { + auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(reverse(basis), typename CopyOp::Shape_MN{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), + make_layout(t_shape, t_stride)); + } + } + } + + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor_B(int m_coord, int n_coord, int l, GShape const &shape, - GStride const &stride, - Basis const & basis = {}) const { + Direction const& direction) const { auto R = rank(GShape{}); - static_assert(R == 3 || R == 4, "mismatch rank"); + static_assert(R == 3, "mismatch rank"); auto t_shape = cute::tuple_cat(make_shape(_1{}), take<1, R>(shape)); - auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(basis, stride, [&](auto i, auto s){ - return E{} * s; - })); - return make_tensor(make_inttuple_iter(coord), - make_layout(t_shape, t_stride)); + + } template diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index adba580df..fae674fa0 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -98,9 +98,6 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr bool a_row_major = std::is_same_v, cutlass::layout::RowMajor>; - static constexpr bool b_row_major = std::is_same_v, cutlass::layout::RowMajor>;; - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using MmaAtomShape = typename TiledMma::AtomShape_MNK; @@ -129,14 +126,14 @@ struct CollectiveMma< using PrefetchBTileSize = decltype(ceil_div(Shape, Int>{},PrefetchBThrShape{})); static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - using traits_load_A = Copy_Traits; + using traits_load_A = Copy_Traits; using atom_load_A = Copy_Atom; using XE_Copy_A = decltype(make_tiled_copy(atom_load_A{} .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), Layout>>{}, make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), get<1>(typename traits_load_A::Shape_MN{}) / Int{})))); - using traits_load_B = Copy_Traits; + using traits_load_B = Copy_Traits; using atom_load_B = Copy_Atom; using XE_Copy_B = decltype(make_tiled_copy(atom_load_B{} .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), @@ -175,17 +172,11 @@ struct CollectiveMma< auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - XE_Copy_A copyA = make_tiled_copy((a_row_major ? Copy_Atom, ElementA>{}.with( - args.ptr_A, K, M, K) - : Copy_Atom, ElementA>{}.with( - args.ptr_A, M, K, M)), + XE_Copy_A copyA = make_tiled_copy(Copy_Atom, ElementA>{}.with(args.ptr_A, M, K), Layout>>{}, make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); - XE_Copy_B copyB = make_tiled_copy((b_row_major ? Copy_Atom, ElementB>{}.with( - args.ptr_B, N, K, N) - : Copy_Atom, ElementB>{}.with( - args.ptr_B, K, N, K)), + XE_Copy_B copyB = make_tiled_copy(Copy_Atom, ElementB>{}.with(args.ptr_B, N, K), Layout>>{}, make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); @@ -194,22 +185,9 @@ struct CollectiveMma< return Params{copyA, copyB, prefetchA, prefetchB}; } - template - static constexpr auto get_pvc_tensor_a(tile_copy_t const &tile_copy, int m, int n, int l, shape_t const &shape, stride_t const &stride) { - if constexpr (row_major) { - return tile_copy.get_pvc_tensor(make_coord(m, n, l), shape, stride, seq<0, 1, 1>{}); - } else { - return tile_copy.get_pvc_tensor(make_coord(n, m, l), shape, stride, seq<1, 0, 0>{}); - } - } - - template - static auto get_pvc_tensor_b(tile_copy_t const &tile_copy, int m, int n, int l, shape_t const &shape, stride_t const &stride) { - if constexpr (row_major) { - return tile_copy.get_pvc_tensor(make_coord(m, n, l), shape, stride, seq<0, 1, 0>{}); - } else { - return tile_copy.get_pvc_tensor(make_coord(n, m, l), shape, stride, seq<1, 0, 1>{}); - } + template + static constexpr auto append_pvc_tensor_with_layout(Tensor_t const &t0, Layout_t const & layout) { + return make_tensor(make_inttuple_iter(t0.data()), append(t0.layout(), layout)); } /// Perform a subgroup-scoped matrix multiply-accumulate @@ -298,28 +276,29 @@ struct CollectiveMma< const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; #endif const int l_coord = l_idx; - Tensor iter_a = get_pvc_tensor_a(mainloop.gmem_tiled_copy_a, - m_coord, 0, l_coord, - append<4>(tCrA_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K)); - Tensor iter_b = get_pvc_tensor_b(mainloop.gmem_tiled_copy_b, - 0, n_coord, l_coord, - append<4>(tCrB_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K)); + + Tensor block2d_copy_iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor(m_coord, 0, l_coord, tCrA_copy_view.shape()); + auto copy_iter_a = append_pvc_tensor_with_layout(block2d_copy_iter_a, make_layout(make_shape(k_tile_count), make_stride(E<1>{} *BLK_K))); + + Tensor block2d_copy_iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor(n_coord, 0, l_coord, tCrB_copy_view.shape()); + auto copy_iter_b = append_pvc_tensor_with_layout(block2d_copy_iter_b, make_layout(make_shape(k_tile_count), make_stride(E<1>{} *BLK_K))); const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K)); int prefetch_k = 0; - Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( - make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), - (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), - append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), - append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); - Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( - make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, - n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), - append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), - append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); + Tensor blocked_prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( + m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), + (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, + l_coord, + make_shape(_1{}, _1{}, _1{})); + auto prefetch_iter_a = append_pvc_tensor_with_layout(blocked_prefetch_iter_a, make_layout(make_shape(k_tile_count), make_stride(E<1>{} *BLK_K))); + + Tensor blocked_prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( + (get_sub_group_id() / ATOM_N / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, + n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), + l_coord, + make_shape(_1{}, _1{}, _1{})); + auto prefetch_iter_b = append_pvc_tensor_with_layout(blocked_prefetch_iter_b, make_layout(make_shape(k_tile_count), make_stride(E<0>{} *BLK_K))); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { @@ -334,8 +313,8 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int k_tile = 0, k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) { // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k), tCrA_copy_view); - copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k), tCrB_copy_view); + copy(mainloop.gmem_tiled_copy_a, copy_iter_a(_,_,_,k), tCrA_copy_view); + copy(mainloop.gmem_tiled_copy_b, copy_iter_b(_,_,_,k), tCrB_copy_view); if(prefetch_k < k_tile_count) { if constexpr(cute::detail::has_prefetch) { From 450397441139a339fb556133f3c2999edfbc829e Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 19 Dec 2024 16:35:29 +0800 Subject: [PATCH 2/2] revert pvc_gemm.cpp --- examples/sycl/pvc/pvc_gemm.cpp | 74 +++++++++++++--------------------- 1 file changed, 28 insertions(+), 46 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index f03cf2a1c..ee20a51ee 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -254,10 +254,7 @@ struct ExampleRunner { float cute_time = timer.seconds() / options.iterations; double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; - printf("Cutlass GEMM (A %s, B %s) Performance: [%4.3f]TFlop/s (%6.4f)ms\n\n", - std::is_same_v ? "RowMajor" : "ColumnMajor", - std::is_same_v ? "RowMajor" : "ColumnMajor", - tflops / cute_time, cute_time*1000); + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); } return; @@ -265,8 +262,26 @@ struct ExampleRunner { }; -template -static constexpr auto gemm_run(Options const& options) { +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + // // Run examples // @@ -285,17 +300,17 @@ static constexpr auto gemm_run(Options const& options) { // elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = float; // <- data type of epilogue operations - using ElementInputA = a_type; // <- data type of elements in input matrix A - using ElementInputB = b_type; // <- data type of elements in input matrix B - using ElementOutput = c_type; // <- data type of elements in output matrix D + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D - using LayoutA = std::conditional_t; - using LayoutB = std::conditional_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = std::conditional_t; - using GmemTiledCopyB = std::conditional_t; + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -350,39 +365,6 @@ static constexpr auto gemm_run(Options const& options) { ExampleRunner runner; runner.run(options, hw_info); -} - -int main(int argc, const char** argv) -{ - // - // Parse options - // - - Options options; - - options.parse(argc, argv); - - if (options.help) { - options.print_usage(std::cout) << std::endl; - return 0; - } - - if (options.error) { - std::cerr << "Aborting execution." << std::endl; - return -1; - } - - // row major A, row major B - gemm_run(options); - - // row major A, column major B - gemm_run(options); - - // column major A, row major B - gemm_run(options); - - // column major A, column major B - gemm_run(options); return 0; }