Skip to content

Commit

Permalink
Added HomDecomp
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed May 3, 2024
1 parent c19189a commit 8ec8d3f
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 6 deletions.
25 changes: 22 additions & 3 deletions include/cloudkey.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ struct EvalKey {
lweParams params;
// BootstrapingKey
std::shared_ptr<BootstrappingKey<lvl01param>> bklvl01;
std::shared_ptr<BootstrappingKey<lvl01param>> bklvlh1;
std::shared_ptr<BootstrappingKey<lvlh1param>> bklvlh1;
std::shared_ptr<BootstrappingKey<lvl02param>> bklvl02;
std::shared_ptr<BootstrappingKey<lvl02param>> bklvlh2;
std::shared_ptr<BootstrappingKey<lvlh2param>> bklvlh2;
// BoostrappingKeyFFT
std::shared_ptr<BootstrappingKeyFFT<lvl01param>> bkfftlvl01;
std::shared_ptr<BootstrappingKeyFFT<lvlh1param>> bkfftlvlh1;
Expand All @@ -354,6 +354,7 @@ struct EvalKey {
std::shared_ptr<KeySwitchingKey<lvl20param>> iksklvl20;
std::shared_ptr<KeySwitchingKey<lvl21param>> iksklvl21;
std::shared_ptr<KeySwitchingKey<lvl22param>> iksklvl22;
std::shared_ptr<KeySwitchingKey<lvl31param>> iksklvl31;
// SubsetKeySwitchingKey
std::shared_ptr<SubsetKeySwitchingKey<lvl21param>> subiksklvl21;
// PrivateKeySwitchingKey
Expand All @@ -378,7 +379,7 @@ struct EvalKey {
void serialize(Archive& archive)
{
archive(params, bklvl01, bklvlh1, bklvl02, bklvlh2, bkfftlvl01, bkfftlvlh1, bkfftlvl02, bkfftlvlh2, bknttlvl01,
bknttlvlh1, bknttlvl02, bknttlvlh2, iksklvl10, iksklvl1h, iksklvl20, iksklvl21, iksklvl22,
bknttlvlh1, bknttlvl02, bknttlvlh2, iksklvl10, iksklvl1h, iksklvl20, iksklvl21, iksklvl22, iksklvl31,
privksklvl11, privksklvl21, privksklvl22);
}

Expand All @@ -391,11 +392,21 @@ struct EvalKey {
std::make_unique_for_overwrite<BootstrappingKey<lvl01param>>();
bkgen<lvl01param>(*bklvl01, sk);
}
else if constexpr (std::is_same_v<P, lvlh1param>) {
bklvlh1 =
std::make_unique_for_overwrite<BootstrappingKey<lvlh1param>>();
bkgen<lvlh1param>(*bklvlh1, sk);
}
else if constexpr (std::is_same_v<P, lvl02param>) {
bklvl02 =
std::make_unique_for_overwrite<BootstrappingKey<lvl02param>>();
bkgen<lvl02param>(*bklvl02, sk);
}
else if constexpr (std::is_same_v<P, lvlh2param>) {
bklvlh2 =
std::make_unique_for_overwrite<BootstrappingKey<lvlh2param>>();
bkgen<lvlh2param>(*bklvlh2, sk);
}
else
static_assert(false_v<typename P::T>, "Not predefined parameter!");
}
Expand Down Expand Up @@ -548,6 +559,11 @@ struct EvalKey {
std::make_unique_for_overwrite<KeySwitchingKey<lvl22param>>();
ikskgen<lvl22param>(*iksklvl22, sk);
}
else if constexpr (std::is_same_v<P, lvl31param>) {
iksklvl31 =
std::make_unique_for_overwrite<KeySwitchingKey<lvl31param>>();
ikskgen<lvl31param>(*iksklvl31, sk);
}
else
static_assert(false_v<typename P::T>, "Not predefined parameter!");
}
Expand Down Expand Up @@ -703,6 +719,9 @@ struct EvalKey {
else if constexpr (std::is_same_v<P, lvl22param>) {
return *iksklvl22;
}
else if constexpr (std::is_same_v<P, lvl31param>) {
return *iksklvl31;
}
else
static_assert(false_v<typename P::T>, "Not predefined parameter!");
}
Expand Down
69 changes: 69 additions & 0 deletions include/homdecomp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* @author Kotaro Matsuoka
*/

#pragma once

#include "gatebootstrapping.hpp"

namespace TFHEpp {
/*!
* @brief Generates a Polynomial with each coefficient subtracted by a base value
* @tparam P The parameter set for the Polynomial
* @tparam basebit The base to be subtracted
* @return A Polynomial of the parameter type with coefficients subtracted by the base value
*/
template <class P, uint basebit> constexpr Polynomial<P> subtractpolygen() {
Polynomial<P> poly;
for (int i = 0; i < P::n; i++)
poly[i] = 1ULL << (std::numeric_limits<typename P::T>::digits - basebit - 2);
return poly;
}

// https://eprint.iacr.org/2023/645
/*!
* @brief Homomorphically decomposes an input ciphertext into an array of level 1 ciphertexts
* @tparam high2midP The parameter set for the transition from high to mid level
* @tparam mid2lowP The parameter set for the transition from mid to low level
* @tparam brP The bootstrapping parameter set
* @tparam basebit The base value
* @tparam numdigit The number of digits
* @param cres Array of output ciphertexts
* @param cin Input ciphertext
* @param kskh2m The key switching key for high to mid level
* @param kskm2l The key switching key for mid to low level
* @param bkfft The bootstrapping key FFT
*/
template <class high2midP, class mid2lowP, class brP, uint basebit, uint numdigit>
void HomDecomp(std::array<TLWE<typename high2midP::targetP>, numdigit> &cres,
const TLWE<typename high2midP::domainP> &cin, const KeySwitchingKey<high2midP> &kskh2m,
const KeySwitchingKey<mid2lowP> &kskm2l, const BootstrappingKeyFFT<brP> &bkfft) {
TFHEpp::TLWE<typename mid2lowP::targetP> tlwelvlhalf;
TFHEpp::TLWE<typename high2midP::targetP> subtlwe;

// cres will be used as a reusable buffer
constexpr uint32_t plain_modulusbit = basebit * numdigit;
#pragma omp parallel for default(none) shared(cin, cres, kskh2m)
for (int digit = 1; digit <= numdigit; digit++) {
TFHEpp::TLWE<typename high2midP::domainP> switchedtlwe;
for (int i = 0; i <= high2midP::domainP::k * high2midP::domainP::n; i++)
switchedtlwe[i] = cin[i] << (plain_modulusbit - basebit * digit);
IdentityKeySwitch<high2midP>(cres[digit - 1], switchedtlwe, kskh2m);
}
for (int digit = 1; digit <= numdigit; digit++) {
if (digit != 1) {
for (int i = 0; i <= high2midP::targetP::k * high2midP::targetP::n; i++)
cres[digit - 1][i] += subtlwe[i];
cres[digit - 1][high2midP::targetP::k * high2midP::targetP::n] -=
1ULL << (std::numeric_limits<typename high2midP::targetP::T>::digits - basebit - 1);
}
IdentityKeySwitch<mid2lowP>(tlwelvlhalf, cres[digit - 1], kskm2l);
tlwelvlhalf[mid2lowP::targetP::k * mid2lowP::targetP::n] +=
1ULL << (std::numeric_limits<typename mid2lowP::targetP::T>::digits - basebit - 1);
if (digit != numdigit)
GateBootstrappingTLWE2TLWEFFT<brP>(subtlwe, tlwelvlhalf, bkfft,
subtractpolygen<typename high2midP::targetP, basebit>());
}
}

} // namespace TFHEpp
4 changes: 3 additions & 1 deletion include/key.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ namespace TFHEpp {
using namespace std;
struct lweKey {
Key<lvl0param> lvl0;
Key<lvlhalfparam> lvlhalf;
Key<lvl1param> lvl1;
Key<lvl2param> lvl2;
Key<lvl3param> lvl3;
lweKey();
template <class P>
Key<P> get() const;
Expand All @@ -31,7 +33,7 @@ struct SecretKey {
template <class Archive>
void serialize(Archive &archive)
{
archive(key.lvl0, key.lvl1, key.lvl2, params);
archive(key.lvl0, key.lvlhalf, key.lvl1, key.lvl2, key.lvl3, params);
}
};
} // namespace TFHEpp
3 changes: 2 additions & 1 deletion include/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ using relinKeyFFT = std::array<TRLWEInFD<P>, P::l>;
fun(lvl0param); \
fun(lvlhalfparam); \
fun(lvl1param); \
fun(lvl2param);
fun(lvl2param); \
fun(lvl3param);
#define TFHEPP_EXPLICIT_INSTANTIATION_TRLWE(fun) \
fun(lvl1param); \
fun(lvl2param);
Expand Down
28 changes: 28 additions & 0 deletions include/params/128bit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ struct lvl2param {
static_cast<double>(1ULL << (std::numeric_limits<T>::digits - 4));
};

struct lvl3param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = -1;
static const std::uint32_t nbit = 13; // dimension must be a power of 2 for
// ease of polynomial multiplication.
static constexpr std::uint32_t n = 1 << nbit; // dimension
static constexpr std::uint32_t k = 1;
static constexpr std::uint32_t l = 4;
static constexpr std::uint32_t Bgbit = 9;
static constexpr std::uint32_t Bg = 1 << Bgbit;
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -47); // fresh noise
using T = uint64_t; // Torus representation
static constexpr T μ = 1ULL << 61;
static constexpr uint32_t plain_modulusbit = 31;
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
};

