Skip to content

Add more multiplication primitives #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions inc/zoo/swar/SWAR.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ constexpr __uint128_t lsbIndex(__uint128_t v) noexcept {
}
#endif



/// Core abstraction around SIMD Within A Register (SWAR). Specifies 'lanes'
/// of NBits width against a type T, and provides an abstraction for performing
/// SIMD operations against that primitive type T treated as a SIMD register.
Expand Down Expand Up @@ -108,6 +110,17 @@ struct SWAR {
return result;
}

constexpr static auto evenLaneMask() {
using S = SWAR<NBits, T>;
static_assert(0 == S::Lanes % 2, "Only even number of elements supported");
using D = SWAR<NBits * 2, T>;
return S{(D::LeastSignificantBit << S::NBits) - D::LeastSignificantBit};
}

constexpr static auto oddLaneMask() {
return SWAR<NBits, T>{static_cast<T>(~evenLaneMask().value())};
}

template <typename Range>
constexpr static auto from(const Range &values) noexcept {
using std::begin; using std::end;
Expand Down Expand Up @@ -245,8 +258,6 @@ constexpr auto horizontalEquality(SWAR<NBits, T> left, SWAR<NBits, T> right) {
return left.m_v == right.m_v;
}



#if ZOO_USE_LEASTNBITSMASK
template<int NBits, typename T = uint64_t>
constexpr auto isolate(T pattern) {
Expand Down
166 changes: 122 additions & 44 deletions inc/zoo/swar/associative_iteration.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ std::ostream &operator<<(std::ostream &out, zoo::swar::SWAR<NB, B> s) {

namespace zoo::swar {

template <int NBits, typename T>
constexpr static auto consumeMSB(SWAR<NBits, T> s) noexcept {
using S = SWAR<NBits, T>;
auto msbCleared = s & ~S{S::MostSignificantBit};
return S{msbCleared.value() << 1};
}

template<typename S>
constexpr auto parallelSuffix(S input) {
auto
Expand Down Expand Up @@ -393,8 +400,7 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount(
};

auto halver = [](auto counts) {
auto msbCleared = counts & ~S{S::MostSignificantBit};
return S{msbCleared.value() << 1};
return swar::consumeMSB(counts);
};

auto shifted = S{multiplier.value() << (NB - ActualBits)};
Expand Down Expand Up @@ -426,38 +432,6 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount_deprecated(
return product;
}

// TODO(Jamie): Add tests from other PR.
template<int ActualBits, int NB, typename T>
constexpr auto exponentiation_OverflowUnsafe_SpecificBitCount(
SWAR<NB, T> x,
SWAR<NB, T> exponent
) {
using S = SWAR<NB, T>;

auto operation = [](auto left, auto right, auto counts) {
const auto mask = makeLaneMaskFromMSB(counts);
const auto product =
multiplication_OverflowUnsafe_SpecificBitCount<ActualBits>(left, right);
return (product & mask) | (left & ~mask);
};

// halver should work same as multiplication... i think...
auto halver = [](auto counts) {
auto msbCleared = counts & ~S{S::MostSignificantBit};
return S{static_cast<T>(msbCleared.value() << 1)};
};

exponent = S{static_cast<T>(exponent.value() << (NB - ActualBits))};
return associativeOperatorIterated_regressive(
x,
S{meta::BitmaskMaker<T, 1, NB>().value}, // neutral is lane wise..
exponent,
S{S::MostSignificantBit},
operation,
ActualBits,
halver
);
}

template<int NB, typename T>
constexpr auto multiplication_OverflowUnsafe(
Expand All @@ -475,14 +449,6 @@ struct SWAR_Pair{
SWAR<NB, T> even, odd;
};

template<int NB, typename T>
constexpr SWAR<NB, T> doublingMask() {
using S = SWAR<NB, T>;
static_assert(0 == S::Lanes % 2, "Only even number of elements supported");
using D = SWAR<NB * 2, T>;
return S{(D::LeastSignificantBit << NB) - D::LeastSignificantBit};
}

template<int NB, typename T>
constexpr auto doublePrecision(SWAR<NB, T> input) {
using S = SWAR<NB, T>;
Expand All @@ -491,7 +457,7 @@ constexpr auto doublePrecision(SWAR<NB, T> input) {
"Precision can only be doubled for SWARs of even element count"
);
using RV = SWAR<NB * 2, T>;
constexpr auto DM = doublingMask<NB, T>();
constexpr auto DM = SWAR<NB, T>::evenLaneMask();
return SWAR_Pair<NB * 2, T>{
RV{(input & DM).value()},
RV{(input.value() >> NB) & DM.value()}
Expand All @@ -503,13 +469,125 @@ constexpr auto halvePrecision(SWAR<NB, T> even, SWAR<NB, T> odd) {
using S = SWAR<NB, T>;
static_assert(0 == NB % 2, "Only even lane-bitcounts supported");
using RV = SWAR<NB/2, T>;
constexpr auto HalvingMask = doublingMask<NB/2, T>();
constexpr auto HalvingMask = SWAR<NB/2, T>::evenLaneMask();
auto
evenHalf = RV{even.value()} & HalvingMask,
oddHalf = RV{(RV{odd.value()} & HalvingMask).value() << NB/2};

return evenHalf | oddHalf;
}


template <int NB, typename T> struct MultiplicationResult {
SWAR<NB, T> lower;
SWAR<NB, T> upper;
Comment on lines +482 to +483
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge

};

template <int NB, typename T>
constexpr
auto
doublePrecisionMultiplication(SWAR<NB, T> multiplicand, SWAR<NB, T> multiplier) {
auto
icand = doublePrecision(multiplicand),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! never thought about omitting the prefix

plier = doublePrecision(multiplier);
auto
lower = multiplication_OverflowUnsafe(icand.even, plier.even),
upper = multiplication_OverflowUnsafe(icand.odd, plier.odd);
return std::make_pair(lower, upper);
}

template <int NB, typename T>
constexpr auto deinterleaveLanesOfPair = [](auto even, auto odd) {
using S = SWAR<NB, T>;
using H = SWAR<NB / 2, T>;
constexpr auto
HalfLane = H::NBits,
UpperHalfOfLanes = H::oddLaneMask().value();
auto
upper_even = even.shiftIntraLaneRight(HalfLane, S{UpperHalfOfLanes}),
upper_odd = odd.shiftIntraLaneRight(HalfLane, S{UpperHalfOfLanes});
auto
lower = halvePrecision(even, odd), // throws away the upper bits
upper = halvePrecision(upper_even, upper_odd); // preserve the upper bits
return std::make_pair(lower, upper);
};

template <int NB, typename T>
constexpr auto
wideningMultiplication(SWAR<NB, T> multiplicand, SWAR<NB, T> multiplier) {
auto [even, odd] = doublePrecisionMultiplication(multiplicand, multiplier);
auto [lower, upper] = deinterleaveLanesOfPair<NB * 2, T>(even, odd);
return std::make_pair(lower, upper);
}

template <int NB, typename T>
constexpr
auto saturatingMultiplication(SWAR<NB, T> multiplicand, SWAR<NB, T> multiplier) {
using S = SWAR<NB, T>;
constexpr auto One = S{S::LeastSignificantBit};
auto [result, overflow] = wideningMultiplication(multiplicand, multiplier);
auto did_overflow = zoo::swar::greaterEqual(overflow, One);
auto lane_mask = did_overflow.MSBtoLaneMask();
return S{result | lane_mask};
}

template<int NB, typename T, typename MultiplicationFn>
constexpr auto exponentiation (
SWAR<NB, T> x,
SWAR<NB, T> exponent,
MultiplicationFn&& multiplicationFn
) {
using S = SWAR<NB, T>;
constexpr auto NumBitsPerLane = S::NBits;
constexpr auto
MSB = S{S::MostSignificantBit},
LSB = S{S::LeastSignificantBit};

auto operation = [&multiplicationFn](auto left, auto right, auto counts) {
auto mask = makeLaneMaskFromMSB(counts);
auto product = multiplicationFn(left, right);
return (product & mask) | (left & ~mask);
};

auto halver = [](auto counts) {
return swar::consumeMSB(counts);
};

return associativeOperatorIterated_regressive(
x,
LSB,
exponent,
MSB,
operation,
NumBitsPerLane,
halver
);
}

template<int NB, typename T>
constexpr auto saturatingExponentation(
SWAR<NB, T> x,
SWAR<NB, T> exponent
) {
return exponentiation(
x,
exponent,
saturatingMultiplication<NB, T>
);
}

template<int NB, typename T>
constexpr auto exponentiation_OverflowUnsafe(
SWAR<NB, T> x,
SWAR<NB, T> exponent
) {
return exponentiation(
x,
exponent,
multiplication_OverflowUnsafe<NB, T>
);
}

}

#endif
64 changes: 63 additions & 1 deletion test/swar/BasicOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ static_assert(BooleanSWAR{Literals<4, u16>,
namespace Multiplication {

static_assert(~int64_t(0) == negate(S4_64{S4_64::LeastSignificantBit}).value());
static_assert(0x0F0F0F0F == doublingMask<4, uint32_t>().value());
static_assert(0x0F0F0F0F == SWAR<4, uint32_t>::evenLaneMask().value());

constexpr auto PrecisionFixtureTest = 0x89ABCDEF;
constexpr auto Doubled =
Expand Down Expand Up @@ -255,6 +255,68 @@ HE(3, u8, 0xFF, 0x7);
HE(2, u8, 0xAA, 0x2);
#undef HE

template<int NB, typename T>
constexpr auto testSaturatingMultiplication(T left, T right, T expected) {
using S = SWAR<NB, T>;
return saturatingExponentation(S{left}, S{right}).value() == expected;
}
static_assert(
testSaturatingMultiplication<8, u32>(
0x09'40'03'01,
0x37'03'C0'01,
0xFF'FF'FF'01
));
static_assert(
testSaturatingMultiplication<8, u32>(
0x02'02'02'02,
0x02'02'02'02,
0x04'04'04'04
));
static_assert(
testSaturatingMultiplication<8, u32>(
0xFF'FF'FF'FF,
0x04'03'02'01,
0xFF'FF'FF'FF
));
static_assert(
testSaturatingMultiplication<8, u32>(
0x02'FF'FF'FF,
0x03'03'02'01,
0x08'FF'FF'FF
));
static_assert(
testSaturatingMultiplication<4, u32>(
0x1243'0003,
0x0002'0002,
0x1119'1119
));

namespace test_deinterleaving {

template <int NB, typename T>
auto test = [](auto a, auto b, auto expected_lower, auto expected_upper) {
auto [lower, upper] = deinterleaveLanesOfPair<NB, T>(a, b);
auto lower_ok = lower.value() == expected_lower.value();
auto upper_ok = upper.value() == expected_upper.value();
return lower_ok && upper_ok;
};


// notice the vertical groups becomes horizontal pairs
using S = SWAR<8, uint32_t>;
static_assert(test<8, uint32_t>(
S{0xFDFCFBFA}, // input a
S{0xF4F3F2F1}, // input b
/* C A
3 1
*/ S{0x4D3C2B1A}, // expected lower
/* 3C 1A */
S{0xFFFFFFFF} // expected upper
));

} // namespace test_deinterleaving


TEST_CASE("Old multiply version", "[deprecated][swar]") {
SWAR<8, u32> Micand{0x5030201};
SWAR<8, u32> Mplier{0xA050301};
Expand Down
Loading