Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
taozha2 committed Dec 16, 2024
1 parent 1baf8fe commit 8273148
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 48 deletions.
34 changes: 14 additions & 20 deletions examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "common.h"
#include "common.hpp"

using namespace cute;

Expand Down Expand Up @@ -227,7 +227,10 @@ struct ExampleRunner {
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

gemm_op.can_implement(arguments);
if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){
std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::exit(1);
}

gemm_op.initialize(arguments, workspace.get());

Expand Down Expand Up @@ -286,16 +289,13 @@ static constexpr auto gemm_run(Options const& options) {
using ElementInputB = b_type; // <- data type of elements in input matrix B
using ElementOutput = c_type; // <- data type of elements in output matrix D

// using LayoutA = cutlass::layout::ColumnMajor;
// using LayoutB = cutlass::layout::RowMajor;
using LayoutA = std::conditional_t<a_row_major, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>;
using LayoutB = std::conditional_t<b_row_major, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; // row major load
// using GmemTiledCopyA = XE_2D_U16x16x16_LD_T; // column major load

// using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // row major load
// using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; // column major load
using copy_traits_a = std::conditional_t<a_row_major, XE_2D_U16x32x32_LD_N, XE_2D_U16x16x16_LD_T>;
using copy_traits_b = std::conditional_t<b_row_major, XE_2D_U16x32x32_LD_V, XE_2D_U16x16x16_LD_T>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand Down Expand Up @@ -326,23 +326,17 @@ static constexpr auto gemm_run(Options const& options) {
XE_2D_U32x8x16_ST_N,
void, void>;

// Mainloop
// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
ElementInputA,
cutlass::gemm::TagToStrideA_t<std::conditional_t<a_row_major,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>>,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementInputB,
cutlass::gemm::TagToStrideB_t<std::conditional_t<b_row_major,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>>,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
std::conditional_t<a_row_major, XE_2D_U16x32x32_LD_N, XE_2D_U16x16x16_LD_T>, // A
void, void, cute::identity, // A
std::conditional_t<b_row_major, XE_2D_U16x32x32_LD_V, XE_2D_U16x16x16_LD_T>, // B
void, void, cute::identity // B
copy_traits_a, void, void, cute::identity, // A
copy_traits_b, void, void, cute::identity // B
>;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Expand Down
70 changes: 42 additions & 28 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,15 @@ struct CollectiveMma<

/// Perform a subgroup-scoped matrix multiply-accumulate
template <
int PrefetchStrideA,
int PrefetchStrideB,
class FrgTensorD,
class TensorA,
class TensorB,
class FrgTensorC,
class KTileIterator,
class ResidueMNK
class ResidueMNK,
class BlkCoord
>
CUTLASS_DEVICE void
operator() (
Expand All @@ -229,6 +232,8 @@ struct CollectiveMma<
FrgTensorC const &src_accum,
KTileIterator k_tile_iter, int k_tile_count,
ResidueMNK residue_mnk,
BlkCoord const &blk_coord,
int const &K,
int thread_idx,
char *smem_buf,
Params const& mainloop)
Expand Down Expand Up @@ -284,11 +289,15 @@ struct CollectiveMma<
//
// Mainloop
//
int sub_group_id = get_sub_group_id();
const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M;
const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N;
const int l_coord = BlockIdxZ();

auto [m_idx, n_idx, k_idx, l_idx] = blk_coord;
#ifdef CUTLASS_SYCL_SWITCH_WG
const int m_coord = n_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
const int n_coord = m_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
#else
const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
#endif
const int l_coord = l_idx;
Tensor iter_a = get_pvc_tensor_a<a_row_major>(mainloop.gmem_tiled_copy_a,
m_coord, 0, l_coord,
append<4>(tCrA_copy_view.shape(), k_tile_count),
Expand All @@ -298,44 +307,49 @@ struct CollectiveMma<
append<4>(tCrB_copy_view.shape(), k_tile_count),
append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K));

const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
int prefetch_k = 0;

Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor(
make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{}) * get<1>(PrefetchATileSize{}), l_coord),
make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{});
Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor(
make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) * get<0>(PrefetchBTileSize{}),
n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord),
make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB,
n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{});

#pragma unroll
for (int i = 0; i < DispatchPolicy::Stages; i++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
if constexpr(cute::detail::has_prefetch<GmemTiledCopyA>) {
prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,i));
prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k));
}
if constexpr(cute::detail::has_prefetch<GmemTiledCopyB>) {
prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,i));
prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k));
}
}
#pragma unroll
for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile), tCrA_copy_view);
copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile), tCrB_copy_view);

if(k_tile + DispatchPolicy::Stages < k_tile_count) {
if constexpr(cute::detail::has_prefetch<GmemTiledCopyA>) {
prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,k_tile + DispatchPolicy::Stages));
CUTLASS_PRAGMA_UNROLL
for (int k_tile = 0, k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) {
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k), tCrA_copy_view);
copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k), tCrB_copy_view);

if(prefetch_k < k_tile_count) {
if constexpr(cute::detail::has_prefetch<GmemTiledCopyA>) {
prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k));
}
if constexpr(cute::detail::has_prefetch<GmemTiledCopyB>) {
prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k));
}
}
if constexpr(cute::detail::has_prefetch<GmemTiledCopyB>) {
prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,k_tile + DispatchPolicy::Stages));

for (int i = 0; i < SG_K / SubgroupSize; i++) {
cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum);
}
}
for (int i = 0; i < SG_K / SubgroupSize; i++) {
cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum);
}
}
}
};

Expand Down

0 comments on commit 8273148

Please sign in to comment.