Skip to content

Commit

Permalink
May be fix TFHE-RS param?
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed May 23, 2024
1 parent 905b692 commit 0f4ea6b
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 19 deletions.
22 changes: 13 additions & 9 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,14 @@ inline void TwistFFT(Polynomial<P> &res, const PolynomialInFD<P> &a)
template <class P>
inline void TwistFFTrescale(Polynomial<P> &res, const PolynomialInFD<P> &a)
{
if constexpr (std::is_same_v<typename P::T, uint32_t>)
fftplvl1.execute_direct_torus32_rescale(res.data(), a.data(), P::Δ);
// else if constexpr (std::is_same_v<typename P::T, uint64_t>)
// fftplvl2.execute_direct_torus64_rescale(res.data(), a.data());
if constexpr (std::is_same_v<P, lvl1param>){
if constexpr(std::is_same_v<typename P::T, uint32_t>)
fftplvl1.execute_direct_torus32_rescale(res.data(), a.data(), P::Δ);
else if constexpr(std::is_same_v<typename P::T, uint64_t>)
fftplvl1.execute_direct_torus64_rescale(res.data(), a.data(), P::Δ);
}
else if constexpr (std::is_same_v<P, lvl2param>)
fftplvl2.execute_direct_torus64_rescale(res.data(), a.data(), P::Δ);
else
static_assert(false_v<typename P::T>, "Undefined TwistFFT!");
}
Expand Down Expand Up @@ -179,7 +183,7 @@ inline void PolyMul(Polynomial<P> &res, const Polynomial<P> &a,
MulInFD<P::n>(ffta, ffta, fftb);
TwistFFT<P>(res, ffta);
}
else if constexpr (std::is_same_v<P, lvl2param>) {
else if constexpr (std::is_same_v<typename P::T, uint64_t>) {
// Naieve
// for (int i = 0; i < P::n; i++) {
// typename P::T ri = 0;
Expand Down Expand Up @@ -227,15 +231,15 @@ inline void PolyMulRescaleUnsigned(Polynomial<P> &res,
const UnsignedPolynomial<P> &a,
const UnsignedPolynomial<P> &b)
{
if constexpr (std::is_same_v<typename P::T, uint32_t>) {
// if constexpr (std::is_same_v<typename P::T, uint32_t>) {
PolynomialInFD<P> ffta, fftb;
TwistIFFT<P>(ffta, a);
TwistIFFT<P>(fftb, b);
MulInFD<P::n>(ffta, ffta, fftb);
TwistFFTrescale<P>(res, ffta);
}
else
static_assert(false_v<typename P::T>, "Undefined PolyMul!");
// }
// else
// static_assert(false_v<typename P::T>, "Undefined PolyMul!");
}

template <class P>
Expand Down
14 changes: 7 additions & 7 deletions include/params/tfhe-rs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ struct lvl0param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = 0;
static constexpr int32_t key_value_diff = key_value_max - key_value_min;
static constexpr std::uint32_t n = 776; // dimension
static constexpr std::uint32_t n = 636; // dimension
static constexpr std::uint32_t k = 1;
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static constexpr double α =
5.033523219195911e-06; // fresh noise, 2^{-17.6}
3.2192861177056265e-06; // fresh noise, 2^{-17.6}
using T = uint32_t; // Torus representation
static constexpr T μ = 1U << (std::numeric_limits<T>::digits - 3);
static constexpr std::make_signed_t<T> μ = 1U << (std::numeric_limits<T>::digits - 3);
static constexpr uint32_t plain_modulus = 2;
static constexpr double Δ =
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
Expand All @@ -37,7 +37,7 @@ struct lvlhalfparam {
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -17); // fresh noise
using T = uint32_t; // Torus representation
static constexpr T μ = 1U << (std::numeric_limits<T>::digits - 3);
static constexpr std::make_signed_t<T> μ = 1U << (std::numeric_limits<T>::digits - 3);
static constexpr uint32_t plain_modulus = 8;
static constexpr double Δ =
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
Expand All @@ -58,9 +58,9 @@ struct lvl1param {
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static const inline double α =
0.0000000000034525330484572114; // fresh noise, 2^{-24.8...}
3.966608917163306e-12; // fresh noise, 2^{-24.8...}
using T = uint64_t; // Torus representation
static constexpr T μ = 1ULL << 61;
static constexpr std::make_signed_t<T> μ = 1ULL << 61;
static constexpr uint32_t plain_modulus = 2;
static constexpr double Δ =
2 * static_cast<double>(1ULL << (std::numeric_limits<T>::digits - 1)) /
Expand All @@ -81,7 +81,7 @@ struct lvl2param {
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -44); // fresh noise
using T = uint64_t; // Torus representation
static constexpr T μ = 1ULL << 61;
static constexpr std::make_signed_t<T> μ = 1ULL << 61;
static constexpr uint32_t plain_modulus = 8;
static constexpr double Δ = μ;
};
Expand Down
6 changes: 3 additions & 3 deletions include/trgsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,21 @@ void trgswfftExternalProduct(TRLWE<P> &res, const TRLWE<P> &trlwe,
DecomposedPolynomial<P> decpoly;
Decomposition<P>(decpoly, trlwe[0]);
PolynomialInFD<P> decpolyfft;
__builtin_prefetch(trgswfft[0].data());
// __builtin_prefetch(trgswfft[0].data());
TwistIFFT<P>(decpolyfft, decpoly[0]);
TRLWEInFD<P> restrlwefft;
for (int m = 0; m < P::k + 1; m++)
MulInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[0][m]);
for (int i = 1; i < P::l; i++) {
__builtin_prefetch(trgswfft[i].data());
// __builtin_prefetch(trgswfft[i].data());
TwistIFFT<P>(decpolyfft, decpoly[i]);
for (int m = 0; m < P::k + 1; m++)
FMAInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[i][m]);
}
for (int k = 1; k < P::k + 1; k++) {
Decomposition<P>(decpoly, trlwe[k]);
for (int i = 0; i < P::l; i++) {
__builtin_prefetch(trgswfft[i + k * P::l].data());
// __builtin_prefetch(trgswfft[i + k * P::l].data());
TwistIFFT<P>(decpolyfft, decpoly[i]);
for (int m = 0; m < P::k + 1; m++)
FMAInFD<P::n>(restrlwefft[m], decpolyfft,
Expand Down
29 changes: 29 additions & 0 deletions thirdparties/spqlios/fft_processor_spqlios.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,35 @@ void FFT_Processor_Spqlios::execute_direct_torus64(uint64_t* res, const double*
#endif
}

void FFT_Processor_Spqlios::execute_direct_torus64_rescale(uint64_t* res, const double* a, const double Δ) {
static const double _2sN = double(2)/double(N);
//static const double _2p64 = pow(2.,64);
//for (int i=0; i<N; i++) real_inout_direct[i]=a[i]*_2sn;
{
double* dst = real_inout_direct;
const double* sit = a;
const double* send = a+N;
//double __2sN = 2./N;
const double* bla = &_2sN;
__asm__ __volatile__ (
"vbroadcastsd (%3),%%ymm2\n"
"1:\n"
"vmovupd (%1),%%ymm0\n"
"vmulpd %%ymm2,%%ymm0,%%ymm0\n"
"vmovapd %%ymm0,(%0)\n"
"addq $32,%1\n"
"addq $32,%0\n"
"cmpq %2,%1\n"
"jb 1b\n"
: "=r"(dst),"=r"(sit),"=r"(send),"=r"(bla)
: "0"(dst),"1"(sit),"2"(send),"3"(bla)
: "%ymm0","%ymm2","memory"
);
}
fft(tables_direct,real_inout_direct);
for (int i=0; i<N; i++) res[i] = uint64_t(std::round(real_inout_direct[i]/(Δ/4)));
}

FFT_Processor_Spqlios::~FFT_Processor_Spqlios() {
//delete (tables_direct);
//delete (tables_reverse);
Expand Down
2 changes: 2 additions & 0 deletions thirdparties/spqlios/fft_processor_spqlios.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class FFT_Processor_Spqlios {

void execute_direct_torus64(uint64_t* res, const double* a);

void execute_direct_torus64_rescale(uint64_t* res, const double* a, const double Δ);

~FFT_Processor_Spqlios();
};

Expand Down

0 comments on commit 0f4ea6b

Please sign in to comment.