diff --git a/include/cloudkey.hpp b/include/cloudkey.hpp
index ac4a15d..af2325d 100644
--- a/include/cloudkey.hpp
+++ b/include/cloudkey.hpp
@@ -193,7 +193,7 @@ void ikskgen(KeySwitchingKey
& ksk, const Key& domainkey,
for (int l = 0; l < P::domainP::k; l++)
for (int i = 0; i < P::domainP::n; i++)
for (int j = 0; j < P::t; j++)
- for (uint32_t k = 0; k < (1 << P::basebit) - 1; k++)
+ for (uint32_t k = 0; k < 1U << (P::basebit-1); k++)
ksk[l * P::domainP::n + i][j][k] =
tlweSymEncrypt(
domainkey[l * P::domainP::n + i] * (k + 1) *
diff --git a/include/keyswitch.hpp b/include/keyswitch.hpp
index 383acd5..f0977cb 100644
--- a/include/keyswitch.hpp
+++ b/include/keyswitch.hpp
@@ -8,22 +8,31 @@
namespace TFHEpp {
+template
+constexpr typename P::domainP::T iksoffsetgen()
+{
+ typename P::domainP::T offset = 0;
+ for (int i = 1; i <= P::t; i++)
+ offset += (1ULL<::digits -
+ i * P::basebit));
+ return offset;
+}
+
template
void IdentityKeySwitch(TLWE &res,
const TLWE &tlwe,
const KeySwitchingKey &ksk)
{
- constexpr uint32_t mask = (1U << P::basebit) - 1;
res = {};
constexpr uint domain_digit =
std::numeric_limits::digits;
constexpr uint target_digit =
std::numeric_limits::digits;
- constexpr typename P::domainP::T prec_offset =
+ constexpr typename P::domainP::T roundoffset =
(P::basebit * P::t) < domain_digit
? 1ULL << (domain_digit - (1 + P::basebit * P::t))
: 0;
-
if constexpr (domain_digit == target_digit)
res[P::targetP::k * P::targetP::n] =
tlwe[P::domainP::k * P::domainP::n];
@@ -37,16 +46,25 @@ void IdentityKeySwitch(TLWE &res,
static_cast(
tlwe[P::domainP::k * P::domainP::n])
<< (target_digit - domain_digit);
+
+//Koga's Optimization
+ constexpr typename P::domainP::T offset = iksoffsetgen();
+ constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1;
+ constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1);
+
for (int i = 0; i < P::domainP::k * P::domainP::n; i++) {
- const typename P::domainP::T aibar = tlwe[i] + prec_offset;
+ const typename P::domainP::T aibar = tlwe[i] + offset + roundoffset;
for (int j = 0; j < P::t; j++) {
- const uint32_t aij =
- (aibar >> (std::numeric_limits::digits -
+ const int32_t aij =
+ ((aibar >> (std::numeric_limits::digits -
(j + 1) * P::basebit)) &
- mask;
- if (aij != 0)
- for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
- res[k] -= ksk[i][j][aij - 1][k];
+ mask)-halfbase;
+ if(aij > 0)
+ for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
+ res[k] -= ksk[i][j][aij - 1][k];
+ else if(aij < 0)
+ for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
+ res[k] += ksk[i][j][std::abs(aij) - 1][k];
}
}
}
diff --git a/include/params.hpp b/include/params.hpp
index 2eca7c8..98e03db 100644
--- a/include/params.hpp
+++ b/include/params.hpp
@@ -147,7 +147,7 @@ using BootstrappingKeyRAINTT =
template
using KeySwitchingKey = std::array<
- std::array, (1 << P::basebit) - 1>,
+ std::array, (1 << (P::basebit-1))>,
P::t>,
P::domainP::k * P::domainP::n>;
template