Skip to content

Commit

Permalink
hot fix: batch gemm & revert to previous config (#149)
Browse files Browse the repository at this point in the history
Co-authored-by: Alejandro Acosta <[email protected]>
  • Loading branch information
jiyang1011 and aacostadiaz authored Oct 30, 2024
1 parent 641f717 commit 68e1449
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 16 deletions.
10 changes: 5 additions & 5 deletions benchmarks/pvc/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ struct Gemm_OperandB;

template<>
struct Gemm_OperandA<bfloat16_t, layout::RowMajor> {
using GmemTiledCopy = XE_2D_U16x8x16_LD_N;
using GmemTiledCopy = XE_2D_U16x32x32_LD_N;
};

template<>
struct Gemm_OperandB<bfloat16_t, layout::RowMajor> {
using GmemTiledCopy = XE_2D_U16x16x16_LD_V;
using GmemTiledCopy = XE_2D_U16x32x32_LD_V;
};

} // namespace details
Expand All @@ -93,12 +93,12 @@ struct GemmConfiguration<
bfloat16_t, LayoutB,
float, LayoutC,
float> {
using TileShape = Shape<_256, _256, _16>;
using TileShape = Shape<_256, _256, _32>;
using DispatchPolicy = MainloopIntelPVC<3>;;
using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_8,_1>>,
Tile<_64,_128,_16>>;
Layout<Shape<_8,_4,_1>>,
Tile<_64,_64,_32>>;

// A
using OperandA = detail::Gemm_OperandA<bfloat16_t, LayoutA>;
Expand Down
10 changes: 5 additions & 5 deletions examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,15 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16_LD_V;
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;

// Workgroup-level tile
using TileShape = Shape<_256, _128, _16>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8,_2,_1>>,
Tile<_64,_32,_16>>; // Subgroup level-tile
Layout<Shape<_8,_4,_1>>,
Tile<_64,_64,_32>>; // Subgroup level-tile

constexpr int PipelineStages = 3;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class CollectiveEpilogue<
Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trD = make_tensor<typename TiledMma::ValTypeD>(Shape<Int<FragmentSize>>{});
Tensor tOuti = params.xe_store_d.get_pvc_tensor(
make_coord(m_offset, n_offset, l_offset),
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(Int<get<0>(MmaAtomShape{})>{}, Int<get<1>(MmaAtomShape{})>{}, _1{}));

Expand Down
18 changes: 13 additions & 5 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,14 @@ struct CollectiveMma<
// Instantiate the MMA object
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(thread_idx);
Tensor tCrA = thread_mma.partition_fragment_A(gA(_, _, 0));
Tensor tCrB = thread_mma.partition_fragment_B(gB(_, _, 0));
Tensor tCrA_partition = thread_mma.partition_fragment_A(gA(_, _, 0));
Tensor tCrA = make_tensor(static_cast<decltype(tCrA_partition) &&>(tCrA_partition).data(),
tCrA_partition.shape());
Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, 0));
Tensor tCrB = make_tensor(static_cast<decltype(tCrB_partition) &&>(tCrB_partition).data(),
make_shape(size<0>(tCrB_partition.shape()),
size<2>(tCrB_partition.shape()),
size<1>(tCrB_partition.shape())));
// Partition the copying of A and B tiles across the threads
auto gmem_thr_copy_A = mainloop.gmem_tiled_copy_a.get_slice(thread_idx);
auto gmem_thr_copy_B = mainloop.gmem_tiled_copy_b.get_slice(thread_idx);
Expand Down Expand Up @@ -234,10 +240,10 @@ struct CollectiveMma<
const int n_coord = BlockIdxX() * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
const int l_coord = BlockIdxZ();
Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, l_coord), make_shape(_, size<1>(tCrA_copy_view.shape()), size<2>(tCrA_copy_view.shape()), k_tile_count),
make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count),
append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{});
Tensor iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor(
make_coord(0, n_coord, l_coord), make_shape(_, size<2>(tCrB_copy_view.shape()), size<1>(tCrB_copy_view.shape()), k_tile_count),
make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count),
append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{});
#pragma unroll
for (int i = 0; i < DispatchPolicy::Stages; i++) {
Expand All @@ -261,7 +267,9 @@ struct CollectiveMma<
prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile + DispatchPolicy::Stages));
}
}
cute::gemm(tiled_mma, accum, tCrA, tCrB, src_accum);
for (int i = 0; i < SG_K / SubgroupSize; i++) {
cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum);
}
}
}
};
Expand Down

0 comments on commit 68e1449

Please sign in to comment.