Skip to content

Commit

Permalink
MultiShift, Masked shift and masked interleave
Browse files Browse the repository at this point in the history
  • Loading branch information
mazimkhan authored and wbb-ccl committed Jan 31, 2025
1 parent e96d4d3 commit ad2e8d0
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 1 deletion.
49 changes: 49 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,49 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
(a[i] << ((sizeof(T)*8 - b[i]) & shift_amt_mask))`, where `shift_amt_mask` is
equal to `sizeof(T)*8 - 1`.

A compound shift on 64-bit values:

* `V`: `{u,i}64`, `VI`: `{u,i}8` \
<code>V **MultiShift**(V vals, VI indices)</code>: returns a
vector with `(vals[i] >> indices[i*8+j]) & 0xff` in byte `j` of `r[i]` for each
`j` between 0 and 7.

If `indices[i*8+j]` is less than 0 or greater than 63, byte `j` of `r[i]` is
implementation-defined.

`VI` must be either `Vec<Repartition<int8_t, DFromV<V>>>` or
`Vec<Repartition<uint8_t, DFromV<V>>>`.

`MultiShift(V vals, VI indices)` is equivalent to the following loop (where `N` is
equal to `Lanes(DFromV<V>())`):
```
for(size_t i = 0; i < N; i++) {
uint64_t shift_result = 0;
for(int j = 0; j < 8; j++) {
uint64_t rot_result = (v[i] >> indices[i*8+j]) | (v[i] << (64 - indices[i*8+j]));
shift_result |= (rot_result & 0xff) << (j * 8);
}
r[i] = shift_result;
}
```
#### Masked Shifts
* `V`: `{u,i}` \
<code>V **MaskedShiftLeftOrZero**&lt;int&gt;(M mask, V a)</code> returns `a[i] << int` or `0` if
`mask[i]` is false.
* `V`: `{u,i}` \
<code>V **MaskedShiftRightOrZero**&lt;int&gt;(M mask, V a)</code> returns `a[i] >> int` or `0` if
`mask[i]` is false.
* `V`: `{u,i}` \
<code>V **MaskedShiftRightOr**&lt;int&gt;(V no, M mask, V a)</code> returns `a[i] >> int` or `no` if
`mask[i]` is false.
* `V`: `{u,i}` \
<code>V **MaskedShrOr**(V no, M mask, V a, V shifts)</code> returns `a[i] >> shifts[i]` or `no` if
`mask[i]` is false.
#### Floating-point rounding
* `V`: `{f}` \
Expand Down Expand Up @@ -2159,6 +2202,12 @@ Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
`InterleaveOdd(d, a, b)` is usually more efficient than `OddEven(b,
DupOdd(a))`.
* <code>V **InterleaveEvenOrZero**(M m, V a, V b)</code>: Performs the same
operation as InterleaveEven, but returns zero in lanes where `m[i]` is false.
* <code>V **InterleaveOddOrZero**(M m, V a, V b)</code>: Performs the same
operation as InterleaveOdd, but returns zero in lanes where `m[i]` is false.
#### Zip
* `Ret`: `MakeWide<T>`; `V`: `{u,i}{8,16,32}` \
Expand Down
29 changes: 29 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,35 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)

#undef HWY_SVE_SHIFT_N

// ------------------------------ MaskedShift[Left/Right]SameOrZero

#define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
template <int kBits> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \
return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeftOrZero, lsl_n)
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, asr_n)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, lsr_n)

#undef HWY_SVE_SHIFT_Z

// ------------------------------ MaskedShiftRightSameOr

#define HWY_SVE_SHIFT_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
template <int kBits> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \
return svsel##_##CHAR##BITS(m, sv##OP##_##CHAR##BITS##_z(m, v, shifts), \
no); \
}
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, asr_n)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, lsr_n)

#undef HWY_SVE_SHIFT_OR

// ------------------------------ RotateRight

#if HWY_SVE_HAVE_2
Expand Down
95 changes: 95 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,24 @@ HWY_API V InterleaveEven(V a, V b) {
}
#endif

// ------------------------------ InterleaveEvenOrZero

#if HWY_TARGET != HWY_SCALAR || HWY_IDE
template <class V, class M>
HWY_API V InterleaveEvenOrZero(M m, V a, V b) {
return IfThenElseZero(m, InterleaveEven(DFromV<V>(), a, b));
}
#endif

// ------------------------------ InterleaveOddOrZero

#if HWY_TARGET != HWY_SCALAR || HWY_IDE
template <class V, class M>
HWY_API V InterleaveOddOrZero(M m, V a, V b) {
return IfThenElseZero(m, InterleaveOdd(DFromV<V>(), a, b));
}
#endif

// ------------------------------ MinMagnitude/MaxMagnitude

#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -661,6 +679,27 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH

// ------------------------------ MaskedShift
template <int kshift, class V, class M>
HWY_API V MaskedShiftLeftOrZero(M m, V a) {
return IfThenElseZero(m, ShiftLeft<kshift>(a));
}

template <int kshift, class V, class M>
HWY_API V MaskedShiftRightOrZero(M m, V a) {
return IfThenElseZero(m, ShiftRight<kshift>(a));
}

template <int kshift, class V, class M>
HWY_API V MaskedShiftRightOr(V no, M m, V a) {
return IfThenElse(m, ShiftRight<kshift>(a), no);
}

template <class V, class M>
HWY_API V MaskedShrOr(V no, M m, V a, V shifts) {
return IfThenElse(m, Shr(a, shifts), no);
}

// ------------------------------ IfNegativeThenNegOrUndefIfZero

#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
Expand Down Expand Up @@ -7534,6 +7573,62 @@ HWY_API bool AllBits0(V a) {
return AllTrue(d, Eq(a, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS
// ------------------------------ MultiShift (Rol)
#if (defined(HWY_NATIVE_MULTISHIFT) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MULTISHIFT
#undef HWY_NATIVE_MULTISHIFT
#else
#define HWY_NATIVE_MULTISHIFT
#endif

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>)>
HWY_API V MultiShift(V v, VI idx) {
const DFromV<V> d64;
const Repartition<uint8_t, decltype(d64)> du8;
const Repartition<uint16_t, decltype(d64)> du16;
const auto k7 = Set(du8, uint8_t{0x07});
const auto k63 = Set(du8, uint8_t{0x3F});

const auto masked_idx = And(k63, BitCast(du8, idx));
const auto byte_idx = ShiftRight<3>(masked_idx);
const auto idx_shift = And(k7, masked_idx);

// Calculate even lanes
const auto even_src = DupEven(v);
// Expand indexes to pull out 16 bit segments of idx and idx + 1
const auto even_idx =
InterleaveLower(byte_idx, Add(byte_idx, Set(du8, uint8_t{1})));
// TableLookupBytes indexes select from within a 16 byte block
const auto even_segments = TableLookupBytes(even_src, even_idx);
// Extract unaligned bytes from 16 bit segments
const auto even_idx_shift = ZipLower(idx_shift, Zero(du8));
const auto extracted_even_bytes =
Shr(BitCast(du16, even_segments), even_idx_shift);

// Calculate odd lanes
const auto odd_src = DupOdd(v);
// Expand indexes to pull out 16 bit segments of idx and idx + 1
const auto odd_idx =
InterleaveUpper(du8, byte_idx, Add(byte_idx, Set(du8, uint8_t{1})));
// TableLookupBytes indexes select from within a 16 byte block
const auto odd_segments = TableLookupBytes(odd_src, odd_idx);
// Extract unaligned bytes from 16 bit segments
const auto odd_idx_shift = ZipUpper(du16, idx_shift, Zero(du8));
const auto extracted_odd_bytes =
Shr(BitCast(du16, odd_segments), odd_idx_shift);

// Extract the even bytes of each 128 bit block and pack into lower 64 bits
const auto extract_mask = Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14,
0, 0, 0, 0, 0, 0, 0, 0);
const auto even_lanes =
BitCast(d64, TableLookupBytes(extracted_even_bytes, extract_mask));
const auto odd_lanes =
BitCast(d64, TableLookupBytes(extracted_odd_bytes, extract_mask));
// Interleave at 64 bit level
return InterleaveLower(even_lanes, odd_lanes);
}

#endif
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
60 changes: 60 additions & 0 deletions hwy/tests/blockwise_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,72 @@ struct TestInterleaveOdd {
}
};

struct TestMaskedInterleaveEven {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto even_lanes = AllocateAligned<T>(N);
auto odd_lanes = AllocateAligned<T>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(even_lanes && odd_lanes && expected);
for (size_t i = 0; i < N; ++i) {
even_lanes[i] = ConvertScalarTo<T>(2 * i + 0);
odd_lanes[i] = ConvertScalarTo<T>(2 * i + 1);
}
const auto even = Load(d, even_lanes.get());
const auto odd = Load(d, odd_lanes.get());

for (size_t i = 0; i < N; ++i) {
if (i < 3) {
expected[i] = ConvertScalarTo<T>(2 * i - (i & 1));
} else {
expected[i] = ConvertScalarTo<T>(0);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(),
InterleaveEvenOrZero(first_3, even, odd));
}
};

struct TestMaskedInterleaveOdd {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto even_lanes = AllocateAligned<T>(N);
auto odd_lanes = AllocateAligned<T>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(even_lanes && odd_lanes && expected);
for (size_t i = 0; i < N; ++i) {
even_lanes[i] = ConvertScalarTo<T>(2 * i + 0);
odd_lanes[i] = ConvertScalarTo<T>(2 * i + 1);
}
const auto even = Load(d, even_lanes.get());
const auto odd = Load(d, odd_lanes.get());

for (size_t i = 0; i < N; ++i) {
if (i < 3) {
expected[i] = ConvertScalarTo<T>((2 * i) - (i & 1) + 2);
} else {
expected[i] = ConvertScalarTo<T>(0);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(),
InterleaveOddOrZero(first_3, even, odd));
}
};

HWY_NOINLINE void TestAllInterleave() {
// Not DemoteVectors because this cannot be supported by HWY_SCALAR.
ForAllTypes(ForShrinkableVectors<TestInterleaveLower>());
ForAllTypes(ForShrinkableVectors<TestInterleaveUpper>());
ForAllTypes(ForShrinkableVectors<TestInterleaveEven>());
ForAllTypes(ForShrinkableVectors<TestInterleaveOdd>());
ForAllTypes(ForShrinkableVectors<TestMaskedInterleaveEven>());
ForAllTypes(ForShrinkableVectors<TestMaskedInterleaveOdd>());
}

struct TestZipLower {
Expand Down
Loading

0 comments on commit ad2e8d0

Please sign in to comment.