diff --git a/include/mulfft.hpp b/include/mulfft.hpp index 13986b6..f570e54 100644 --- a/include/mulfft.hpp +++ b/include/mulfft.hpp @@ -144,6 +144,24 @@ inline void TwistIFFT(PolynomialInFD

&res, const Polynomial

&a) else static_assert(false_v, "Undefined TwistIFFT!"); } +template +inline void MulInFD(std::array &res, const std::array &b) +{ + #ifdef USE_INTERLEAVED_FORMAT + for(int i = 0; i < N / 2; i++){ + const std::complex tmp = std::complex(res[2*i], res[2*i+1]) * std::complex(b[2*i], b[2*i+1]); + res[2*i] = tmp.real(); + res[2*i+1] = tmp.imag(); + } + #else + for (int i = 0; i < N / 2; i++) { + double aimbim = res[i + N / 2] * b[i + N / 2]; + double arebim = res[i] * b[i + N / 2]; + res[i] = std::fma(res[i], b[i], -aimbim); + res[i + N / 2] = std::fma(res[i + N / 2], b[i], arebim); + } + #endif +} template inline void MulInFD(std::array &res, const std::array &a, @@ -156,11 +174,36 @@ inline void MulInFD(std::array &res, const std::array &a, res[2*i+1] = tmp.imag(); } #else + // for (int i = 0; i < N / 2; i++) { + // double aimbim = a[i + N / 2] * b[i + N / 2]; + // double arebim = a[i] * b[i + N / 2]; + // res[i] = std::fma(a[i], b[i], -aimbim); + // res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim); + // } + + // for (int i = 0; i < N / 2; i++) { + // res[i] = a[i + N / 2] * b[i + N / 2]; + // res[i + N / 2] = a[i] * b[i + N / 2]; + // } + // for (int i = 0; i < N / 2; i++) { + // res[i] = std::fma(a[i], b[i], -res[i]); + // res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]); + // } + + // for (int i = 0; i < N / 2; i++) { + // double arebre = a[i] * b[i]; + // double aimbre = a[i + N/2] * b[i]; + // res[i] = std::fma(- a[i + N / 2] , b[i + N / 2],arebre); + // res[i + N / 2] = std::fma(a[i], b[i + N / 2], aimbre); + // } + for (int i = 0; i < N / 2; i++) { - double aimbim = a[i + N / 2] * b[i + N / 2]; - double arebim = a[i] * b[i + N / 2]; - res[i] = std::fma(a[i], b[i], -aimbim); - res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim); + res[i] = a[i] * b[i]; + res[i + N / 2] = a[i + N / 2] * b[i]; + } + for (int i = 0; i < N / 2; i++) { + res[i + N / 2] += a[i] * b[i + N / 2]; + res[i] -= a[i + N / 2] * b[i + N / 2]; } #endif } @@ -204,8 +247,11 @@ inline void PolyMul(Polynomial

&res, const Polynomial

&a, TwistIFFT

(ffta, a); alignas(64) PolynomialInFD

fftb; TwistIFFT

(fftb, b); - MulInFD(ffta, ffta, fftb); + MulInFD(ffta, fftb); TwistFFT

(res, ffta); + // alignas(64) PolynomialInFD

fftres; + // MulInFD(fftres, ffta, fftb); + // TwistFFT

(res, fftres); } else if constexpr (std::is_same_v) { // Naieve diff --git a/test/polymul.cpp b/test/polymul.cpp index 8d642c8..ff521d2 100644 --- a/test/polymul.cpp +++ b/test/polymul.cpp @@ -34,15 +34,15 @@ int main() cout << "FFT Passed" << endl; for (int test = 0; test < num_test; test++) { - array a; + alignas(64) array a; for (int i = 0; i < lvl1param::n; i++) a[i] = Bgdist(engine) - lvl1param::Bg / 2; for (typename TFHEpp::lvl1param::T &i : a) i = Bgdist(engine) - lvl1param::Bg / 2; - array b; + alignas(64) array b; for (typename TFHEpp::lvl1param::T &i : b) i = Torus32dist(engine); - Polynomial polymul; + alignas(64) Polynomial polymul; TFHEpp::PolyMul(polymul, a, b); Polynomial naieve = {}; for (int i = 0; i < lvl1param::n; i++) {