Skip to content

Commit

Permalink
Fix review comments
Browse files Browse the repository at this point in the history
Remove OrZero suffixes
Remove masked interleaves. Will be in a future PR with x86 specialisations
Correct template naming
Correct multishift header comment
Optimise MultiShift
Add x86 TODOs for multishift
Improve MultiShift docs
  • Loading branch information
wbb-ccl committed Jan 31, 2025
1 parent ad2e8d0 commit c4a44dd
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 114 deletions.
14 changes: 4 additions & 10 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ A compound shift on 64-bit values:

* `V`: `{u,i}64`, `VI`: `{u,i}8` \
<code>V **MultiShift**(V vals, VI indices)</code>: returns a
vector with `(vals[i] >> indices[i*8+j]) & 0xff` in byte `j` of `r[i]` for each
`j` between 0 and 7.
vector with `(vals[i] >> indices[i*8+j]) & 0xff` in byte `j` of vector `r[i]`
for each `j` between 0 and 7.

If `indices[i*8+j]` is less than 0 or greater than 63, byte `j` of `r[i]` is
implementation-defined.
Expand All @@ -1048,11 +1048,11 @@ A compound shift on 64-bit values:
#### Masked Shifts
* `V`: `{u,i}` \
<code>V **MaskedShiftLeftOrZero**&lt;int&gt;(M mask, V a)</code> returns `a[i] << int` or `0` if
<code>V **MaskedShiftLeft**&lt;int&gt;(M mask, V a)</code> returns `a[i] << int` or `0` if
`mask[i]` is false.
* `V`: `{u,i}` \
<code>V **MaskedShiftRightOrZero**&lt;int&gt;(M mask, V a)</code> returns `a[i] >> int` or `0` if
<code>V **MaskedShiftRight**&lt;int&gt;(M mask, V a)</code> returns `a[i] >> int` or `0` if
`mask[i]` is false.
* `V`: `{u,i}` \
Expand Down Expand Up @@ -2202,12 +2202,6 @@ Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
`InterleaveOdd(d, a, b)` is usually more efficient than `OddEven(b,
DupOdd(a))`.
* <code>V **InterleaveEvenOrZero**(M m, V a, V b)</code>: Performs the same
operation as InterleaveEven, but returns zero in lanes where `m[i]` is false.
* <code>V **InterleaveOddOrZero**(M m, V a, V b)</code>: Performs the same
operation as InterleaveOdd, but returns zero in lanes where `m[i]` is false.
#### Zip
* `Ret`: `MakeWide<T>`; `V`: `{u,i}{8,16,32}` \
Expand Down
8 changes: 4 additions & 4 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1103,17 +1103,17 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)

#undef HWY_SVE_SHIFT_N

// ------------------------------ MaskedShift[Left/Right]SameOrZero
// ------------------------------ MaskedShift[Left/Right]Same

#define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
template <int kBits> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \
return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeftOrZero, lsl_n)
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, asr_n)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, lsr_n)
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeft, lsl_n)
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRight, asr_n)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRight, lsr_n)

#undef HWY_SVE_SHIFT_Z

Expand Down
44 changes: 13 additions & 31 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,24 +518,6 @@ HWY_API V InterleaveEven(V a, V b) {
}
#endif

// ------------------------------ InterleaveEvenOrZero

#if HWY_TARGET != HWY_SCALAR || HWY_IDE
template <class V, class M>
HWY_API V InterleaveEvenOrZero(M m, V a, V b) {
return IfThenElseZero(m, InterleaveEven(DFromV<V>(), a, b));
}
#endif

// ------------------------------ InterleaveOddOrZero

#if HWY_TARGET != HWY_SCALAR || HWY_IDE
template <class V, class M>
HWY_API V InterleaveOddOrZero(M m, V a, V b) {
return IfThenElseZero(m, InterleaveOdd(DFromV<V>(), a, b));
}
#endif

// ------------------------------ MinMagnitude/MaxMagnitude

#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -680,19 +662,19 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
#endif // HWY_NATIVE_MASKED_ARITH

