Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sealPir 其他安全强度支持(4096安全参数支持) #114

Open
Yinbenxin opened this issue Apr 18, 2024 · 4 comments
Open

sealPir 其他安全强度支持(4096安全参数支持) #114

Yinbenxin opened this issue Apr 18, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@Yinbenxin
Copy link

此问题发生在尝试使用4096安全强度来进行匿踪查询,经过为期1周的排查,问题终于浮出水面并得以解决。
问题代码:
seal_pir.cc中std::vectorseal::Ciphertext SealPirServer::ExpandQuery函数:

std::vector<seal::Ciphertext> SealPirServer::ExpandQuery(
    const seal::Ciphertext &encrypted, std::uint32_t m) {
  uint64_t plain_mod = seal_params_->plain_modulus().value();

  seal::GaloisKeys &galkey = galois_key_;

  // Assume that m is a power of 2. If not, round it to the next power of 2.
  uint32_t logm = std::ceil(std::log2(m));

  std::vector<int> galois_elts;
  auto n = seal_params_->poly_modulus_degree();
  YACL_ENFORCE(logm <= std::ceil(std::log2(n)), "m > n is not allowed.");

  galois_elts.reserve(std::ceil(std::log2(n)));
  for (size_t i = 0; i < std::ceil(std::log2(n)); i++) {
    galois_elts.push_back((n + seal::util::exponentiate_uint(2, i)) /
                          seal::util::exponentiate_uint(2, i));
  }

  std::vector<seal::Ciphertext> results(1);
  results[0] = encrypted;
  seal::Plaintext tempPt;
  for (size_t j = 0; j < logm - 1; j++) {
    std::vector<seal::Ciphertext> results2(1 << (j + 1));
    int step = 1 << j;
    seal::Plaintext pt0(n);
    seal::Plaintext pt1(n);

    pt0.set_zero();
    pt0[n - step] = plain_mod - 1;
    std::cout << "plain_mods:" << plain_mod << std::endl;
    int index_raw = (n << 1) - (1 << j);  // -2^j
    int index = (index_raw * galois_elts[j]) % (n << 1);
    pt1.set_zero();
    pt1[index] = 1;
    std::cout << "pt0:" << pt0.to_string() << std::endl;
    std::cout << "pt1:" << pt1.to_string() << std::endl;
    // int nstep = -step;
    yacl::parallel_for(0, step, [&](int64_t begin, int64_t end) {
      for (int k = begin; k < end; k++) {
        seal::Ciphertext c0;
        seal::Ciphertext c1;
        seal::Ciphertext t0;
        seal::Ciphertext t1;

        c0 = results[k];

        // SPDLOG_INFO("apply_galois j:{} k:{}", j, k);
        evaluator_->apply_galois(c0, galois_elts[j], galkey,
                                 t0);          // t0 = Sub(c0,N/(2^i)+1)
        evaluator_->add(c0, t0, results2[k]);  // c0 + Sub(c0,N/(2^i)+1)
        // multiply_power_of_X(c0, c1, index_raw);
        evaluator_->multiply_plain(c0, pt0, c1);  // c1 = c0*(-x)^(-2j)
        evaluator_->multiply_plain(t0, pt1, t1);
        // Sub(c0,N/(2^i)+1) * x^(-2j*(N+2^i)/(2^i))=Sub(c1,N/2^j+1)
        evaluator_->add(c1, t1, results2[k + step]);
      }
    });
    results = results2;
  }

  // Last step of the loop
  std::vector<seal::Ciphertext> results2(results.size() << 1);
  seal::Plaintext two("2");

  seal::Plaintext pt0(n);
  seal::Plaintext pt1(n);

  pt0.set_zero();
  pt0[n - results.size()] = plain_mod - 1;

  int index_raw = (n << 1) - (1 << (logm - 1));
  int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
  pt1.set_zero();
  pt1[index] = 1;

  for (uint32_t k = 0; k < results.size(); k++) {
    if (k >= (m - (1 << (logm - 1)))) {  // corner case.
      evaluator_->multiply_plain(results[k], two,
                                 results2[k]);  // plain multiplication by 2.
    } else {
      seal::Ciphertext c0;
      seal::Ciphertext c1;
      seal::Ciphertext t0;
      seal::Ciphertext t1;

      c0 = results[k];
      evaluator_->apply_galois(c0, galois_elts[logm - 1], galkey, t0);
      evaluator_->add(c0, t0, results2[k]);
      // multiply_power_of_X(c0, c1, index_raw);

      evaluator_->multiply_plain(c0, pt0, c1);
      evaluator_->multiply_plain(t0, pt1, t1);
      evaluator_->add(c1, t1, results2[k + results.size()]);
    }
  }

  auto first = results2.begin();
  auto last = results2.begin() + m;
  std::vector<seal::Ciphertext> new_vec(first, last);
  return new_vec;
}

