Skip to content

Commit

Permalink
Fixed PolyMul to avoid implicitly assuming different arrays for res a…
Browse files Browse the repository at this point in the history
…nd a
  • Loading branch information
nindanaoto committed Jul 15, 2024
1 parent 72f8475 commit b432281
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
56 changes: 51 additions & 5 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ inline void TwistIFFT(PolynomialInFD<P> &res, const Polynomial<P> &a)
else
static_assert(false_v<typename P::T>, "Undefined TwistIFFT!");
}
template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &b)
{
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
const std::complex tmp = std::complex(res[2*i], res[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
res[2*i] = tmp.real();
res[2*i+1] = tmp.imag();
}
#else
for (int i = 0; i < N / 2; i++) {
double aimbim = res[i + N / 2] * b[i + N / 2];
double arebim = res[i] * b[i + N / 2];
res[i] = std::fma(res[i], b[i], -aimbim);
res[i + N / 2] = std::fma(res[i + N / 2], b[i], arebim);
}
#endif
}

template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
Expand All @@ -156,11 +174,36 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
res[2*i+1] = tmp.imag();
}
#else
// for (int i = 0; i < N / 2; i++) {
// double aimbim = a[i + N / 2] * b[i + N / 2];
// double arebim = a[i] * b[i + N / 2];
// res[i] = std::fma(a[i], b[i], -aimbim);
// res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
// }

// for (int i = 0; i < N / 2; i++) {
// res[i] = a[i + N / 2] * b[i + N / 2];
// res[i + N / 2] = a[i] * b[i + N / 2];
// }
// for (int i = 0; i < N / 2; i++) {
// res[i] = std::fma(a[i], b[i], -res[i]);
// res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
// }

// for (int i = 0; i < N / 2; i++) {
// double arebre = a[i] * b[i];
// double aimbre = a[i + N/2] * b[i];
// res[i] = std::fma(- a[i + N / 2] , b[i + N / 2],arebre);
// res[i + N / 2] = std::fma(a[i], b[i + N / 2], aimbre);
// }

for (int i = 0; i < N / 2; i++) {
double aimbim = a[i + N / 2] * b[i + N / 2];
double arebim = a[i] * b[i + N / 2];
res[i] = std::fma(a[i], b[i], -aimbim);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
res[i] = a[i] * b[i];
res[i + N / 2] = a[i + N / 2] * b[i];
}
for (int i = 0; i < N / 2; i++) {
res[i + N / 2] += a[i] * b[i + N / 2];
res[i] -= a[i + N / 2] * b[i + N / 2];
}
#endif
}
Expand Down Expand Up @@ -204,8 +247,11 @@ inline void PolyMul(Polynomial<P> &res, const Polynomial<P> &a,
TwistIFFT<P>(ffta, a);
alignas(64) PolynomialInFD<P> fftb;
TwistIFFT<P>(fftb, b);
MulInFD<P::n>(ffta, ffta, fftb);
MulInFD<P::n>(ffta, fftb);
TwistFFT<P>(res, ffta);
// alignas(64) PolynomialInFD<P> fftres;
// MulInFD<P::n>(fftres, ffta, fftb);
// TwistFFT<P>(res, fftres);
}
else if constexpr (std::is_same_v<typename P::T, uint64_t>) {
// Naieve
Expand Down
6 changes: 3 additions & 3 deletions test/polymul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ int main()
cout << "FFT Passed" << endl;

for (int test = 0; test < num_test; test++) {
array<typename TFHEpp::lvl1param::T, lvl1param::n> a;
alignas(64) array<typename TFHEpp::lvl1param::T, lvl1param::n> a;
for (int i = 0; i < lvl1param::n; i++)
a[i] = Bgdist(engine) - lvl1param::Bg / 2;
for (typename TFHEpp::lvl1param::T &i : a)
i = Bgdist(engine) - lvl1param::Bg / 2;
array<typename TFHEpp::lvl1param::T, lvl1param::n> b;
alignas(64) array<typename TFHEpp::lvl1param::T, lvl1param::n> b;
for (typename TFHEpp::lvl1param::T &i : b) i = Torus32dist(engine);

Polynomial<lvl1param> polymul;
alignas(64) Polynomial<lvl1param> polymul;
TFHEpp::PolyMul<lvl1param>(polymul, a, b);
Polynomial<lvl1param> naieve = {};
for (int i = 0; i < lvl1param::n; i++) {
Expand Down

0 comments on commit b432281

Please sign in to comment.