@@ -205,14 +205,22 @@ choose_tiled_mma(ATensor const& A, BTensor const& B, CTensor const&)
205205 auto op = choose_mma_op<TA,TB,TC>();
206206
207207 constexpr bool byte = (cute::max (sizeof_bits_v<TA>, sizeof_bits_v<TB>) <= 8 );
208- constexpr bool use_1x_dpas_per_k = is_constant_v<1 , decltype (stride<0 >(A))> // Use one DPAS in k dimension for A^T case
209- || (byte && is_constant_v<1 , decltype (stride<0 >(B))>); // pending compiler improvements (also int8 B^N)
208+ constexpr bool a_t = is_constant_v<1 , decltype (stride<0 >(A))>;
209+ constexpr bool b_n = is_constant_v<1 , decltype (stride<0 >(B))>;
210+
211+ constexpr bool use_1x_dpas_per_k = a_t // Use one DPAS in k dimension for A^T case
212+ || (byte && b_n); // pending compiler improvements (also int8 B^N).
213+ constexpr bool use_4x8_sg = ((sizeof_bits_v<TB> <= sizeof_bits_v<TA>) // Use smaller B loads for expensive reorders.
214+ && !(is_same_v<TB, cute::float_e5m2_t > && is_same_v<TA, cute::half_t >))
215+ || (b_n && sizeof_bits_v<TB> < 8 );
210216
211217 using _K = conditional_t <use_1x_dpas_per_k,
212218 C<op.K >, C<op.K *2 >>;
213219
214220 using WGTile = Shape<_256, _256, _K>; // 256x256 WG tile size
215- using SGLayout = Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major
221+ using SGLayout8x4 = Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major
222+ using SGLayout4x8 = Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major
223+ using SGLayout = conditional_t <use_4x8_sg, SGLayout4x8, SGLayout8x4>;
216224
217225 using MMA = typename TiledMMAHelper<MMA_Atom<decltype (op)>, Layout<WGTile>, SGLayout>::TiledMMA;
218226
0 commit comments