Skip to content

Commit 56a200d

Browse files
authored
[Xe] [Reorder] Support broadcasting reorders (#589)
Adds support for broadcasting in reorders, where a single src value maps to multiple dst values. This is useful for reordering scales/zero points during dequantization.
1 parent ac1e946 commit 56a200d

File tree

2 files changed

+37
-16
lines changed

2 files changed

+37
-16
lines changed

include/cute/algorithm/reorder.hpp

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,41 @@ reorder_impl(ReorderAtom const& atom,
165165
constexpr int values = size(SLayout{}) / size<0>(SLayout{});
166166
constexpr int vchunk = sizeof_bits_v<RegistersSrc> / sizeof_bits_v<SType>;
167167

168-
// Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic
169-
// correspondence of src/dst values for subgroup reorders.
170-
auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index
171-
auto vrlayout = composition(composition(Layout<Shape<_SG, Int<values>>, Stride<_0, _1>>{},
172-
rlayout),
173-
Layout<Shape<_1, Int<values>>, Stride<_0, _SG>>{}); // src val -> dst val
174-
175-
CUTE_UNROLL
176-
for (int sv = 0; sv < values; sv += vchunk) {
177-
auto pS = recast_ptr<RegTypeSrc>(src.data() + sv);
178-
auto pD = recast_ptr<RegTypeDst>(dst.data() + vrlayout(sv));
179-
180-
detail::explode(detail::CallReorder<ReorderAtom>{},
181-
pS, make_int_sequence<RegNumSrc>{},
182-
pD, make_int_sequence<RegNumDst>{});
168+
static constexpr bool has_broadcast = (size(DLayoutWI{}) > size(SLayoutWI{}));
169+
170+
if (!has_broadcast) {
171+
// Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic
172+
// correspondence of src/dst values for subgroup reorders.
173+
auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index
174+
auto vrlayout = composition(composition(Layout<Shape<_SG, Int<values>>, Stride<_0, _1>>{},
175+
rlayout),
176+
Layout<Shape<_1, Int<values>>, Stride<_0, _SG>>{}); // src val -> dst val
177+
178+
CUTE_UNROLL
179+
for (int sv = 0; sv < values; sv += vchunk) {
180+
auto pS = recast_ptr<RegTypeSrc>(src.data() + sv);
181+
auto pD = recast_ptr<RegTypeDst>(dst.data() + vrlayout(sv));
182+
183+
detail::explode(detail::CallReorder<ReorderAtom>{},
184+
pS, make_int_sequence<RegNumSrc>{},
185+
pD, make_int_sequence<RegNumDst>{});
186+
}
187+
} else {
188+
// If there is broadcast happening, then we need to loop over dst values instead.
189+
auto rlayout = coalesce(composition(right_inverse(slayout), dlayout)); // dst index -> src index
190+
auto vrlayout = composition(composition(Layout<Shape<_SG, Int<values>>, Stride<_0, _1>>{},
191+
rlayout),
192+
Layout<Shape<_1, Int<values>>, Stride<_0, _SG>>{}); // dst val -> src val
193+
194+
CUTE_UNROLL
195+
for (int dv = 0; dv < values; dv += vchunk) {
196+
auto pS = recast_ptr<RegTypeSrc>(src.data() + vrlayout(dv));
197+
auto pD = recast_ptr<RegTypeDst>(dst.data() + dv);
198+
199+
detail::explode(detail::CallReorder<ReorderAtom>{},
200+
pS, make_int_sequence<RegNumSrc>{},
201+
pD, make_int_sequence<RegNumDst>{});
202+
}
183203
}
184204
}
185205

include/cute/atom/reorder_atom_xe.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ reorder_impl(ReorderDispatchXeGeneric const&,
190190
static constexpr int elems_per_grf = 64 / sizeof(SrcType);
191191
static constexpr int ds_vl = cute::min(32, cute::min(shape<0>(rlayout), elems_per_grf / stride<0>(rlayout)));
192192
static constexpr int ss_vl = cute::min(32, cute::min(shape<0>(ilayout), elems_per_grf / stride<0>(ilayout)));
193+
static constexpr bool has_broadcast = (size(DLayoutWI{}) > size(SLayoutWI{}));
193194

194195
// Make dst live, to prevent compiler from inserting its own initialization.
195196
#ifdef __SYCL_DEVICE_ONLY__
@@ -202,7 +203,7 @@ reorder_impl(ReorderDispatchXeGeneric const&,
202203
}
203204
#endif
204205

205-
if constexpr (ss_vl >= ds_vl) {
206+
if constexpr (ss_vl >= ds_vl || has_broadcast) {
206207
// Stride on src. For simplicity, take 1 GRF at a time.
207208
for_each(make_seq<size(SLayout{}) / ss_vl>{}, [&](auto i) {
208209
constexpr auto didx = i * ss_vl;

0 commit comments

Comments
 (0)