// Key Switching parameters
struct lvl10param {
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
Expand Down Expand Up @@ -143,3 +163,11 @@ struct lvl22param {
using domainP = lvl2param;
using targetP = lvl2param;
};

struct lvl31param {
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 2; // how many bit should be encrypted in keyswitching key
static const inline double α = lvl1param::α; // key noise
using domainP = lvl3param;
using targetP = lvl1param;
};
33 changes: 32 additions & 1 deletion include/params/CGGI16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct lvl0param {
plain_modulus;
};

//Dummy
struct lvlhalfparam {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = 0;
Expand Down Expand Up @@ -73,6 +74,27 @@ struct lvl2param {
static constexpr double Δ = μ;
};

//Dummy
struct lvl3param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = -1;
static const std::uint32_t nbit = 13; // dimension must be a power of 2 for
// ease of polynomial multiplication.
static constexpr std::uint32_t n = 1 << nbit; // dimension
static constexpr std::uint32_t k = 1;
static constexpr std::uint32_t l = 4;
static constexpr std::uint32_t Bgbit = 9;
static constexpr std::uint32_t Bg = 1 << Bgbit;
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -47); // fresh noise
using T = uint64_t; // Torus representation
static constexpr T μ = 1ULL << 61;
static constexpr uint32_t plain_modulusbit = 31;
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
};

