Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the code of FindBootstrapRotationIndices() and its helper functions (standardize the use of uint32_t: phase 1) #923

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/pke/include/scheme/ckksrns/ckksrns-fhe.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,6 @@ class FHECKKSRNS : public FHERNS {
Ciphertext<DCRTPoly> EvalBootstrap(ConstCiphertext<DCRTPoly> ciphertext, uint32_t numIterations,
uint32_t precision) const override;

//------------------------------------------------------------------------------
// Find Rotation Indices
//------------------------------------------------------------------------------

std::vector<int32_t> FindBootstrapRotationIndices(uint32_t slots, uint32_t M);

std::vector<int32_t> FindLinearTransformRotationIndices(uint32_t slots, uint32_t M);

std::vector<int32_t> FindCoeffsToSlotsRotationIndices(uint32_t slots, uint32_t M);

std::vector<int32_t> FindSlotsToCoeffsRotationIndices(uint32_t slots, uint32_t M);

//------------------------------------------------------------------------------
// Precomputations for CoeffsToSlots and SlotsToCoeffs
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -232,6 +220,18 @@ class FHECKKSRNS : public FHERNS {
}

private:
//------------------------------------------------------------------------------
// Find Rotation Indices
//------------------------------------------------------------------------------
std::vector<int32_t> FindBootstrapRotationIndices(uint32_t slots, uint32_t M);

// ATTN: The following 3 functions are helper methods to be called in FindBootstrapRotationIndices() only.
// so they DO NOT remove possible duplicates and automorphisms corresponding to 0 and M/4.
// These methods completely depend on FindBootstrapRotationIndices() to do that.
std::vector<uint32_t> FindLinearTransformRotationIndices(uint32_t slots, uint32_t M);
std::vector<uint32_t> FindCoeffsToSlotsRotationIndices(uint32_t slots, uint32_t M);
std::vector<uint32_t> FindSlotsToCoeffsRotationIndices(uint32_t slots, uint32_t M);

//------------------------------------------------------------------------------
// Auxiliary Bootstrap Functions
//------------------------------------------------------------------------------
Expand Down
219 changes: 98 additions & 121 deletions src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,13 +751,13 @@ std::vector<int32_t> FHECKKSRNS::FindBootstrapRotationIndices(uint32_t slots, ui
auto pair = m_bootPrecomMap.find(slots);
if (pair == m_bootPrecomMap.end()) {
std::string errorMsg(std::string("Precomputations for ") + std::to_string(slots) +
std::string(" slots were not generated") +
std::string(" slots were not generated.") +
std::string(" Need to call EvalBootstrapSetup to proceed"));
OPENFHE_THROW(errorMsg);
}
const std::shared_ptr<CKKSBootstrapPrecom> precom = pair->second;

std::vector<int32_t> fullIndexList;
std::vector<uint32_t> fullIndexList;

bool isLTBootstrap = (precom->m_paramsEnc[CKKS_BOOT_PARAMS::LEVEL_BUDGET] == 1) &&
(precom->m_paramsDec[CKKS_BOOT_PARAMS::LEVEL_BUDGET] == 1);
Expand All @@ -768,212 +768,189 @@ std::vector<int32_t> FHECKKSRNS::FindBootstrapRotationIndices(uint32_t slots, ui
else {
fullIndexList = FindCoeffsToSlotsRotationIndices(slots, M);

std::vector<int32_t> indexListStC = FindSlotsToCoeffsRotationIndices(slots, M);
fullIndexList.insert(fullIndexList.end(), indexListStC.begin(), indexListStC.end());
std::vector<uint32_t> indexListStC{FindSlotsToCoeffsRotationIndices(slots, M)};
fullIndexList.insert(fullIndexList.end(),
std::make_move_iterator(indexListStC.begin()),
std::make_move_iterator(indexListStC.end()));
}

// Remove possible duplicates
sort(fullIndexList.begin(), fullIndexList.end());
fullIndexList.erase(unique(fullIndexList.begin(), fullIndexList.end()), fullIndexList.end());

// remove automorphisms corresponding to 0
fullIndexList.erase(std::remove(fullIndexList.begin(), fullIndexList.end(), 0), fullIndexList.end());
fullIndexList.erase(std::remove(fullIndexList.begin(), fullIndexList.end(), M / 4), fullIndexList.end());
// Remove possible duplicates and remove automorphisms corresponding to 0 and M/4 by using std::set
std::set<uint32_t> s(fullIndexList.begin(), fullIndexList.end());
s.erase(0);
s.erase(M/4);

return fullIndexList;
return std::vector<int32_t>(s.begin(), s.end());
}

std::vector<int32_t> FHECKKSRNS::FindLinearTransformRotationIndices(uint32_t slots, uint32_t M) {
// ATTN: This function is a helper methods to be called in FindBootstrapRotationIndices() only.
// so it DOES NOT remove possible duplicates and automorphisms corresponding to 0 and M/4.
// This method completely depends on FindBootstrapRotationIndices() to do that.
std::vector<uint32_t> FHECKKSRNS::FindLinearTransformRotationIndices(uint32_t slots, uint32_t M) {
auto pair = m_bootPrecomMap.find(slots);
if (pair == m_bootPrecomMap.end()) {
std::string errorMsg(std::string("Precomputations for ") + std::to_string(slots) +
std::string(" slots were not generated") +
std::string(" slots were not generated.") +
std::string(" Need to call EvalBootstrapSetup to proceed"));
OPENFHE_THROW(errorMsg);
}
const std::shared_ptr<CKKSBootstrapPrecom> precom = pair->second;

std::vector<int32_t> indexList;

// Computing the baby-step g and the giant-step h.
int g = (precom->m_dim1 == 0) ? ceil(sqrt(slots)) : precom->m_dim1;
int h = ceil(static_cast<double>(slots) / g);
uint32_t g = (precom->m_dim1 == 0) ? static_cast<uint32_t>(std::ceil(std::sqrt(slots))) : precom->m_dim1;
uint32_t h = static_cast<uint32_t>(std::ceil(static_cast<double>(slots) / g));

std::vector<uint32_t> indexList;
// To avoid overflowing uint32_t variables, we do some math operations below in a specific order
// computing all indices for baby-step giant-step procedure
// ATTN: resize() is used as indexListEvalLT may be empty here
indexList.reserve(g + h + M - 2);
for (int i = 0; i < g; i++) {
indexList.emplace_back(i + 1);
int32_t indexListSz = static_cast<int32_t>(g) + h + M - 2;
if(indexListSz < 0) {
OPENFHE_THROW("indexListSz can not be negative");
}
for (int i = 2; i < h; i++) {
indexList.reserve(indexListSz);
for (size_t i = 1; i <= g; ++i) {
indexList.emplace_back(i);
}
for (size_t i = 2; i < h; ++i) {
indexList.emplace_back(g * i);
}

uint32_t m = slots * 4;
// additional automorphisms are needed for sparse bootstrapping
if (m != M) {
for (uint32_t j = 1; j < M / m; j <<= 1) {
for (size_t j = 1; j < M / m; j <<= 1) {
indexList.emplace_back(j * slots);
}
}
// Remove possible duplicates
sort(indexList.begin(), indexList.end());
indexList.erase(unique(indexList.begin(), indexList.end()), indexList.end());

// remove automorphisms corresponding to 0
indexList.erase(std::remove(indexList.begin(), indexList.end(), 0), indexList.end());
indexList.erase(std::remove(indexList.begin(), indexList.end(), M / 4), indexList.end());

return indexList;
}

std::vector<int32_t> FHECKKSRNS::FindCoeffsToSlotsRotationIndices(uint32_t slots, uint32_t M) {
// ATTN: This function is a helper methods to be called in FindBootstrapRotationIndices() only.
// so it DOES NOT remove possible duplicates and automorphisms corresponding to 0 and M/4.
// This method completely depends on FindBootstrapRotationIndices() to do that.
std::vector<uint32_t> FHECKKSRNS::FindCoeffsToSlotsRotationIndices(uint32_t slots, uint32_t M) {
auto pair = m_bootPrecomMap.find(slots);
if (pair == m_bootPrecomMap.end()) {
std::string errorMsg(std::string("Precomputations for ") + std::to_string(slots) +
std::string(" slots were not generated") +
std::string(" slots were not generated.") +
std::string(" Need to call EvalBootstrapSetup to proceed"));
OPENFHE_THROW(errorMsg);
}
const std::shared_ptr<CKKSBootstrapPrecom> precom = pair->second;

std::vector<int32_t> indexList;

int32_t levelBudget = precom->m_paramsEnc[CKKS_BOOT_PARAMS::LEVEL_BUDGET];
int32_t layersCollapse = precom->m_paramsEnc[CKKS_BOOT_PARAMS::LAYERS_COLL];
int32_t remCollapse = precom->m_paramsEnc[CKKS_BOOT_PARAMS::LAYERS_REM];
int32_t numRotations = precom->m_paramsEnc[CKKS_BOOT_PARAMS::NUM_ROTATIONS];
int32_t b = precom->m_paramsEnc[CKKS_BOOT_PARAMS::BABY_STEP];
int32_t g = precom->m_paramsEnc[CKKS_BOOT_PARAMS::GIANT_STEP];
int32_t numRotationsRem = precom->m_paramsEnc[CKKS_BOOT_PARAMS::NUM_ROTATIONS_REM];
int32_t bRem = precom->m_paramsEnc[CKKS_BOOT_PARAMS::BABY_STEP_REM];
int32_t gRem = precom->m_paramsEnc[CKKS_BOOT_PARAMS::GIANT_STEP_REM];
uint32_t levelBudget = precom->m_paramsEnc[CKKS_BOOT_PARAMS::LEVEL_BUDGET];
uint32_t layersCollapse = precom->m_paramsEnc[CKKS_BOOT_PARAMS::LAYERS_COLL];
uint32_t remCollapse = precom->m_paramsEnc[CKKS_BOOT_PARAMS::LAYERS_REM];
uint32_t numRotations = precom->m_paramsEnc[CKKS_BOOT_PARAMS::NUM_ROTATIONS];
uint32_t b = precom->m_paramsEnc[CKKS_BOOT_PARAMS::BABY_STEP];
uint32_t g = precom->m_paramsEnc[CKKS_BOOT_PARAMS::GIANT_STEP];
uint32_t numRotationsRem = precom->m_paramsEnc[CKKS_BOOT_PARAMS::NUM_ROTATIONS_REM];
uint32_t bRem = precom->m_paramsEnc[CKKS_BOOT_PARAMS::BABY_STEP_REM];
uint32_t gRem = precom->m_paramsEnc[CKKS_BOOT_PARAMS::GIANT_STEP_REM];

int32_t stop;
int32_t flagRem;
if (remCollapse == 0) {
stop = -1;
flagRem = 0;
}
else {
stop = 0;
flagRem = 1;
}
uint32_t flagRem = (remCollapse == 0) ? 0 : 1;

std::vector<uint32_t> indexList;
// To avoid overflowing uint32_t variables, we do some math operations below in a specific order
// Computing all indices for baby-step giant-step procedure for encoding and decoding
indexList.reserve(b + g - 2 + bRem + gRem - 2 + 1 + M);
int32_t indexListSz = static_cast<int32_t>(b) + g - 2 + bRem + gRem - 2 + 1 + M;
if(indexListSz < 0) {
OPENFHE_THROW("indexListSz can not be negative");
}
indexList.reserve(indexListSz);

for (int32_t s = int32_t(levelBudget) - 1; s > stop; s--) {
for (int32_t j = 0; j < g; j++) {
indexList.emplace_back(ReduceRotation(
(j - int32_t((numRotations + 1) / 2) + 1) * (1 << ((s - flagRem) * layersCollapse + remCollapse)),
slots));
for (int32_t s = static_cast<int32_t>(levelBudget) - 1; s >= static_cast<int32_t>(flagRem); --s) {
const uint32_t scalingFactor = 1U << ((s - flagRem) * layersCollapse + remCollapse);
for (int32_t j = (1 - (numRotations + 1) / 2); j < static_cast<int32_t>(g); ++j) {
indexList.emplace_back(ReduceRotation(j * scalingFactor, slots));
}

for (int32_t i = 0; i < b; i++) {
indexList.emplace_back(
ReduceRotation((g * i) * (1 << ((s - flagRem) * layersCollapse + remCollapse)), M / 4));
for (size_t i = 0; i < b; i++) {
indexList.emplace_back(ReduceRotation((g * i) * scalingFactor, M / 4));
}
}

if (flagRem) {
for (int32_t j = 0; j < gRem; j++) {
indexList.emplace_back(ReduceRotation((j - int32_t((numRotationsRem + 1) / 2) + 1), slots));
for (int32_t j = (1 - (numRotationsRem + 1) / 2); j < static_cast<int32_t>(gRem); ++j) {
indexList.emplace_back(ReduceRotation(j, slots));
}
for (int32_t i = 0; i < bRem; i++) {
for (size_t i = 0; i < bRem; i++) {
indexList.emplace_back(ReduceRotation(gRem * i, M / 4));
}
}

uint32_t m = slots * 4;
// additional automorphisms are needed for sparse bootstrapping
if (m != M) {
for (uint32_t j = 1; j < M / m; j <<= 1) {
for (size_t j = 1; j < M / m; j <<= 1) {
indexList.emplace_back(j * slots);
}
}

// Remove possible duplicates
sort(indexList.begin(), indexList.end());
indexList.erase(unique(indexList.begin(), indexList.end()), indexList.end());

// remove automorphisms corresponding to 0
indexList.erase(std::remove(indexList.begin(), indexList.end(), 0), indexList.end());
indexList.erase(std::remove(indexList.begin(), indexList.end(), M / 4), indexList.end());

return indexList;
}

std::vector<int32_t> FHECKKSRNS::FindSlotsToCoeffsRotationIndices(uint32_t slots, uint32_t M) {
std::vector<uint32_t> FHECKKSRNS::FindSlotsToCoeffsRotationIndices(uint32_t slots, uint32_t M) {
auto pair = m_bootPrecomMap.find(slots);
if (pair == m_bootPrecomMap.end()) {
std::string errorMsg(std::string("Precomputations for ") + std::to_string(slots) +
std::string(" slots were not generated") +
std::string(" slots were not generated.") +
std::string(" Need to call EvalBootstrapSetup to proceed"));
OPENFHE_THROW(errorMsg);
}
const std::shared_ptr<CKKSBootstrapPrecom> precom = pair->second;

std::vector<int32_t> indexList;

int32_t levelBudget = precom->m_paramsDec[CKKS_BOOT_PARAMS::LEVEL_BUDGET];
int32_t layersCollapse = precom->m_paramsDec[CKKS_BOOT_PARAMS::LAYERS_COLL];
int32_t remCollapse = precom->m_paramsDec[CKKS_BOOT_PARAMS::LAYERS_REM];
int32_t numRotations = precom->m_paramsDec[CKKS_BOOT_PARAMS::NUM_ROTATIONS];
int32_t b = precom->m_paramsDec[CKKS_BOOT_PARAMS::BABY_STEP];
int32_t g = precom->m_paramsDec[CKKS_BOOT_PARAMS::GIANT_STEP];
int32_t numRotationsRem = precom->m_paramsDec[CKKS_BOOT_PARAMS::NUM_ROTATIONS_REM];
int32_t bRem = precom->m_paramsDec[CKKS_BOOT_PARAMS::BABY_STEP_REM];
int32_t gRem = precom->m_paramsDec[CKKS_BOOT_PARAMS::GIANT_STEP_REM];

int32_t flagRem;
if (remCollapse == 0) {
flagRem = 0;
}
else {
flagRem = 1;
}

uint32_t levelBudget = precom->m_paramsDec[CKKS_BOOT_PARAMS::LEVEL_BUDGET];
uint32_t layersCollapse = precom->m_paramsDec[CKKS_BOOT_PARAMS::LAYERS_COLL];
uint32_t remCollapse = precom->m_paramsDec[CKKS_BOOT_PARAMS::LAYERS_REM];
uint32_t numRotations = precom->m_paramsDec[CKKS_BOOT_PARAMS::NUM_ROTATIONS];
uint32_t b = precom->m_paramsDec[CKKS_BOOT_PARAMS::BABY_STEP];
uint32_t g = precom->m_paramsDec[CKKS_BOOT_PARAMS::GIANT_STEP];
uint32_t numRotationsRem = precom->m_paramsDec[CKKS_BOOT_PARAMS::NUM_ROTATIONS_REM];
uint32_t bRem = precom->m_paramsDec[CKKS_BOOT_PARAMS::BABY_STEP_REM];
uint32_t gRem = precom->m_paramsDec[CKKS_BOOT_PARAMS::GIANT_STEP_REM];

uint32_t flagRem = (remCollapse == 0) ? 0 : 1;
if(levelBudget < flagRem) {
OPENFHE_THROW("levelBudget can not be less than flagRem");
}

std::vector<uint32_t> indexList;
// To avoid overflowing uint32_t variables, we do some math operations below in a specific order
// Computing all indices for baby-step giant-step procedure for encoding and decoding
indexList.reserve(b + g - 2 + bRem + gRem - 2 + 1 + M);
int32_t indexListSz = static_cast<int32_t>(b) + g - 2 + bRem + gRem - 2 + 1 + M;
if(indexListSz < 0) {
OPENFHE_THROW("indexListSz can not be negative");
}
indexList.reserve(indexListSz);

for (int32_t s = 0; s < int32_t(levelBudget) - flagRem; s++) {
for (int32_t j = 0; j < g; j++) {
indexList.emplace_back(
ReduceRotation((j - (numRotations + 1) / 2 + 1) * (1 << (s * layersCollapse)), M / 4));
for (size_t s = 0; s < (levelBudget - flagRem); ++s) {
const uint32_t scalingFactor = 1U << (s * layersCollapse);
for (int32_t j = (1 - (numRotations + 1) / 2); j <= static_cast<int32_t>(g); ++j) {
indexList.emplace_back(ReduceRotation(j * scalingFactor, M / 4));
}
for (int32_t i = 0; i < b; i++) {
indexList.emplace_back(ReduceRotation((g * i) * (1 << (s * layersCollapse)), M / 4));
for (size_t i = 0; i < b; ++i) {
indexList.emplace_back(ReduceRotation((g * i) * scalingFactor, M / 4));
}
}

if (flagRem) {
int32_t s = int32_t(levelBudget) - flagRem;
for (int32_t j = 0; j < gRem; j++) {
indexList.emplace_back(
ReduceRotation((j - (numRotationsRem + 1) / 2 + 1) * (1 << (s * layersCollapse)), M / 4));
uint32_t s = levelBudget - flagRem;
const uint32_t scalingFactor = 1U << (s * layersCollapse);
for (int32_t j = (1 - (numRotationsRem + 1) / 2); j <= static_cast<int32_t>(gRem); ++j) {
indexList.emplace_back(ReduceRotation(j * scalingFactor, M / 4));
}
for (int32_t i = 0; i < bRem; i++) {
indexList.emplace_back(ReduceRotation((gRem * i) * (1 << (s * layersCollapse)), M / 4));
for (size_t i = 0; i < bRem; ++i) {
indexList.emplace_back(ReduceRotation((gRem * i) * scalingFactor, M / 4));
}
}

uint32_t m = slots * 4;
// additional automorphisms are needed for sparse bootstrapping
if (m != M) {
for (uint32_t j = 1; j < M / m; j <<= 1) {
for (size_t j = 1; j < M / m; j <<= 1) {
indexList.emplace_back(j * slots);
}
}

// Remove possible duplicates
sort(indexList.begin(), indexList.end());
indexList.erase(unique(indexList.begin(), indexList.end()), indexList.end());

// remove automorphisms corresponding to 0
indexList.erase(std::remove(indexList.begin(), indexList.end(), 0), indexList.end());
indexList.erase(std::remove(indexList.begin(), indexList.end(), M / 4), indexList.end());

return indexList;
}

Expand Down