@@ -165,21 +165,41 @@ reorder_impl(ReorderAtom const& atom,
165165 constexpr int values = size (SLayout{}) / size<0 >(SLayout{});
166166 constexpr int vchunk = sizeof_bits_v<RegistersSrc> / sizeof_bits_v<SType>;
167167
168- // Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic
169- // correspondence of src/dst values for subgroup reorders.
170- auto rlayout = coalesce (composition (right_inverse (dlayout), slayout)); // src index -> dst index
171- auto vrlayout = composition (composition (Layout<Shape<_SG, Int<values>>, Stride<_0, _1>>{},
172- rlayout),
173- Layout<Shape<_1, Int<values>>, Stride<_0, _SG>>{}); // src val -> dst val
174-
175- CUTE_UNROLL
176- for (int sv = 0 ; sv < values; sv += vchunk) {
177- auto pS = recast_ptr<RegTypeSrc>(src.data () + sv);
178- auto pD = recast_ptr<RegTypeDst>(dst.data () + vrlayout (sv));
179-
180- detail::explode (detail::CallReorder<ReorderAtom>{},
181- pS, make_int_sequence<RegNumSrc>{},
182- pD, make_int_sequence<RegNumDst>{});
168+ static constexpr bool has_broadcast = (size (DLayoutWI{}) > size (SLayoutWI{}));
169+
170+ if (!has_broadcast) {
171+ // Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic
172+ // correspondence of src/dst values for subgroup reorders.
173+ auto rlayout = coalesce (composition (right_inverse (dlayout), slayout)); // src index -> dst index
174+ auto vrlayout = composition (composition (Layout<Shape<_SG, Int<values>>, Stride<_0, _1>>{},
175+ rlayout),
176+ Layout<Shape<_1, Int<values>>, Stride<_0, _SG>>{}); // src val -> dst val
177+
178+ CUTE_UNROLL
179+ for (int sv = 0 ; sv < values; sv += vchunk) {
180+ auto pS = recast_ptr<RegTypeSrc>(src.data () + sv);
181+ auto pD = recast_ptr<RegTypeDst>(dst.data () + vrlayout (sv));
182+
183+ detail::explode (detail::CallReorder<ReorderAtom>{},
184+ pS, make_int_sequence<RegNumSrc>{},
185+ pD, make_int_sequence<RegNumDst>{});
186+ }
187+ } else {
188+ // If there is broadcast happening, then we need to loop over dst values instead.
189+ auto rlayout = coalesce (composition (right_inverse (slayout), dlayout)); // dst index -> src index
190+ auto vrlayout = composition (composition (Layout<Shape<_SG, Int<values>>, Stride<_0, _1>>{},
191+ rlayout),
192+ Layout<Shape<_1, Int<values>>, Stride<_0, _SG>>{}); // dst val -> src val
193+
194+ CUTE_UNROLL
195+ for (int dv = 0 ; dv < values; dv += vchunk) {
196+ auto pS = recast_ptr<RegTypeSrc>(src.data () + vrlayout (dv));
197+ auto pD = recast_ptr<RegTypeDst>(dst.data () + dv);
198+
199+ detail::explode (detail::CallReorder<ReorderAtom>{},
200+ pS, make_int_sequence<RegNumSrc>{},
201+ pD, make_int_sequence<RegNumDst>{});
202+ }
183203 }
184204}
185205
0 commit comments