Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions examples/cute/tutorial/xe_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,22 @@ choose_tiled_mma(ATensor const& A, BTensor const& B, CTensor const&)
auto op = choose_mma_op<TA,TB,TC>();

constexpr bool byte = (cute::max(sizeof_bits_v<TA>, sizeof_bits_v<TB>) <= 8);
constexpr bool use_1x_dpas_per_k = is_constant_v<1, decltype(stride<0>(A))> // Use one DPAS in k dimension for A^T case
|| (byte && is_constant_v<1, decltype(stride<0>(B))>); // pending compiler improvements (also int8 B^N)
constexpr bool a_t = is_constant_v<1, decltype(stride<0>(A))>;
constexpr bool b_n = is_constant_v<1, decltype(stride<0>(B))>;

constexpr bool use_1x_dpas_per_k = a_t // Use one DPAS in k dimension for A^T case
|| (byte && b_n); // pending compiler improvements (also int8 B^N).
constexpr bool use_4x8_sg = ((sizeof_bits_v<TB> < sizeof_bits_v<TA>) // Use smaller B loads for expensive reorders.
&& !(is_same_v<TB, cute::float_e5m2_t>))
|| (b_n && sizeof_bits_v<TB> < 8);

using _K = conditional_t<use_1x_dpas_per_k,
C<op.K>, C<op.K*2>>;

using WGTile = Shape<_256, _256, _K>; // 256x256 WG tile size
using SGLayout = Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major
using SGLayout8x4 = Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major
using SGLayout4x8 = Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major
using SGLayout = conditional_t<use_4x8_sg, SGLayout4x8, SGLayout8x4>;

using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>, Layout<WGTile>, SGLayout>::TiledMMA;

Expand Down