Skip to content

Commit

Permalink
Fix BFVpp
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Aug 15, 2024
1 parent 6cb21e2 commit 8ba24e2
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 19 deletions.
27 changes: 11 additions & 16 deletions include/bfv++.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

namespace TFHEpp {
template <class P>
inline void RemoveSign(UnsignedPolynomial<P> &res, const Polynomial<P> &a)
{
for (int i = 0; i < P::n; i++) res[i] = (a[i] + 1) / 2;
}
template <class P>
void TRLWEMultWithoutRelinerization(TRLWE3<P> &res, const TRLWE<P> &a,
const TRLWE<P> &b)
{
Expand All @@ -34,18 +29,18 @@ void TRLWEMultWithoutRelinerization(TRLWE3<P> &res, const TRLWE<P> &a,
TwistFFTrescale<P>(res[0], fftc);

PolyMulRescaleUnsigned<P>(res[1], aa[1], bb[1]);
// PolyMulRescaleUnsigned<P>(res[2], aa[0], bb[0]);
PolyMulRescaleUnsigned<P>(res[2], aa[0], bb[0]);

for (int i = 0; i < P::n; i++) {
uint64_t ri = 0;
for (int j = 0; j <= i; j++)
ri += static_cast<uint64_t>(P::plain_modulus) *
static_cast<uint64_t>(a[0][j]) * b[0][i - j];
for (int j = i + 1; j < P::n; j++)
ri -= P::plain_modulus * static_cast<uint64_t>(a[0][j]) *
b[0][P::n + i - j];
res[2][i] = (ri + (1ULL << 31)) >> 32;
}
// for (int i = 0; i < P::n; i++) {
// uint64_t ri = 0;
// for (int j = 0; j <= i; j++)
// ri += static_cast<uint64_t>(P::plain_modulus) *
// static_cast<uint64_t>(a[0][j]) * b[0][i - j];
// for (int j = i + 1; j < P::n; j++)
// ri -= P::plain_modulus * static_cast<uint64_t>(a[0][j]) *
// b[0][P::n + i - j];
// res[2][i] = (ri + (1ULL << 31)) >> 32;
// }
}

template <class P>
Expand Down
35 changes: 34 additions & 1 deletion include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ inline void TwistIFFT(PolynomialInFD<P> &res, const Polynomial<P> &a)
else
static_assert(false_v<typename P::T>, "Undefined TwistIFFT!");
}

template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &b)
{
Expand Down Expand Up @@ -308,7 +309,7 @@ inline void PolyMulRescaleUnsigned(Polynomial<P> &res,
PolynomialInFD<P> ffta, fftb;
TwistIFFT<P>(ffta, a);
TwistIFFT<P>(fftb, b);
MulInFD<P::n>(ffta, ffta, fftb);
MulInFD<P::n>(ffta, fftb);
TwistFFTrescale<P>(res, ffta);
// }
// else
Expand All @@ -333,6 +334,38 @@ inline void PolyMulNaive(Polynomial<P> &res, const Polynomial<P> &a,
}
}

template <class P>
inline void PolyMulRescale(Polynomial<P> &res, const Polynomial<P> &a,
const Polynomial<P> &b)
{
if constexpr (std::is_same_v<typename P::T, uint32_t>) {
UnsignedPolynomial<P> aa, bb;
RemoveSign<P>(aa, a);
RemoveSign<P>(bb, b);
PolyMulRescaleUnsigned<P>(res, aa, bb);
}
else
static_assert(false_v<typename P::T>, "Undefined PolyMul!");
}

template <class P>
inline void PolyMulNaieveRescale(Polynomial<P> &res, const Polynomial<P> &a,
const Polynomial<P> &b)
{
Polynomial<P> aa, bb;
for (int i = 0; i < P::n; i++) aa[i] = (a[i] + 1) / 2;
for (int i = 0; i < P::n; i++) bb[i] = (b[i] + 1) / 2;
for (int i = 0; i < P::n; i++) {
__int128_t ri = 0;
for (int j = 0; j <= i; j++)
ri += static_cast<__int128_t>(aa[j]) * bb[i - j];
for (int j = i + 1; j < P::n; j++)
ri -= static_cast<__int128_t>(aa[j]) * bb[P::n + i - j];
// res[i] = static_cast<typename P::T>((ri) >> (std::numeric_limits<typename P::T>::digits - 3));
res[i] = static_cast<typename P::T>((ri) >> 29);
}
}

template <class P>
std::unique_ptr<std::array<PolynomialInFD<P>, 2 * P::n>> XaittGen()
{
Expand Down
6 changes: 6 additions & 0 deletions include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,10 @@ inline void Automorphism(Polynomial<P> &res, const Polynomial<P> &poly,
}
}

template <class P>
inline void RemoveSign(UnsignedPolynomial<P> &res, const Polynomial<P> &a)
{
for (int i = 0; i < P::n; i++) res[i] = (a[i] + 1) / 2;
}

} // namespace TFHEpp
4 changes: 4 additions & 0 deletions src/bfv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// #include<bfv++.hpp>

// thread_local intel::hexl::NTT nttlvl1(TFHEpp::lvl1param::n, TFHEpp::lvl1param::q);
// // thread_local intel::hexl::NTT nttlvl2(TFHEpp::lvl2param::n, TFHEpp::lvl2param::q);
20 changes: 20 additions & 0 deletions test/polymul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,25 @@ int main()
}
cout << "FFT Passed" << endl;

std::cout << "PolyMulRescale Test" << std::endl;
for (int test = 0; test < num_test; test++) {
std::random_device seed_gen;
std::default_random_engine engine(seed_gen());
std::uniform_int_distribution<typename TFHEpp::lvl1param::T> message(
0, (1ULL << 32) - 1);

TFHEpp::Polynomial<TFHEpp::lvl1param> p0, p1, pres, ptrue;
for (typename TFHEpp::lvl1param::T &i : p0) i = message(engine);
for (typename TFHEpp::lvl1param::T &i : p1) i = message(engine);

TFHEpp::PolyMulRescale<TFHEpp::lvl1param>(pres, p0, p1);
TFHEpp::PolyMulNaieveRescale<TFHEpp::lvl1param>(ptrue, p0, p1);

for (int i = 0; i < TFHEpp::lvl1param::n; i++) {
// std::cout<<pres[i]<<":"<<ptrue[i]<<std::endl;
assert(abs(static_cast<int>(pres[i] - ptrue[i])) <= 2);
}
}
std::cout << "PolyMulRescale Passed" << std::endl;
return 0;
}
2 changes: 1 addition & 1 deletion test/trlwemult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ int main()
for (int i = 0; i < P::n; i++) ptrue[i] %= P::plain_modulus;

// for (int i = 0; i < P::n; i++)
// std::cout<<pres[i]<<":"<<ptrue[i]<<std::endl;
// std::cout<<pres[i]<<":"<<ptrue[i]<<std::endl;
for (int i = 0; i < P::n; i++) assert(pres[i] == ptrue[i]);
}
std::cout << "Passed" << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion thirdparties/spqlios/fft_processor_spqlios.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32_rescale(uint32_t *res, const
);
}
fft(tables_direct, real_inout_direct);
for (int32_t i = 0; i < N; i++) res[i] = uint32_t(std::round(real_inout_direct[i]/(Δ/4)));
for (int32_t i = 0; i < N; i++) res[i] = static_cast<uint32_t>(int64_t(real_inout_direct[i]/(Δ/4)));
}

void FFT_Processor_Spqlios::execute_direct_torus64(uint64_t* res, const double* a) {
Expand Down

0 comments on commit 8ba24e2

Please sign in to comment.