diff --git a/inc/zoo/swar/SWAR.h b/inc/zoo/swar/SWAR.h index 60ba954..4de6bde 100644 --- a/inc/zoo/swar/SWAR.h +++ b/inc/zoo/swar/SWAR.h @@ -70,6 +70,8 @@ constexpr __uint128_t lsbIndex(__uint128_t v) noexcept { } #endif + + /// Core abstraction around SIMD Within A Register (SWAR). Specifies 'lanes' /// of NBits width against a type T, and provides an abstraction for performing /// SIMD operations against that primitive type T treated as a SIMD register. @@ -108,6 +110,17 @@ struct SWAR { return result; } + constexpr static auto evenLaneMask() { + using S = SWAR; + static_assert(0 == S::Lanes % 2, "Only even number of elements supported"); + using D = SWAR; + return S{(D::LeastSignificantBit << S::NBits) - D::LeastSignificantBit}; + } + + constexpr static auto oddLaneMask() { + return SWAR{static_cast(~evenLaneMask().value())}; + } + template constexpr static auto from(const Range &values) noexcept { using std::begin; using std::end; @@ -245,8 +258,6 @@ constexpr auto horizontalEquality(SWAR left, SWAR right) { return left.m_v == right.m_v; } - - #if ZOO_USE_LEASTNBITSMASK template constexpr auto isolate(T pattern) { diff --git a/inc/zoo/swar/associative_iteration.h b/inc/zoo/swar/associative_iteration.h index 00bc9e7..d68f14e 100644 --- a/inc/zoo/swar/associative_iteration.h +++ b/inc/zoo/swar/associative_iteration.h @@ -41,6 +41,13 @@ std::ostream &operator<<(std::ostream &out, zoo::swar::SWAR s) { namespace zoo::swar { +template +constexpr static auto consumeMSB(SWAR s) noexcept { + using S = SWAR; + auto msbCleared = s & ~S{S::MostSignificantBit}; + return S{msbCleared.value() << 1}; +} + template constexpr auto parallelSuffix(S input) { auto @@ -393,8 +400,7 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount( }; auto halver = [](auto counts) { - auto msbCleared = counts & ~S{S::MostSignificantBit}; - return S{msbCleared.value() << 1}; + return swar::consumeMSB(counts); }; auto shifted = S{multiplier.value() << (NB - ActualBits)}; @@ -426,38 +432,6 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount_deprecated( return product; } -// TODO(Jamie): Add tests from other PR. -template -constexpr auto exponentiation_OverflowUnsafe_SpecificBitCount( - SWAR x, - SWAR exponent -) { - using S = SWAR; - - auto operation = [](auto left, auto right, auto counts) { - const auto mask = makeLaneMaskFromMSB(counts); - const auto product = - multiplication_OverflowUnsafe_SpecificBitCount(left, right); - return (product & mask) | (left & ~mask); - }; - - // halver should work same as multiplication... i think... - auto halver = [](auto counts) { - auto msbCleared = counts & ~S{S::MostSignificantBit}; - return S{static_cast(msbCleared.value() << 1)}; - }; - - exponent = S{static_cast(exponent.value() << (NB - ActualBits))}; - return associativeOperatorIterated_regressive( - x, - S{meta::BitmaskMaker().value}, // neutral is lane wise.. - exponent, - S{S::MostSignificantBit}, - operation, - ActualBits, - halver - ); -} template constexpr auto multiplication_OverflowUnsafe( @@ -475,14 +449,6 @@ struct SWAR_Pair{ SWAR even, odd; }; -template -constexpr SWAR doublingMask() { - using S = SWAR; - static_assert(0 == S::Lanes % 2, "Only even number of elements supported"); - using D = SWAR; - return S{(D::LeastSignificantBit << NB) - D::LeastSignificantBit}; -} - template constexpr auto doublePrecision(SWAR input) { using S = SWAR; @@ -491,7 +457,7 @@ constexpr auto doublePrecision(SWAR input) { "Precision can only be doubled for SWARs of even element count" ); using RV = SWAR; - constexpr auto DM = doublingMask(); + constexpr auto DM = SWAR::evenLaneMask(); return SWAR_Pair{ RV{(input & DM).value()}, RV{(input.value() >> NB) & DM.value()} @@ -503,13 +469,125 @@ constexpr auto halvePrecision(SWAR even, SWAR odd) { using S = SWAR; static_assert(0 == NB % 2, "Only even lane-bitcounts supported"); using RV = SWAR; - constexpr auto HalvingMask = doublingMask(); + constexpr auto HalvingMask = SWAR::evenLaneMask(); auto evenHalf = RV{even.value()} & HalvingMask, oddHalf = RV{(RV{odd.value()} & HalvingMask).value() << NB/2}; + return evenHalf | oddHalf; } + +template struct MultiplicationResult { + SWAR lower; + SWAR upper; +}; + +template +constexpr +auto +doublePrecisionMultiplication(SWAR multiplicand, SWAR multiplier) { + auto + icand = doublePrecision(multiplicand), + plier = doublePrecision(multiplier); + auto + lower = multiplication_OverflowUnsafe(icand.even, plier.even), + upper = multiplication_OverflowUnsafe(icand.odd, plier.odd); + return std::make_pair(lower, upper); +} + +template +constexpr auto deinterleaveLanesOfPair = [](auto even, auto odd) { + using S = SWAR; + using H = SWAR; + constexpr auto + HalfLane = H::NBits, + UpperHalfOfLanes = H::oddLaneMask().value(); + auto + upper_even = even.shiftIntraLaneRight(HalfLane, S{UpperHalfOfLanes}), + upper_odd = odd.shiftIntraLaneRight(HalfLane, S{UpperHalfOfLanes}); + auto + lower = halvePrecision(even, odd), // throws away the upper bits + upper = halvePrecision(upper_even, upper_odd); // preserve the upper bits + return std::make_pair(lower, upper); +}; + +template +constexpr auto +wideningMultiplication(SWAR multiplicand, SWAR multiplier) { + auto [even, odd] = doublePrecisionMultiplication(multiplicand, multiplier); + auto [lower, upper] = deinterleaveLanesOfPair(even, odd); + return std::make_pair(lower, upper); +} + +template +constexpr +auto saturatingMultiplication(SWAR multiplicand, SWAR multiplier) { + using S = SWAR; + constexpr auto One = S{S::LeastSignificantBit}; + auto [result, overflow] = wideningMultiplication(multiplicand, multiplier); + auto did_overflow = zoo::swar::greaterEqual(overflow, One); + auto lane_mask = did_overflow.MSBtoLaneMask(); + return S{result | lane_mask}; +} + +template +constexpr auto exponentiation ( + SWAR x, + SWAR exponent, + MultiplicationFn&& multiplicationFn +) { + using S = SWAR; + constexpr auto NumBitsPerLane = S::NBits; + constexpr auto + MSB = S{S::MostSignificantBit}, + LSB = S{S::LeastSignificantBit}; + + auto operation = [&multiplicationFn](auto left, auto right, auto counts) { + auto mask = makeLaneMaskFromMSB(counts); + auto product = multiplicationFn(left, right); + return (product & mask) | (left & ~mask); + }; + + auto halver = [](auto counts) { + return swar::consumeMSB(counts); + }; + + return associativeOperatorIterated_regressive( + x, + LSB, + exponent, + MSB, + operation, + NumBitsPerLane, + halver + ); +} + +template +constexpr auto saturatingExponentation( + SWAR x, + SWAR exponent +) { + return exponentiation( + x, + exponent, + saturatingMultiplication + ); +} + +template +constexpr auto exponentiation_OverflowUnsafe( + SWAR x, + SWAR exponent +) { + return exponentiation( + x, + exponent, + multiplication_OverflowUnsafe + ); +} + } #endif diff --git a/test/swar/BasicOperations.cpp b/test/swar/BasicOperations.cpp index 602384a..bc29cf1 100644 --- a/test/swar/BasicOperations.cpp +++ b/test/swar/BasicOperations.cpp @@ -211,7 +211,7 @@ static_assert(BooleanSWAR{Literals<4, u16>, namespace Multiplication { static_assert(~int64_t(0) == negate(S4_64{S4_64::LeastSignificantBit}).value()); -static_assert(0x0F0F0F0F == doublingMask<4, uint32_t>().value()); +static_assert(0x0F0F0F0F == SWAR<4, uint32_t>::evenLaneMask().value()); constexpr auto PrecisionFixtureTest = 0x89ABCDEF; constexpr auto Doubled = @@ -255,6 +255,68 @@ HE(3, u8, 0xFF, 0x7); HE(2, u8, 0xAA, 0x2); #undef HE +template +constexpr auto testSaturatingMultiplication(T left, T right, T expected) { + using S = SWAR; + return saturatingExponentation(S{left}, S{right}).value() == expected; +} +static_assert( + testSaturatingMultiplication<8, u32>( + 0x09'40'03'01, + 0x37'03'C0'01, + 0xFF'FF'FF'01 +)); +static_assert( + testSaturatingMultiplication<8, u32>( + 0x02'02'02'02, + 0x02'02'02'02, + 0x04'04'04'04 +)); +static_assert( + testSaturatingMultiplication<8, u32>( + 0xFF'FF'FF'FF, + 0x04'03'02'01, + 0xFF'FF'FF'FF +)); +static_assert( + testSaturatingMultiplication<8, u32>( + 0x02'FF'FF'FF, + 0x03'03'02'01, + 0x08'FF'FF'FF +)); +static_assert( + testSaturatingMultiplication<4, u32>( + 0x1243'0003, + 0x0002'0002, + 0x1119'1119 +)); + +namespace test_deinterleaving { + +template +auto test = [](auto a, auto b, auto expected_lower, auto expected_upper) { + auto [lower, upper] = deinterleaveLanesOfPair(a, b); + auto lower_ok = lower.value() == expected_lower.value(); + auto upper_ok = upper.value() == expected_upper.value(); + return lower_ok && upper_ok; +}; + + +// notice the vertical groups becomes horizontal pairs +using S = SWAR<8, uint32_t>; +static_assert(test<8, uint32_t>( + S{0xFDFCFBFA}, // input a + S{0xF4F3F2F1}, // input b +/* C A + 3 1 +*/ S{0x4D3C2B1A}, // expected lower +/* 3C 1A */ + S{0xFFFFFFFF} // expected upper +)); + +} // namespace test_deinterleaving + + TEST_CASE("Old multiply version", "[deprecated][swar]") { SWAR<8, u32> Micand{0x5030201}; SWAR<8, u32> Mplier{0xA050301};