diff --git a/include/mulfft.hpp b/include/mulfft.hpp
index 13986b6..f570e54 100644
--- a/include/mulfft.hpp
+++ b/include/mulfft.hpp
@@ -144,6 +144,24 @@ inline void TwistIFFT(PolynomialInFD
&res, const Polynomial
&a)
else
static_assert(false_v, "Undefined TwistIFFT!");
}
+template
+inline void MulInFD(std::array &res, const std::array &b)
+{
+ #ifdef USE_INTERLEAVED_FORMAT
+ for(int i = 0; i < N / 2; i++){
+ const std::complex tmp = std::complex(res[2*i], res[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
+ res[2*i] = tmp.real();
+ res[2*i+1] = tmp.imag();
+ }
+ #else
+ for (int i = 0; i < N / 2; i++) {
+ double aimbim = res[i + N / 2] * b[i + N / 2];
+ double arebim = res[i] * b[i + N / 2];
+ res[i] = std::fma(res[i], b[i], -aimbim);
+ res[i + N / 2] = std::fma(res[i + N / 2], b[i], arebim);
+ }
+ #endif
+}
template
inline void MulInFD(std::array &res, const std::array &a,
@@ -156,11 +174,36 @@ inline void MulInFD(std::array &res, const std::array &a,
res[2*i+1] = tmp.imag();
}
#else
+ // for (int i = 0; i < N / 2; i++) {
+ // double aimbim = a[i + N / 2] * b[i + N / 2];
+ // double arebim = a[i] * b[i + N / 2];
+ // res[i] = std::fma(a[i], b[i], -aimbim);
+ // res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
+ // }
+
+ // for (int i = 0; i < N / 2; i++) {
+ // res[i] = a[i + N / 2] * b[i + N / 2];
+ // res[i + N / 2] = a[i] * b[i + N / 2];
+ // }
+ // for (int i = 0; i < N / 2; i++) {
+ // res[i] = std::fma(a[i], b[i], -res[i]);
+ // res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
+ // }
+
+ // for (int i = 0; i < N / 2; i++) {
+ // double arebre = a[i] * b[i];
+ // double aimbre = a[i + N/2] * b[i];
+ // res[i] = std::fma(- a[i + N / 2] , b[i + N / 2],arebre);
+ // res[i + N / 2] = std::fma(a[i], b[i + N / 2], aimbre);
+ // }
+
for (int i = 0; i < N / 2; i++) {
- double aimbim = a[i + N / 2] * b[i + N / 2];
- double arebim = a[i] * b[i + N / 2];
- res[i] = std::fma(a[i], b[i], -aimbim);
- res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
+ res[i] = a[i] * b[i];
+ res[i + N / 2] = a[i + N / 2] * b[i];
+ }
+ for (int i = 0; i < N / 2; i++) {
+ res[i + N / 2] += a[i] * b[i + N / 2];
+ res[i] -= a[i + N / 2] * b[i + N / 2];
}
#endif
}
@@ -204,8 +247,11 @@ inline void PolyMul(Polynomial &res, const Polynomial
&a,
TwistIFFT
(ffta, a);
alignas(64) PolynomialInFD
fftb;
TwistIFFT
(fftb, b);
- MulInFD(ffta, ffta, fftb);
+ MulInFD(ffta, fftb);
TwistFFT(res, ffta);
+ // alignas(64) PolynomialInFD
fftres;
+ // MulInFD(fftres, ffta, fftb);
+ // TwistFFT(res, fftres);
}
else if constexpr (std::is_same_v) {
// Naieve
diff --git a/test/polymul.cpp b/test/polymul.cpp
index 8d642c8..ff521d2 100644
--- a/test/polymul.cpp
+++ b/test/polymul.cpp
@@ -34,15 +34,15 @@ int main()
cout << "FFT Passed" << endl;
for (int test = 0; test < num_test; test++) {
- array a;
+ alignas(64) array a;
for (int i = 0; i < lvl1param::n; i++)
a[i] = Bgdist(engine) - lvl1param::Bg / 2;
for (typename TFHEpp::lvl1param::T &i : a)
i = Bgdist(engine) - lvl1param::Bg / 2;
- array b;
+ alignas(64) array b;
for (typename TFHEpp::lvl1param::T &i : b) i = Torus32dist(engine);
- Polynomial polymul;
+ alignas(64) Polynomial polymul;
TFHEpp::PolyMul(polymul, a, b);
Polynomial naieve = {};
for (int i = 0; i < lvl1param::n; i++) {