建议修改为:

std::vector<seal::Ciphertext> SealPirServer::ExpandQuery(
    const seal::Ciphertext &encrypted, std::uint32_t m) {


  seal::GaloisKeys &galkey = galois_key_;

  // Assume that m is a power of 2. If not, round it to the next power of 2.
  uint32_t logm = std::ceil(std::log2(m));

  std::vector<int> galois_elts;
  auto n = seal_params_->poly_modulus_degree();
  YACL_ENFORCE(logm <= std::ceil(std::log2(n)), "m > n is not allowed.");

  galois_elts.reserve(std::ceil(std::log2(n)));
  for (size_t i = 0; i < std::ceil(std::log2(n)); i++) {
    galois_elts.push_back((n + seal::util::exponentiate_uint(2, i)) /
                          seal::util::exponentiate_uint(2, i));
  }

  std::vector<seal::Ciphertext> results(1);
  results[0] = encrypted;
  seal::Plaintext tempPt;
  for (size_t j = 0; j < logm - 1; j++) {
    std::vector<seal::Ciphertext> results2(1 << (j + 1));
    int step = 1 << j;

    int index_raw = (n << 1) - (1 << j); 
    int index = (index_raw * galois_elts[j]) % (n << 1);

    // int nstep = -step;
    yacl::parallel_for(0, step, [&](int64_t begin, int64_t end) {
      for (int k = begin; k < end; k++) {
        seal::Ciphertext c0;
        seal::Ciphertext c1;
        seal::Ciphertext t0;
        seal::Ciphertext t1;

        c0 = results[k];
        // SPDLOG_INFO("apply_galois j:{} k:{}", j, k);
        evaluator_->apply_galois(c0, galois_elts[j], galkey,
                                 t0);          
        evaluator_->add(c0, t0, results2[k]);  
        multiply_power_of_X(c0, c1, index_raw);
        multiply_power_of_X(t0, t1, index);

        evaluator_->add(c1, t1, results2[k + step]);
      }
    });
    results = results2;
  }

  // Last step of the loop
  std::vector<seal::Ciphertext> results2(results.size() << 1);
  seal::Plaintext two("2");

  seal::Plaintext pt0(n);
  seal::Plaintext pt1(n);

  int index_raw = (n << 1) - (1 << (logm - 1));
  int index = (index_raw * galois_elts[logm - 1]) % (n << 1);


  for (uint32_t k = 0; k < results.size(); k++) {
    if (k >= (m - (1 << (logm - 1)))) {  // corner case.
      evaluator_->multiply_plain(results[k], two,
                                 results2[k]);  // plain multiplication by 2.
    } else {
      seal::Ciphertext c0;
      seal::Ciphertext c1;
      seal::Ciphertext t0;
      seal::Ciphertext t1;

      c0 = results[k];
      evaluator_->apply_galois(c0, galois_elts[logm - 1], galkey, t0);
      evaluator_->add(c0, t0, results2[k]);


      multiply_power_of_X(c0, c1, index_raw);
      multiply_power_of_X(t0, t1, index);
      evaluator_->add(c1, t1, results2[k + results.size()]);
    }
  }

  auto first = results2.begin();
  auto last = results2.begin() + m;
  std::vector<seal::Ciphertext> new_vec(first, last);
  return new_vec;
}

void SealPirServer::multiply_power_of_X(const seal::Ciphertext &encrypted,
                                        seal::Ciphertext &destination,
                                        uint32_t index) {
  auto coeff_mod_count = seal_params_->coeff_modulus().size() - 1;
  auto coeff_count = seal_params_->poly_modulus_degree();
  auto encrypted_count = encrypted.size();

  destination = encrypted;
  for (size_t i = 0; i < encrypted_count; i++) {
    for (size_t j = 0; j < coeff_mod_count; j++) {
      seal::util::negacyclic_shift_poly_coeffmod(
          encrypted.data(i) + (j * coeff_count), coeff_count, index,
          seal_params_->coeff_modulus()[j],
          destination.data(i) + (j * coeff_count));
    }
  }
}

主要原因是,multiply_plain会严重损耗seal密态计算的噪音,但是negacyclic_shift_poly_coeffmod不会导致噪音增大,并且在乘x^n时该函数具有更快的计算速度。

@Yinbenxin Yinbenxin added the bug Something isn't working label Apr 18, 2024
@Yinbenxin
Copy link
Author

为了说明这个问题,可以用以下例子进行说明:

