diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index aa873a559e..1bb2096c25 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -1024,8 +1024,8 @@ A compound shift on 64-bit values: * `V`: `{u,i}64`, `VI`: `{u,i}8` \ V **MultiShift**(V vals, VI indices): returns a - vector with `(vals[i] >> indices[i*8+j]) & 0xff` in byte `j` of `r[i]` for each - `j` between 0 and 7. + vector with `(vals[i] >> indices[i*8+j]) & 0xff` in byte `j` of vector `r[i]` + for each `j` between 0 and 7. If `indices[i*8+j]` is less than 0 or greater than 63, byte `j` of `r[i]` is implementation-defined. @@ -1048,11 +1048,11 @@ A compound shift on 64-bit values: #### Masked Shifts * `V`: `{u,i}` \ - V **MaskedShiftLeftOrZero**<int>(M mask, V a) returns `a[i] << int` or `0` if + V **MaskedShiftLeft**<int>(M mask, V a) returns `a[i] << int` or `0` if `mask[i]` is false. * `V`: `{u,i}` \ - V **MaskedShiftRightOrZero**<int>(M mask, V a) returns `a[i] >> int` or `0` if + V **MaskedShiftRight**<int>(M mask, V a) returns `a[i] >> int` or `0` if `mask[i]` is false. * `V`: `{u,i}` \ @@ -2202,12 +2202,6 @@ Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`: `InterleaveOdd(d, a, b)` is usually more efficient than `OddEven(b, DupOdd(a))`. -* V **InterleaveEvenOrZero**(M m, V a, V b): Performs the same - operation as InterleaveEven, but returns zero in lanes where `m[i]` is false. - -* V **InterleaveOddOrZero**(M m, V a, V b): Performs the same - operation as InterleaveOdd, but returns zero in lanes where `m[i]` is false. - #### Zip * `Ret`: `MakeWide`; `V`: `{u,i}{8,16,32}` \ diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 073f20d787..5939f93aa1 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -1103,7 +1103,7 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) #undef HWY_SVE_SHIFT_N -// ------------------------------ MaskedShift[Left/Right]SameOrZero +// ------------------------------ MaskedShift[Left/Right]Same #define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ template \ @@ -1111,9 +1111,9 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) auto shifts = static_cast(kBits); \ return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \ } -HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeftOrZero, lsl_n) -HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, asr_n) -HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, lsr_n) +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeft, lsl_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRight, asr_n) +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRight, lsr_n) #undef HWY_SVE_SHIFT_Z diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index d890a3220e..ea5f3b2e2a 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -518,24 +518,6 @@ HWY_API V InterleaveEven(V a, V b) { } #endif -// ------------------------------ InterleaveEvenOrZero - -#if HWY_TARGET != HWY_SCALAR || HWY_IDE -template -HWY_API V InterleaveEvenOrZero(M m, V a, V b) { - return IfThenElseZero(m, InterleaveEven(DFromV(), a, b)); -} -#endif - -// ------------------------------ InterleaveOddOrZero - -#if HWY_TARGET != HWY_SCALAR || HWY_IDE -template -HWY_API V InterleaveOddOrZero(M m, V a, V b) { - return IfThenElseZero(m, InterleaveOdd(DFromV(), a, b)); -} -#endif - // ------------------------------ MinMagnitude/MaxMagnitude #if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE)) @@ -680,19 +662,19 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { #endif // HWY_NATIVE_MASKED_ARITH // ------------------------------ MaskedShift -template -HWY_API V MaskedShiftLeftOrZero(M m, V a) { - return IfThenElseZero(m, ShiftLeft(a)); +template +HWY_API V MaskedShiftLeft(M m, V a) { + return IfThenElseZero(m, ShiftLeft(a)); } -template -HWY_API V MaskedShiftRightOrZero(M m, V a) { - return IfThenElseZero(m, ShiftRight(a)); +template +HWY_API V MaskedShiftRight(M m, V a) { + return IfThenElseZero(m, ShiftRight(a)); } -template +template HWY_API V MaskedShiftRightOr(V no, M m, V a) { - return IfThenElse(m, ShiftRight(a), no); + return IfThenElse(m, ShiftRight(a), no); } template @@ -7573,7 +7555,8 @@ HWY_API bool AllBits0(V a) { return AllTrue(d, Eq(a, Zero(d))); } #endif // HWY_NATIVE_ALLZEROS -// ------------------------------ MultiShift (Rol) + +// ------------------------------ MultiShift #if (defined(HWY_NATIVE_MULTISHIFT) == defined(HWY_TARGET_TOGGLE)) #ifdef HWY_NATIVE_MULTISHIFT #undef HWY_NATIVE_MULTISHIFT @@ -7618,17 +7601,16 @@ HWY_API V MultiShift(V v, VI idx) { Shr(BitCast(du16, odd_segments), odd_idx_shift); // Extract the even bytes of each 128 bit block and pack into lower 64 bits - const auto extract_mask = Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14, - 0, 0, 0, 0, 0, 0, 0, 0); const auto even_lanes = - BitCast(d64, TableLookupBytes(extracted_even_bytes, extract_mask)); + BitCast(d64, concatEven(extracted_even_bytes, Zero(du16))); const auto odd_lanes = - BitCast(d64, TableLookupBytes(extracted_odd_bytes, extract_mask)); + BitCast(d64, concatEven(extracted_odd_bytes, Zero(du16))); // Interleave at 64 bit level return InterleaveLower(even_lanes, odd_lanes); } #endif + // ================================================== Operator wrapper // SVE* and RVV currently cannot define operators and have already defined diff --git a/hwy/ops/x86_128-inl.h b/hwy/ops/x86_128-inl.h index ee40181594..aafb55d548 100644 --- a/hwy/ops/x86_128-inl.h +++ b/hwy/ops/x86_128-inl.h @@ -13664,6 +13664,8 @@ HWY_API V BitShuffle(V v, VI idx) { } #endif // HWY_TARGET <= HWY_AVX3_DL +// TODO: Implement MultiShift using _mm_multishift_epi64_epi8 + // ------------------------------ Lt128 namespace detail { diff --git a/hwy/ops/x86_256-inl.h b/hwy/ops/x86_256-inl.h index 59e93bfe5b..4e86488c91 100644 --- a/hwy/ops/x86_256-inl.h +++ b/hwy/ops/x86_256-inl.h @@ -8774,6 +8774,8 @@ HWY_API V BitShuffle(V v, VI idx) { } #endif // HWY_TARGET <= HWY_AVX3_DL +// TODO: Implement MultiShift using _mm256_multishift_epi64_epi8 + // ------------------------------ LeadingZeroCount #if HWY_TARGET <= HWY_AVX3 diff --git a/hwy/ops/x86_512-inl.h b/hwy/ops/x86_512-inl.h index 32ae7c2d80..a23938e7c1 100644 --- a/hwy/ops/x86_512-inl.h +++ b/hwy/ops/x86_512-inl.h @@ -7541,6 +7541,8 @@ HWY_API V BitShuffle(V v, VI idx) { } #endif // HWY_TARGET <= HWY_AVX3_DL +// TODO: Implement MultiShift using _mm512_multishift_epi64_epi8 + // -------------------- LeadingZeroCount template ), HWY_IF_V_SIZE_V(V, 64)> diff --git a/hwy/tests/blockwise_test.cc b/hwy/tests/blockwise_test.cc index 99811096e9..03ee45ccbf 100644 --- a/hwy/tests/blockwise_test.cc +++ b/hwy/tests/blockwise_test.cc @@ -274,72 +274,12 @@ struct TestInterleaveOdd { } }; -struct TestMaskedInterleaveEven { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const size_t N = Lanes(d); - const MFromD first_3 = FirstN(d, 3); - auto even_lanes = AllocateAligned(N); - auto odd_lanes = AllocateAligned(N); - auto expected = AllocateAligned(N); - HWY_ASSERT(even_lanes && odd_lanes && expected); - for (size_t i = 0; i < N; ++i) { - even_lanes[i] = ConvertScalarTo(2 * i + 0); - odd_lanes[i] = ConvertScalarTo(2 * i + 1); - } - const auto even = Load(d, even_lanes.get()); - const auto odd = Load(d, odd_lanes.get()); - - for (size_t i = 0; i < N; ++i) { - if (i < 3) { - expected[i] = ConvertScalarTo(2 * i - (i & 1)); - } else { - expected[i] = ConvertScalarTo(0); - } - } - - HWY_ASSERT_VEC_EQ(d, expected.get(), - InterleaveEvenOrZero(first_3, even, odd)); - } -}; - -struct TestMaskedInterleaveOdd { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const size_t N = Lanes(d); - const MFromD first_3 = FirstN(d, 3); - auto even_lanes = AllocateAligned(N); - auto odd_lanes = AllocateAligned(N); - auto expected = AllocateAligned(N); - HWY_ASSERT(even_lanes && odd_lanes && expected); - for (size_t i = 0; i < N; ++i) { - even_lanes[i] = ConvertScalarTo(2 * i + 0); - odd_lanes[i] = ConvertScalarTo(2 * i + 1); - } - const auto even = Load(d, even_lanes.get()); - const auto odd = Load(d, odd_lanes.get()); - - for (size_t i = 0; i < N; ++i) { - if (i < 3) { - expected[i] = ConvertScalarTo((2 * i) - (i & 1) + 2); - } else { - expected[i] = ConvertScalarTo(0); - } - } - - HWY_ASSERT_VEC_EQ(d, expected.get(), - InterleaveOddOrZero(first_3, even, odd)); - } -}; - HWY_NOINLINE void TestAllInterleave() { // Not DemoteVectors because this cannot be supported by HWY_SCALAR. ForAllTypes(ForShrinkableVectors()); ForAllTypes(ForShrinkableVectors()); ForAllTypes(ForShrinkableVectors()); ForAllTypes(ForShrinkableVectors()); - ForAllTypes(ForShrinkableVectors()); - ForAllTypes(ForShrinkableVectors()); } struct TestZipLower { diff --git a/hwy/tests/shift_test.cc b/hwy/tests/shift_test.cc index 7207403fcc..a0c661cb3f 100644 --- a/hwy/tests/shift_test.cc +++ b/hwy/tests/shift_test.cc @@ -502,7 +502,7 @@ HWY_NOINLINE void TestAllVariableRoundingShr() { ForIntegerTypes(ForPartialVectors()); } -struct TestMaskedShiftOrZero { +struct TestMaskedShift { template HWY_NOINLINE void operator()(T /*unused*/, D d) { const MFromD all_true = MaskTrue(d); @@ -510,17 +510,14 @@ struct TestMaskedShiftOrZero { const auto v1 = Iota(d, 1); const MFromD first_five = FirstN(d, 5); - HWY_ASSERT_VEC_EQ(d, ShiftLeft<1>(v1), - MaskedShiftLeftOrZero<1>(all_true, v1)); - HWY_ASSERT_VEC_EQ(d, ShiftRight<1>(v1), - MaskedShiftRightOrZero<1>(all_true, v1)); + HWY_ASSERT_VEC_EQ(d, ShiftLeft<1>(v1), MaskedShiftLeft<1>(all_true, v1)); + HWY_ASSERT_VEC_EQ(d, ShiftRight<1>(v1), MaskedShiftRight<1>(all_true, v1)); const Vec v1_exp_left = IfThenElse(first_five, ShiftLeft<1>(v1), v0); - HWY_ASSERT_VEC_EQ(d, v1_exp_left, MaskedShiftLeftOrZero<1>(first_five, v1)); + HWY_ASSERT_VEC_EQ(d, v1_exp_left, MaskedShiftLeft<1>(first_five, v1)); const Vec v1_exp_right = IfThenElse(first_five, ShiftRight<1>(v1), v0); - HWY_ASSERT_VEC_EQ(d, v1_exp_right, - MaskedShiftRightOrZero<1>(first_five, v1)); + HWY_ASSERT_VEC_EQ(d, v1_exp_right, MaskedShiftRight<1>(first_five, v1)); } }; struct TestMaskedShiftRightOr { @@ -551,7 +548,7 @@ struct TestMaskedShrOr { }; HWY_NOINLINE void TestAllMaskedShift() { - ForIntegerTypes(ForPartialVectors()); + ForIntegerTypes(ForPartialVectors()); ForIntegerTypes(ForPartialVectors()); ForSignedTypes(ForPartialVectors()); }