struct lvl10param {
static constexpr std::uint32_t t = 8;
static constexpr std::uint32_t basebit = 2;
Expand All @@ -83,6 +105,7 @@ struct lvl10param {
using targetP = lvl0param;
};

//Dummy
struct lvl1hparam {
static constexpr std::uint32_t t = 10; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 3; // how many bit should be encrypted in keyswitching key
Expand All @@ -91,7 +114,7 @@ struct lvl1hparam {
using targetP = lvlhalfparam;
};

// dummy
// Dummy
struct lvl11param {
static constexpr std::uint32_t t = 0; // number of addition in keyswitching
static constexpr std::uint32_t basebit =
Expand Down Expand Up @@ -134,4 +157,12 @@ struct lvl22param {
static const inline double α = lvl2param::α;
using domainP = lvl2param;
using targetP = lvl2param;
};

struct lvl31param {
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 2; // how many bit should be encrypted in keyswitching key
static const inline double α = lvl1param::α; // key noise
using domainP = lvl3param;
using targetP = lvl1param;
};
32 changes: 32 additions & 0 deletions include/params/CGGI19.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct lvl0param {
plain_modulus;
};

//Dummy
struct lvlhalfparam {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = 0;
Expand Down Expand Up @@ -75,6 +76,27 @@ struct lvl2param {
static constexpr double Δ = μ;
};

//Dummy
struct lvl3param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = -1;
static const std::uint32_t nbit = 13; // dimension must be a power of 2 for
// ease of polynomial multiplication.
static constexpr std::uint32_t n = 1 << nbit; // dimension
static constexpr std::uint32_t k = 1;
static constexpr std::uint32_t l = 4;
static constexpr std::uint32_t Bgbit = 9;
static constexpr std::uint32_t Bg = 1 << Bgbit;
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -47); // fresh noise
using T = uint64_t; // Torus representation
static constexpr T μ = 1ULL << 61;
static constexpr uint32_t plain_modulusbit = 31;
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
};

// Dummy
struct lvl11param {
static constexpr std::uint32_t t = 0; // number of addition in keyswitching
Expand All @@ -97,6 +119,7 @@ struct lvl10param {
using targetP = lvl0param;
};

//Dummy
struct lvl1hparam {
static constexpr std::uint32_t t = 10; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 3; // how many bit should be encrypted in keyswitching key
Expand Down Expand Up @@ -137,3 +160,12 @@ struct lvl22param {
using domainP = lvl2param;
using targetP = lvl2param;
};

//Dummy
struct lvl31param {
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 2; // how many bit should be encrypted in keyswitching key
static const inline double α = lvl1param::α; // key noise
using domainP = lvl3param;
using targetP = lvl1param;
};
33 changes: 33 additions & 0 deletions include/params/compress.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct lvl0param {
plain_modulus;
};

//Dummy
struct lvlhalfparam {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = 0;
Expand Down Expand Up @@ -84,6 +85,28 @@ struct lvl2param {
static constexpr double Δ = μ;
};


//Dummy
struct lvl3param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = -1;
static const std::uint32_t nbit = 13; // dimension must be a power of 2 for
// ease of polynomial multiplication.
static constexpr std::uint32_t n = 1 << nbit; // dimension
static constexpr std::uint32_t k = 1;
static constexpr std::uint32_t l = 4;
static constexpr std::uint32_t Bgbit = 9;
static constexpr std::uint32_t Bg = 1 << Bgbit;
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -47); // fresh noise
using T = uint64_t; // Torus representation
static constexpr T μ = 1ULL << 61;
static constexpr uint32_t plain_modulusbit = 31;
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
};

// Key Switching parameters
struct lvl10param {
static constexpr std::uint32_t t = 5; // number of addition in keyswitching
Expand All @@ -93,6 +116,7 @@ struct lvl10param {
using targetP = lvl0param;
};

//Dummy
struct lvl1hparam {
static constexpr std::uint32_t t = 10; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 3; // how many bit should be encrypted in keyswitching key
Expand Down Expand Up @@ -134,3 +158,12 @@ struct lvl22param {
using domainP = lvl2param;
using targetP = lvl2param;
};

//Dummy
struct lvl31param {
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
static constexpr std::uint32_t basebit = 2; // how many bit should be encrypted in keyswitching key
static const inline double α = lvl1param::α; // key noise
using domainP = lvl3param;
using targetP = lvl1param;
};
Loading

0 comments on commit 8ec8d3f

Please sign in to comment.