@@ -118,12 +118,12 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
118118 make_identity_tensor (select<0 ,1 >(TiledMMA{}.tile_mnk()))));
119119
120120 using FragS = FragC<TiledMMAQK>;
121- using FragSRow = decltype (reduce<1 >(FragS{}, sycl::plus{}));
121+ using FragSRow = decltype (reduce<1 >(FragS{}, sycl::plus< void > {}));
122122 using ElementS = typename TiledMMAQK::ValTypeD;
123123
124124 using SingleFragA = FragC<TiledMMAPV>; // (atom val,q',v')
125125 using FragA = expand_sg_fragment_t <SingleFragA, 1 , VTiles>; // (atom val,q',v',VV)
126- using FragARow = decltype (reduce<1 >(FragA{}, sycl::plus{}));
126+ using FragARow = decltype (reduce<1 >(FragA{}, sycl::plus< void > {}));
127127 using ElementA = typename TiledMMAPV::ValTypeD;
128128
129129 static constexpr bool CausalMask = CausalMask_;
@@ -293,9 +293,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
293293 if (check_remainder_k && K == blk_k1 - 1 ) {
294294 FragSRow k_rem_mask;
295295 int k = get<0 >(tKgK (0 ,0 ,0 ,K,0 )) + get_sub_group ().get_local_id ()[0 ];
296+ CUTLASS_PRAGMA_UNROLL
296297 for (int i = 0 ; i < k_rem_mask.size (); i++, k += intel::sg_size) {
297298 k_rem_mask (i) = (k < shape<0 >(K_2D)) ? ElementS (sycl::nan (0u )) : ElementS (-INFINITY);
298299 }
300+ CUTLASS_PRAGMA_UNROLL
299301 for (int i = 0 ; i < tSrS.size (); i++) {
300302 tSrS (i) = sycl::fmin (tSrS (i), broadcast<1 >(k_rem_mask, tSrS, i));
301303 }
@@ -309,7 +311,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
309311#if 0
310312 reorder(tSrS, tArP);
311313#else
312- for (int i = 0 ; i < tArP.size (); i++)
314+ for (int i = 0 ; i < tArP.size (); i++) // SYCL compiler currently is not correctly handling the above reorder.
313315 tArP (i) = static_cast <typename TiledMMAPV::ValTypeA>(tSrS (i));
314316#endif
315317
@@ -370,7 +372,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
370372 }
371373
372374 /* Update sums */
373- auto tS_bsum = reduce<1 >(tS, sycl::plus{});
375+ auto tS_bsum = reduce<1 >(tS, sycl::plus< void > {});
374376 for (int i = 0 ; i < tS_sum.size (); i++)
375377 tS_sum (i) += tS_bsum (i);
376378 }
0 commit comments