// ------------------------------ MaskedShift
template <int kshift, class V, class M>
HWY_API V MaskedShiftLeftOrZero(M m, V a) {
return IfThenElseZero(m, ShiftLeft<kshift>(a));
template <int kShift, class V, class M>
HWY_API V MaskedShiftLeft(M m, V a) {
return IfThenElseZero(m, ShiftLeft<kShift>(a));
}

template <int kshift, class V, class M>
HWY_API V MaskedShiftRightOrZero(M m, V a) {
return IfThenElseZero(m, ShiftRight<kshift>(a));
template <int kShift, class V, class M>
HWY_API V MaskedShiftRight(M m, V a) {
return IfThenElseZero(m, ShiftRight<kShift>(a));
}

template <int kshift, class V, class M>
template <int kShift, class V, class M>
HWY_API V MaskedShiftRightOr(V no, M m, V a) {
return IfThenElse(m, ShiftRight<kshift>(a), no);
return IfThenElse(m, ShiftRight<kShift>(a), no);
}

template <class V, class M>
Expand Down Expand Up @@ -7573,7 +7555,8 @@ HWY_API bool AllBits0(V a) {
return AllTrue(d, Eq(a, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS
// ------------------------------ MultiShift (Rol)

// ------------------------------ MultiShift
#if (defined(HWY_NATIVE_MULTISHIFT) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MULTISHIFT
#undef HWY_NATIVE_MULTISHIFT
Expand Down Expand Up @@ -7618,17 +7601,16 @@ HWY_API V MultiShift(V v, VI idx) {
Shr(BitCast(du16, odd_segments), odd_idx_shift);

// Extract the even bytes of each 128 bit block and pack into lower 64 bits
const auto extract_mask = Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14,
0, 0, 0, 0, 0, 0, 0, 0);
const auto even_lanes =
BitCast(d64, TableLookupBytes(extracted_even_bytes, extract_mask));
BitCast(d64, concatEven(extracted_even_bytes, Zero(du16)));
const auto odd_lanes =
BitCast(d64, TableLookupBytes(extracted_odd_bytes, extract_mask));
BitCast(d64, concatEven(extracted_odd_bytes, Zero(du16)));
// Interleave at 64 bit level
return InterleaveLower(even_lanes, odd_lanes);
}

#endif

// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/x86_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13664,6 +13664,8 @@ HWY_API V BitShuffle(V v, VI idx) {
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// TODO: Implement MultiShift using _mm_multishift_epi64_epi8

// ------------------------------ Lt128

namespace detail {
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8774,6 +8774,8 @@ HWY_API V BitShuffle(V v, VI idx) {
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// TODO: Implement MultiShift using _mm256_multishift_epi64_epi8

// ------------------------------ LeadingZeroCount

#if HWY_TARGET <= HWY_AVX3
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7541,6 +7541,8 @@ HWY_API V BitShuffle(V v, VI idx) {
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// TODO: Implement MultiShift using _mm512_multishift_epi64_epi8

// -------------------- LeadingZeroCount

template <class V, HWY_IF_UI32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)>
Expand Down
60 changes: 0 additions & 60 deletions hwy/tests/blockwise_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,72 +274,12 @@ struct TestInterleaveOdd {
}
};

struct TestMaskedInterleaveEven {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto even_lanes = AllocateAligned<T>(N);
auto odd_lanes = AllocateAligned<T>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(even_lanes && odd_lanes && expected);
for (size_t i = 0; i < N; ++i) {
even_lanes[i] = ConvertScalarTo<T>(2 * i + 0);
odd_lanes[i] = ConvertScalarTo<T>(2 * i + 1);
}
const auto even = Load(d, even_lanes.get());
const auto odd = Load(d, odd_lanes.get());

for (size_t i = 0; i < N; ++i) {
if (i < 3) {
expected[i] = ConvertScalarTo<T>(2 * i - (i & 1));
} else {
expected[i] = ConvertScalarTo<T>(0);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(),
InterleaveEvenOrZero(first_3, even, odd));
}
};

struct TestMaskedInterleaveOdd {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto even_lanes = AllocateAligned<T>(N);
auto odd_lanes = AllocateAligned<T>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(even_lanes && odd_lanes && expected);
for (size_t i = 0; i < N; ++i) {
even_lanes[i] = ConvertScalarTo<T>(2 * i + 0);
odd_lanes[i] = ConvertScalarTo<T>(2 * i + 1);
}
const auto even = Load(d, even_lanes.get());
const auto odd = Load(d, odd_lanes.get());

for (size_t i = 0; i < N; ++i) {
if (i < 3) {
expected[i] = ConvertScalarTo<T>((2 * i) - (i & 1) + 2);
} else {
expected[i] = ConvertScalarTo<T>(0);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(),
InterleaveOddOrZero(first_3, even, odd));
}
};

HWY_NOINLINE void TestAllInterleave() {
// Not DemoteVectors because this cannot be supported by HWY_SCALAR.
ForAllTypes(ForShrinkableVectors<TestInterleaveLower>());
ForAllTypes(ForShrinkableVectors<TestInterleaveUpper>());
ForAllTypes(ForShrinkableVectors<TestInterleaveEven>());
ForAllTypes(ForShrinkableVectors<TestInterleaveOdd>());
ForAllTypes(ForShrinkableVectors<TestMaskedInterleaveEven>());
ForAllTypes(ForShrinkableVectors<TestMaskedInterleaveOdd>());
}

struct TestZipLower {
Expand Down
15 changes: 6 additions & 9 deletions hwy/tests/shift_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,25 +502,22 @@ HWY_NOINLINE void TestAllVariableRoundingShr() {
ForIntegerTypes(ForPartialVectors<TestVariableRoundingShr>());
}

struct TestMaskedShiftOrZero {
struct TestMaskedShift {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const MFromD<D> all_true = MaskTrue(d);
const Vec<D> v0 = Zero(d);
const auto v1 = Iota(d, 1);
const MFromD<D> first_five = FirstN(d, 5);

HWY_ASSERT_VEC_EQ(d, ShiftLeft<1>(v1),
MaskedShiftLeftOrZero<1>(all_true, v1));
HWY_ASSERT_VEC_EQ(d, ShiftRight<1>(v1),
MaskedShiftRightOrZero<1>(all_true, v1));
HWY_ASSERT_VEC_EQ(d, ShiftLeft<1>(v1), MaskedShiftLeft<1>(all_true, v1));
HWY_ASSERT_VEC_EQ(d, ShiftRight<1>(v1), MaskedShiftRight<1>(all_true, v1));

const Vec<D> v1_exp_left = IfThenElse(first_five, ShiftLeft<1>(v1), v0);
HWY_ASSERT_VEC_EQ(d, v1_exp_left, MaskedShiftLeftOrZero<1>(first_five, v1));
HWY_ASSERT_VEC_EQ(d, v1_exp_left, MaskedShiftLeft<1>(first_five, v1));

const Vec<D> v1_exp_right = IfThenElse(first_five, ShiftRight<1>(v1), v0);
HWY_ASSERT_VEC_EQ(d, v1_exp_right,
MaskedShiftRightOrZero<1>(first_five, v1));
HWY_ASSERT_VEC_EQ(d, v1_exp_right, MaskedShiftRight<1>(first_five, v1));
}
};
struct TestMaskedShiftRightOr {
Expand Down Expand Up @@ -551,7 +548,7 @@ struct TestMaskedShrOr {
};

HWY_NOINLINE void TestAllMaskedShift() {
ForIntegerTypes(ForPartialVectors<TestMaskedShiftOrZero>());
ForIntegerTypes(ForPartialVectors<TestMaskedShift>());
ForIntegerTypes(ForPartialVectors<TestMaskedShiftRightOr>());
ForSignedTypes(ForPartialVectors<TestMaskedShrOr>());
}
Expand Down

0 comments on commit c4a44dd

Please sign in to comment.