Skip to content

Commit

Permalink
formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Jul 18, 2024
1 parent faa0592 commit 71e0a9b
Show file tree
Hide file tree
Showing 23 changed files with 188 additions and 134 deletions.
4 changes: 1 addition & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ if(NOT USE_TERNARY)
endif()
endif()


# if(USE_AVX512)
# string(APPEND CMAKE_CXX_FLAGS " -mprefer-vector-width=512")
# if(USE_AVX512) string(APPEND CMAKE_CXX_FLAGS " -mprefer-vector-width=512")
# endif()

if(USE_FFTW3)
Expand Down
6 changes: 4 additions & 2 deletions include/circuitbootstrapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void CircuitBootstrapping(TRGSW<typename privksP::targetP> &trgsw,
const TLWE<typename bkP::domainP> &tlwe,
const EvalKey &ek)
{
alignas(64) std::array<TLWE<typename bkP::targetP>, privksP::targetP::l> temp;
alignas(64) std::array<TLWE<typename bkP::targetP>, privksP::targetP::l>
temp;
GateBootstrappingManyLUT<bkP, privksP::targetP::l>(
temp, tlwe, ek.getbkfft<bkP>(), CBtestvector<privksP>());
for (int i = 0; i < privksP::targetP::l; i++) {
Expand Down Expand Up @@ -81,7 +82,8 @@ void CircuitBootstrappingSub(TRGSW<typename privksP::targetP> &trgsw,
{
alignas(64) TLWE<typename bkP::domainP> tlwelvl0;
IdentityKeySwitch<iksP>(tlwelvl0, tlwe, ek.getiksk<iksP>());
alignas(64) std::array<TLWE<typename bkP::targetP>, privksP::targetP::l> temp;
alignas(64) std::array<TLWE<typename bkP::targetP>, privksP::targetP::l>
temp;
GateBootstrappingManyLUT<bkP, privksP::targetP::l>(
temp, tlwelvl0, ek.getbkfft<bkP>(), CBtestvector<privksP>());
for (int i = 0; i < privksP::targetP::l; i++) {
Expand Down
42 changes: 29 additions & 13 deletions include/cloudkey.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,19 +416,23 @@ struct EvalKey {
void emplacebkfft(const SecretKey& sk)
{
if constexpr (std::is_same_v<P, lvl01param>) {
bkfftlvl01 = std::unique_ptr<BootstrappingKeyFFT<lvl01param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvl01param>());
bkfftlvl01 = std::unique_ptr<BootstrappingKeyFFT<lvl01param>>(
new (std::align_val_t(64)) BootstrappingKeyFFT<lvl01param>());
bkfftgen<lvl01param>(*bkfftlvl01, sk);
}
else if constexpr (std::is_same_v<P, lvlh1param>) {
bkfftlvlh1 = std::unique_ptr<BootstrappingKeyFFT<lvlh1param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvlh1param>());
bkfftlvlh1 = std::unique_ptr<BootstrappingKeyFFT<lvlh1param>>(
new (std::align_val_t(64)) BootstrappingKeyFFT<lvlh1param>());
bkfftgen<lvlh1param>(*bkfftlvlh1, sk);
}
else if constexpr (std::is_same_v<P, lvl02param>) {
bkfftlvl02 = std::unique_ptr<BootstrappingKeyFFT<lvl02param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvl02param>());
bkfftlvl02 = std::unique_ptr<BootstrappingKeyFFT<lvl02param>>(
new (std::align_val_t(64)) BootstrappingKeyFFT<lvl02param>());
bkfftgen<lvl02param>(*bkfftlvl02, sk);
}
else if constexpr (std::is_same_v<P, lvlh2param>) {
bkfftlvlh2 = std::unique_ptr<BootstrappingKeyFFT<lvlh2param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvlh2param>());
bkfftlvlh2 = std::unique_ptr<BootstrappingKeyFFT<lvlh2param>>(
new (std::align_val_t(64)) BootstrappingKeyFFT<lvlh2param>());
bkfftgen<lvlh2param>(*bkfftlvlh2, sk);
}
else
Expand Down Expand Up @@ -528,15 +532,18 @@ struct EvalKey {
void emplaceiksk(const SecretKey& sk)
{
if constexpr (std::is_same_v<P, lvl10param>) {
iksklvl10 = std::unique_ptr<KeySwitchingKey<lvl10param>>(new (std::align_val_t(64)) KeySwitchingKey<lvl10param>());
iksklvl10 = std::unique_ptr<KeySwitchingKey<lvl10param>>(
new (std::align_val_t(64)) KeySwitchingKey<lvl10param>());
ikskgen<lvl10param>(*iksklvl10, sk);
}
else if constexpr (std::is_same_v<P, lvl1hparam>) {
iksklvl1h = std::unique_ptr<KeySwitchingKey<lvl1hparam>>(new (std::align_val_t(64)) KeySwitchingKey<lvl1hparam>());
iksklvl1h = std::unique_ptr<KeySwitchingKey<lvl1hparam>>(
new (std::align_val_t(64)) KeySwitchingKey<lvl1hparam>());
ikskgen<lvl1hparam>(*iksklvl1h, sk);
}
else if constexpr (std::is_same_v<P, lvl20param>) {
iksklvl20 = std::unique_ptr<KeySwitchingKey<lvl20param>>(new (std::align_val_t(64)) KeySwitchingKey<lvl20param>());
iksklvl20 = std::unique_ptr<KeySwitchingKey<lvl20param>>(
new (std::align_val_t(64)) KeySwitchingKey<lvl20param>());
ikskgen<lvl20param>(*iksklvl20, sk);
}
// else if constexpr (std::is_same_v<P, lvl2hparam>) {
Expand All @@ -545,15 +552,18 @@ struct EvalKey {
// ikskgen<lvlh2param>(*iksklvlh2, sk);
// }
else if constexpr (std::is_same_v<P, lvl21param>) {
iksklvl21 = std::unique_ptr<KeySwitchingKey<lvl21param>>(new (std::align_val_t(64)) KeySwitchingKey<lvl21param>());
iksklvl21 = std::unique_ptr<KeySwitchingKey<lvl21param>>(
new (std::align_val_t(64)) KeySwitchingKey<lvl21param>());
ikskgen<lvl21param>(*iksklvl21, sk);
}
else if constexpr (std::is_same_v<P, lvl22param>) {
iksklvl22 = std::unique_ptr<KeySwitchingKey<lvl22param>>(new (std::align_val_t(64)) KeySwitchingKey<lvl22param>());
iksklvl22 = std::unique_ptr<KeySwitchingKey<lvl22param>>(
new (std::align_val_t(64)) KeySwitchingKey<lvl22param>());
ikskgen<lvl22param>(*iksklvl22, sk);
}
else if constexpr (std::is_same_v<P, lvl31param>) {
iksklvl31 = std::unique_ptr<KeySwitchingKey<lvl31param>>(new (std::align_val_t(64)) KeySwitchingKey<lvl31param>());
iksklvl31 = std::unique_ptr<KeySwitchingKey<lvl31param>>(
new (std::align_val_t(64)) KeySwitchingKey<lvl31param>());
ikskgen<lvl31param>(*iksklvl31, sk);
}
else
Expand All @@ -576,15 +586,21 @@ struct EvalKey {
const SecretKey& sk)
{
if constexpr (std::is_same_v<P, lvl11param>) {
privksklvl11[key] = std::unique_ptr<PrivateKeySwitchingKey<lvl11param>>(new (std::align_val_t(64)) PrivateKeySwitchingKey<lvl11param>());
privksklvl11[key] =
std::unique_ptr<PrivateKeySwitchingKey<lvl11param>>(new (
std::align_val_t(64)) PrivateKeySwitchingKey<lvl11param>());
privkskgen<lvl11param>(*privksklvl11[key], func, sk);
}
else if constexpr (std::is_same_v<P, lvl21param>) {
privksklvl21[key] = std::unique_ptr<PrivateKeySwitchingKey<lvl21param>>(new (std::align_val_t(64)) PrivateKeySwitchingKey<lvl21param>());
privksklvl21[key] =
std::unique_ptr<PrivateKeySwitchingKey<lvl21param>>(new (
std::align_val_t(64)) PrivateKeySwitchingKey<lvl21param>());
privkskgen<lvl21param>(*privksklvl21[key], func, sk);
}
else if constexpr (std::is_same_v<P, lvl22param>) {
privksklvl22[key] = std::unique_ptr<PrivateKeySwitchingKey<lvl22param>>(new (std::align_val_t(64)) PrivateKeySwitchingKey<lvl22param>());
privksklvl22[key] =
std::unique_ptr<PrivateKeySwitchingKey<lvl22param>>(new (
std::align_val_t(64)) PrivateKeySwitchingKey<lvl22param>());
privkskgen<lvl22param>(*privksklvl22[key], func, sk);
}
else
Expand Down
81 changes: 42 additions & 39 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ inline void TwistFFT(Polynomial<P> &res, const PolynomialInFD<P> &a)
template <class P>
inline void TwistFFTrescale(Polynomial<P> &res, const PolynomialInFD<P> &a)
{
if constexpr (std::is_same_v<P, lvl1param>){
if constexpr(std::is_same_v<typename P::T, uint32_t>)
if constexpr (std::is_same_v<P, lvl1param>) {
if constexpr (std::is_same_v<typename P::T, uint32_t>)
fftplvl1.execute_direct_torus32_rescale(res.data(), a.data(), P::Δ);
else if constexpr(std::is_same_v<typename P::T, uint64_t>)
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
fftplvl1.execute_direct_torus64_rescale(res.data(), a.data(), P::Δ);
}
else if constexpr (std::is_same_v<P, lvl2param>)
Expand Down Expand Up @@ -147,33 +147,35 @@ inline void TwistIFFT(PolynomialInFD<P> &res, const Polynomial<P> &a)
template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &b)
{
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
const std::complex tmp = std::complex(res[2*i], res[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
res[2*i] = tmp.real();
res[2*i+1] = tmp.imag();
#ifdef USE_INTERLEAVED_FORMAT
for (int i = 0; i < N / 2; i++) {
const std::complex tmp = std::complex(res[2 * i], res[2 * i + 1]) *
std::complex(b[2 * i], b[2 * i + 1]);
res[2 * i] = tmp.real();
res[2 * i + 1] = tmp.imag();
}
#else
#else
for (int i = 0; i < N / 2; i++) {
double aimbim = res[i + N / 2] * b[i + N / 2];
double arebim = res[i] * b[i + N / 2];
res[i] = std::fma(res[i], b[i], -aimbim);
res[i + N / 2] = std::fma(res[i + N / 2], b[i], arebim);
}
#endif
#endif
}

template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
const std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
res[2*i] = tmp.real();
res[2*i+1] = tmp.imag();
#ifdef USE_INTERLEAVED_FORMAT
for (int i = 0; i < N / 2; i++) {
const std::complex tmp = std::complex(a[2 * i], a[2 * i + 1]) *
std::complex(b[2 * i], b[2 * i + 1]);
res[2 * i] = tmp.real();
res[2 * i + 1] = tmp.imag();
}
#else
#else
// for (int i = 0; i < N / 2; i++) {
// double aimbim = a[i + N / 2] * b[i + N / 2];
// double arebim = a[i] * b[i + N / 2];
Expand Down Expand Up @@ -205,7 +207,7 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
res[i + N / 2] += a[i] * b[i + N / 2];
res[i] -= a[i + N / 2] * b[i + N / 2];
}
#endif
#endif
}

// Be careful about memory accesss (We assume b has relatively high memory
Expand All @@ -214,13 +216,14 @@ template <uint32_t N>
inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
res[2*i] += tmp.real();
res[2*i+1] += tmp.imag();
#ifdef USE_INTERLEAVED_FORMAT
for (int i = 0; i < N / 2; i++) {
std::complex tmp = std::complex(a[2 * i], a[2 * i + 1]) *
std::complex(b[2 * i], b[2 * i + 1]);
res[2 * i] += tmp.real();
res[2 * i + 1] += tmp.imag();
}
#else
#else
for (int i = 0; i < N / 2; i++) {
res[i] = std::fma(a[i], b[i], res[i]);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
Expand All @@ -229,13 +232,13 @@ inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
res[i + N / 2] = std::fma(a[i], b[i + N / 2], res[i + N / 2]);
res[i] -= a[i + N / 2] * b[i + N / 2];
}
// for (int i = 0; i < N / 2; i++) {
// res[i] = std::fma(a[i + N / 2], b[i + N / 2], -res[i]);
// res[i] = std::fma(a[i], b[i], -res[i]);
// res[i + N / 2] = std::fma(a[i], b[i + N / 2], res[i + N / 2]);
// res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
// }
#endif
// for (int i = 0; i < N / 2; i++) {
// res[i] = std::fma(a[i + N / 2], b[i + N / 2], -res[i]);
// res[i] = std::fma(a[i], b[i], -res[i]);
// res[i + N / 2] = std::fma(a[i], b[i + N / 2], res[i + N / 2]);
// res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
// }
#endif
}

template <class P>
Expand Down Expand Up @@ -302,11 +305,11 @@ inline void PolyMulRescaleUnsigned(Polynomial<P> &res,
const UnsignedPolynomial<P> &b)
{
// if constexpr (std::is_same_v<typename P::T, uint32_t>) {
PolynomialInFD<P> ffta, fftb;
TwistIFFT<P>(ffta, a);
TwistIFFT<P>(fftb, b);
MulInFD<P::n>(ffta, ffta, fftb);
TwistFFTrescale<P>(res, ffta);
PolynomialInFD<P> ffta, fftb;
TwistIFFT<P>(ffta, a);
TwistIFFT<P>(fftb, b);
MulInFD<P::n>(ffta, ffta, fftb);
TwistFFTrescale<P>(res, ffta);
// }
// else
// static_assert(false_v<typename P::T>, "Undefined PolyMul!");
Expand All @@ -333,7 +336,7 @@ inline void PolyMulNaive(Polynomial<P> &res, const Polynomial<P> &a,
template <class P>
std::unique_ptr<std::array<PolynomialInFD<P>, 2 * P::n>> XaittGen()
{
std::unique_ptr<std::array<PolynomialInFD<P>, 2 *P::n>> xaitt =
std::unique_ptr<std::array<PolynomialInFD<P>, 2 * P::n>> xaitt =
std::make_unique<std::array<PolynomialInFD<P>, 2 * P::n>>();
for (int i = 0; i < 2 * P::n; i++) {
std::array<typename P::T, P::n> xai = {};
Expand All @@ -350,7 +353,7 @@ std::unique_ptr<std::array<PolynomialInFD<P>, 2 * P::n>> XaittGen()
template <class P>
std::unique_ptr<std::array<PolynomialNTT<P>, 2 * P::n>> XaittGenNTT()
{
std::unique_ptr<std::array<PolynomialNTT<P>, 2 *P::n>> xaitt =
std::unique_ptr<std::array<PolynomialNTT<P>, 2 * P::n>> xaitt =
std::make_unique<std::array<PolynomialNTT<P>, 2 * P::n>>();
for (int i = 0; i < 2 * P::n; i++) {
std::array<typename P::T, P::n> xai = {};
Expand All @@ -366,10 +369,10 @@ std::unique_ptr<std::array<PolynomialNTT<P>, 2 * P::n>> XaittGenNTT()

#if defined(USE_TERNARY) || defined(USE_KEY_BUNDLE)
alignas(64) static const std::unique_ptr<
const std::array<PolynomialInFD<lvl1param>, 2 *lvl1param::n>> xaittlvl1 =
const std::array<PolynomialInFD<lvl1param>, 2 * lvl1param::n>> xaittlvl1 =
XaittGen<lvl1param>();
alignas(64) static const std::unique_ptr<
const std::array<PolynomialInFD<lvl2param>, 2 *lvl2param::n>> xaittlvl2 =
const std::array<PolynomialInFD<lvl2param>, 2 * lvl2param::n>> xaittlvl2 =
XaittGen<lvl2param>();
#endif
#ifdef USE_TERNARY
Expand Down
3 changes: 2 additions & 1 deletion include/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

namespace TFHEpp {

template<class T, size_t N> struct alignas( 64 ) aligned_array : public std::array<T,N> { };
template <class T, size_t N>
struct alignas(64) aligned_array : public std::array<T, N> {};

enum class ErrorDistribution { ModularGaussian, CenteredBinomial };

Expand Down
10 changes: 6 additions & 4 deletions include/params/tfhe-rs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ struct lvl0param {
ErrorDistribution::ModularGaussian;
static constexpr double α =
3.2192861177056265e-06; // fresh noise, 2^{-17.6}
using T = uint32_t; // Torus representation
static constexpr std::make_signed_t<T> μ = 1U << (std::numeric_limits<T>::digits - 3);
using T = uint32_t; // Torus representation
static constexpr std::make_signed_t<T> μ =
1U << (std::numeric_limits<T>::digits - 3);
static constexpr uint32_t plain_modulus = 2;
static constexpr double Δ =
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
Expand All @@ -37,7 +38,8 @@ struct lvlhalfparam {
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -17); // fresh noise
using T = uint32_t; // Torus representation
static constexpr std::make_signed_t<T> μ = 1U << (std::numeric_limits<T>::digits - 3);
static constexpr std::make_signed_t<T> μ =
1U << (std::numeric_limits<T>::digits - 3);
static constexpr uint32_t plain_modulus = 8;
static constexpr double Δ =
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
Expand All @@ -59,7 +61,7 @@ struct lvl1param {
ErrorDistribution::ModularGaussian;
static const inline double α =
3.966608917163306e-12; // fresh noise, 2^{-24.8...}
using T = uint64_t; // Torus representation
using T = uint64_t; // Torus representation
static constexpr std::make_signed_t<T> μ = 1ULL << 61;
static constexpr uint32_t plain_modulus = 2;
static constexpr double Δ =
Expand Down
4 changes: 2 additions & 2 deletions include/trgsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ constexpr typename P::T offsetgen()
}

template <class P>
inline void Decomposition(DecomposedPolynomial<P> &decpoly, const Polynomial<P> &poly,
typename P::T randbits = 0)
inline void Decomposition(DecomposedPolynomial<P> &decpoly,
const Polynomial<P> &poly, typename P::T randbits = 0)
{
#ifdef USE_OPTIMAL_DECOMPOSITION
// https://eprint.iacr.org/2021/1161
Expand Down
Loading

0 comments on commit 71e0a9b

Please sign in to comment.