Skip to content

Commit

Permalink
Added CatIKS Koga's opt
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Aug 14, 2024
1 parent 5daa838 commit f1524e0
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions include/keyswitch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ void IdentityKeySwitch(TLWE<typename P::targetP> &res,
(j + 1) * P::basebit)) &
mask)-halfbase;
if(aij > 0)
for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
res[k] -= ksk[i][j][aij - 1][k];
for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
res[k] -= ksk[i][j][aij - 1][k];
else if(aij < 0)
for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
res[k] += ksk[i][j][std::abs(aij) - 1][k];
for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
res[k] += ksk[i][j][-aij - 1][k];
}
}
}
Expand All @@ -75,13 +75,12 @@ void CatIdentityKeySwitch(
const std::array<TLWE<typename P::domainP>, numcat> &tlwe,
const KeySwitchingKey<P> &ksk)
{
constexpr uint32_t mask = (1U << P::basebit) - 1;
res = {};
constexpr uint domain_digit =
std::numeric_limits<typename P::domainP::T>::digits;
constexpr uint target_digit =
std::numeric_limits<typename P::targetP::T>::digits;
constexpr typename P::domainP::T prec_offset =
constexpr typename P::domainP::T roundoffset =
(P::basebit * P::t) < domain_digit
? 1ULL << (domain_digit - (1 + P::basebit * P::t))
: 0;
Expand All @@ -101,20 +100,27 @@ void CatIdentityKeySwitch(
tlwe[cat][P::domainP::k * P::domainP::n])
<< (target_digit - domain_digit);
}

//Koga's Optimization
constexpr typename P::domainP::T offset = iksoffsetgen<P>();
constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1;
constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1);
for (int i = 0; i < P::domainP::k * P::domainP::n; i++) {
std::array<typename P::domainP::T, numcat> aibarcat;
for (int cat = 0; cat < numcat; cat++)
aibarcat[cat] = tlwe[cat][i] + prec_offset;
aibarcat[cat] = tlwe[cat][i] + offset + roundoffset;
for (int j = 0; j < P::t; j++) {
for (int cat = 0; cat < numcat; cat++) {
const uint32_t aij =
(aibarcat[cat] >>
(std::numeric_limits<typename P::domainP::T>::digits -
(j + 1) * P::basebit)) &
mask;
if (aij != 0)
const int32_t aij =
((aibar >> (std::numeric_limits<typename P::domainP::T>::digits -
(j + 1) * P::basebit)) &
mask)-halfbase;
if (aij > 0)
for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
res[cat][k] -= ksk[i][j][aij - 1][k];
if (aij < 0)
for (int k = 0; k <= P::targetP::k * P::targetP::n; k++)
res[cat][k] += ksk[i][j][-aij - 1][k];
}
}
}
Expand Down

0 comments on commit f1524e0

Please sign in to comment.