#include <iostream>
#include "seal/seal.h"
#include "seal/util/polyarithsmallmod.h"
using namespace std;
using namespace seal;
using namespace seal::util;
inline void multiply_power_of_X(const Ciphertext &encrypted,EncryptionParameters enc_params_,
                                           Ciphertext &destination,
                                           uint32_t index) {

    auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
    auto coeff_count = enc_params_.poly_modulus_degree();
    auto encrypted_count = encrypted.size();

    destination = encrypted;

    for (int i = 0; i < encrypted_count; i++) {
        for (int j = 0; j < coeff_mod_count; j++) {
            negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
                                           coeff_count, index,
                                           enc_params_.coeff_modulus()[j],
                                           destination.data(i) + (j * coeff_count));
        }
    }
}
int main() {
    // 初始化 SEAL 库
    int N = 4096;
    EncryptionParameters parms(scheme_type::bfv);
    parms.set_poly_modulus_degree(N);
    parms.set_coeff_modulus(CoeffModulus::BFVDefault(N));
    parms.set_plain_modulus(PlainModulus::Batching(N, 20));
    auto context = SEALContext(parms);
    uint64_t plain_mod = parms.plain_modulus().value();
    // 生成密钥
    seal::PublicKey public_key;
    seal::SecretKey secret_key;
    KeyGenerator keygen(context);
    keygen.create_public_key(public_key);
    secret_key = keygen.secret_key();
    // 创建加密器
    Encryptor encryptor(context, public_key);

    // 创建一个多项式
    Plaintext plain_coefficients(N);
    plain_coefficients.set_zero();
    plain_coefficients[1] = 10;
    // 加密多项式
    Ciphertext ciphertext;
    encryptor.encrypt(plain_coefficients, ciphertext);

    // 创建一个 x^10 的明文
    Plaintext plain_power(N);
    int step = 1 << 4;
    plain_coefficients.set_zero();
    int index_raw = (N << 1) - step;
    plain_power[N - step] = plain_mod - 1;
    Evaluator evaluator(context);
    Decryptor decryptor(context, secret_key);
    Ciphertext mpfx= ciphertext;
    Ciphertext mp= ciphertext;


    for (int i = 0; i < 4; ++i) {
        Ciphertext mpfx_result;
        Ciphertext mp_result ;
        Plaintext mpfx_plaint;
        Plaintext mp_plaint;
        multiply_power_of_X(mpfx, parms, mpfx_result,index_raw);
        evaluator.multiply_plain(mp, plain_power,mp_result);
        decryptor.decrypt(mpfx_result, mpfx_plaint);
        decryptor.decrypt(mp_result, mp_plaint);

        cout << "multiply_power_of_X result: " << mpfx_plaint.to_string().substr(0,50) << endl;
        cout << "multiply_plain result: " << mp_plaint.to_string().substr(0,50) << endl;
        cout << "multiply_power_of_X 剩余可用噪音: " << decryptor.invariant_noise_budget(mpfx_result) << endl;
        cout << "multiply_plain 剩余可用噪音: " << decryptor.invariant_noise_budget(mp_result) << endl;

        mpfx= mpfx_result;
        mp= mp_result;

    }
    return 0;
}

multiply_power_of_X result: FBFF7x^4081
multiply_plain result: FBFF7x^4081
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 25
multiply_power_of_X result: FBFF7x^4065
multiply_plain result: FBFF7x^4065
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 5
multiply_power_of_X result: FBFF7x^4049
multiply_plain result: FAD3Ax^4095 + 1E1x^4094 + F0x^4093 + FBF11x^4092 +
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 0

multiply_power_of_X result: FBFF7x^4033
multiply_plain result: 5DCD2x^4095 + 4065Ex^4094 + 4065Ex^4093 + 9E32Fx^4
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 0

可以看到噪音会迅速降低,从而导致计算错误。

@Yinbenxin
Copy link
Author

之前仅仅支持8192是因为查询的总量较小,噪音并未消耗完毕,在数据量较大时会出现噪音不够所导致的计算错误问题。

@Jamie-Cui
Copy link
Collaborator

@qxzhou1010 Would you mind to take a look at this?

@qxzhou1010
Copy link
Contributor

@Yinbenxin 非常感谢您提出这个issue,并给出了优化的实现。这里我们是想在密文下计算 c1 = c0*(-x)^(-2j),由于 BFV 中多项式模采用了非常特殊的负循环多项式(x^N+1),因此这里的乘法运算本质上就是对 c0 的负循环移位操作。所以我们可以使用 negacyclic_shift_poly_coeffmod 来加速这个运算,并且这个过程对噪声消耗是零的,因为只涉及到对密文多项式一些简单的移位操作,所以并不会增加密文中所包含的噪声。

multiply_plain 是因为涉及到密文*明文,因此结果密文中的噪声项会被放大,所以每一次操作都会导致对噪声预算的消耗。

实际上,在 SealPIR 官方仓库中正是采用的这个实现。可以参考:https://github.com/microsoft/SealPIR/blob/ee1a5a3922fc9250f9bb4e2416ff5d02bfef7e52/src/pir_server.cpp#L415。

我们后续将会对这个点的实现进行优化,再次感谢您提出的问题和进行的验证。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants