From 8caf0347ab6cc0bd4b00ef56252ba4e489d6a05f Mon Sep 17 00:00:00 2001
From: huocun <131865681+huocun-ant@users.noreply.github.com>
Date: Wed, 11 Dec 2024 21:02:25 +0800
Subject: [PATCH] repo sync (#215)

---
 README.md                                  |  1 +
 benchmark/docker-compose/.env              |  3 +-
 benchmark/stats.py                         |  3 +-
 docker/entry.sh                            |  2 +-
 experiment/pir/pps/BUILD.bazel             |  1 +
 experiment/pir/pps/client.cc               | 41 ++++++++++++++--------
 experiment/pir/pps/pps_pir_benchmark.cc    | 22 ++++++------
 experiment/pir/pps/sender.cc               |  2 +-
 psi/apsi_wrapper/api/receiver_c_wrapper.cc |  2 +-
 psi/rr22/rr22_psi.h                        | 13 ++++---
 psi/sealpir/seal_pir.cc                    |  9 ++---
 psi/utils/ub_psi_cache.h                   | 11 ++++--
 12 files changed, 69 insertions(+), 41 deletions(-)

diff --git a/README.md b/README.md
index f70d7b7..b2cc375 100644
--- a/README.md
+++ b/README.md
@@ -205,6 +205,7 @@ chmod +x traceconv
 ```
 4. Open chrome://tracing in your chrome and load JSON file.
 
+
 ## PSI V2 Benchamrk
 
 Please refer to [PSI V2 Benchmark](docs/user_guide/psi_v2_benchmark.md)
diff --git a/benchmark/docker-compose/.env b/benchmark/docker-compose/.env
index c9edf13..d72d4a3 100644
--- a/benchmark/docker-compose/.env
+++ b/benchmark/docker-compose/.env
@@ -1,6 +1,5 @@
-# OPENSOURCE-CLEANUP GSUB psi:latest secretflow/psi:latest
 # docker env
-IMAGE_WITH_TAG=secretflow/psi-anolis8:0.4.2b0
+IMAGE_WITH_TAG=secretflow/psi:latest
 
 # network env
 # LATENCY=10ms
diff --git a/benchmark/stats.py b/benchmark/stats.py
index 8f4bc16..84b1040 100644
--- a/benchmark/stats.py
+++ b/benchmark/stats.py
@@ -16,6 +16,7 @@
 import json
 import csv
 import sys
+import os
 import time
 from datetime import datetime
 
@@ -40,7 +41,7 @@ def stream_container_stats(container_name, output_file):
                 data = json.loads(stats)
                 running_time_s = int(time.time()) - start_unix_time
                 cpu_percent = ((data['cpu_stats']['cpu_usage']['total_usage'] - prev_cpu_total) /
-                           (data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100
+                           (data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100 * os.cpu_count()
                 mem_usage = (data['memory_stats']['usage'] - data['memory_stats']['stats']['inactive_file']) / 1024 / 1024
                 mem_limit = data['memory_stats']['limit'] / 1024 / 1024
                 net_tx = 0
diff --git a/docker/entry.sh b/docker/entry.sh
index e53ae0a..115a42b 100755
--- a/docker/entry.sh
+++ b/docker/entry.sh
@@ -7,7 +7,7 @@ cd src_copied
 
 conda install -y perl=5.20.3.1
 
-bazel build psi:main -c opt --config=linux-release --repository_cache=/tmp/bazel_repo_cache
+bazel build psi:main -c opt --config=linux-release --remote_timeout=300s --remote_retries=10
 chmod 777 bazel-bin/psi/main
 mkdir -p ../src/docker/linux/amd64
 cp bazel-bin/psi/main ../src/docker/linux/amd64
diff --git a/experiment/pir/pps/BUILD.bazel b/experiment/pir/pps/BUILD.bazel
index a221070..e8738ae 100644
--- a/experiment/pir/pps/BUILD.bazel
+++ b/experiment/pir/pps/BUILD.bazel
@@ -34,6 +34,7 @@ psi_cc_library(
     deps = [
         ":ggm_pset",
         "@yacl//yacl/base:dynamic_bitset",
+        "@yacl//yacl/base:exception",
         "@yacl//yacl/crypto/rand",
         "@yacl//yacl/crypto/tools:prg",
     ],
diff --git a/experiment/pir/pps/client.cc b/experiment/pir/pps/client.cc
index b9c34b8..cbc0f71 100644
--- a/experiment/pir/pps/client.cc
+++ b/experiment/pir/pps/client.cc
@@ -16,6 +16,8 @@
 
 #include <spdlog/spdlog.h>
 
+#include "yacl/base/exception.h"
+
 namespace pir::pps {
 
 bool PpsPirClient::Bernoulli() {
@@ -34,17 +36,22 @@ uint64_t PpsPirClient::GetRandomU64Less() {
 // Generate sk and m random numbers \in [n]
 void PpsPirClient::Setup(PIRKey& sk, std::set<uint64_t>& deltas) {
   sk = pps_.Gen(lambda_);
+
+  size_t max_try_count = 10 * M();
+  size_t count = 0;
+
   // The map.size() must be equal to SET_SIZE.
-  std::vector<uint64_t> rand =
-      yacl::crypto::PrgAesCtr<uint64_t>(yacl::crypto::RandU64(), M());
-  for (uint64_t i = 0; i < M(); i++) {
-    // The most expensive operation.
-    uint64_t r = LemireTrick(rand[i], universe_size_);
+  size_t i = 0;
+  while (i < M() && count < max_try_count) {
+    count += 1;
+    uint64_t r = LemireTrick(yacl::crypto::RandU64(), universe_size_);
     if (!deltas.insert(r).second) {
-      rand[i] = yacl::crypto::RandU64();
-      i--;
+      continue;
     }
+    ++i;
   }
+
+  YACL_ENFORCE(count < max_try_count);
 }
 
 // Params:
@@ -91,18 +98,24 @@ void PpsPirClient::Setup(std::vector<PIRKeyUnion>& ck,
                          std::vector<std::unordered_set<uint64_t>>& v) {
   ck.resize(MM());
   v.resize(MM());
-  std::vector<uint128_t> rand =
-      yacl::crypto::PrgAesCtr<uint128_t>(yacl::crypto::RandU128(), MM());
-  for (uint64_t i = 0; i < MM(); ++i) {
-    pps_.Eval(rand[i], v[i]);
+
+  size_t max_try_count = 10 * MM();
+  size_t count = 0;
+
+  size_t i = 0;
+  while (i < MM() && count < max_try_count) {
+    count += 1;
+    auto rand = yacl::crypto::RandU128();
+    pps_.Eval(rand, v[i]);
     if (v[i].size() == set_size_) {
-      ck[i] = PIRKeyUnion(rand[i]);
+      ck[i] = PIRKeyUnion(rand);
     } else {
       v[i].clear();
-      rand[i] = yacl::crypto::RandU128();
-      --i;
+      continue;
     }
+    ++i;
   }
+  YACL_ENFORCE(count < max_try_count);
 }
 
 void PpsPirClient::Query(uint64_t i, std::vector<PIRKeyUnion>& ck,
diff --git a/experiment/pir/pps/pps_pir_benchmark.cc b/experiment/pir/pps/pps_pir_benchmark.cc
index 016f153..c9daed9 100644
--- a/experiment/pir/pps/pps_pir_benchmark.cc
+++ b/experiment/pir/pps/pps_pir_benchmark.cc
@@ -46,15 +46,15 @@ static void BM_PpsSingleBitPir(benchmark::State& state) {
     pir::pps::PpsPirServer pirOfflineServer(n * n, n);
     pir::pps::PpsPirServer pirOnlineServer(n * n, n);
 
-    pir::pps::PIRKey pirKey, pirKeyOffline;
-    pir::pps::PIRQueryParam pirQueryParam;
-    pir::pps::PIRPuncKey pirPuncKey, pirPuncKeyOnline;
-    std::set<uint64_t> deltas, deltasOffline;
+    pir::pps::PIRKey pirKey{}, pirKeyOffline{};
+    pir::pps::PIRQueryParam pirQueryParam{};
+    pir::pps::PIRPuncKey pirPuncKey{}, pirPuncKeyOnline{};
+    std::set<uint64_t> deltas{}, deltasOffline{};
     yacl::dynamic_bitset<> bits;
     GenerateRandomBitString(bits, n * n);
     yacl::dynamic_bitset<> h, hOffline;
     uint64_t query_index = pirClient.GetRandomU64Less();
-    bool query_result;
+    bool query_result{};
 
     constexpr int kWorldSize = 2;
     const auto contextsOffline = yacl::link::test::SetupWorld(kWorldSize);
@@ -102,7 +102,7 @@ static void BM_PpsSingleBitPir(benchmark::State& state) {
     recver_future.get();
 
     bool a = pirOnlineServer.Answer(pirPuncKeyOnline, bits);
-    bool aClient;
+    bool aClient{};
 
     sender_future =
         std::async(std::launch::async, pir::pps::OnlineServerSendToClient,
@@ -129,13 +129,13 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) {
     pir::pps::PpsPirServer pirOfflineServer(n * n, n);
     pir::pps::PpsPirServer pirOnlineServer(n * n, n);
 
-    std::vector<pir::pps::PIRKeyUnion> pirKey, pirKeyOffline;
+    std::vector<pir::pps::PIRKeyUnion> pirKey{}, pirKeyOffline{};
     yacl::dynamic_bitset<> bits;
     GenerateRandomBitString(bits, n * n);
     yacl::dynamic_bitset<> h, hOffline;
-    pir::pps::PIRQueryParam pirParam;
+    pir::pps::PIRQueryParam pirParam{};
 
-    bool aLeft, aRight, aLeftOnline, aRightOnline, queryResult;
+    bool aLeft{}, aRight{}, aLeftOnline{}, aRightOnline{}, queryResult{};
     std::vector<std::unordered_set<uint64_t>> v;
 
     constexpr int kWorldSize = 2;
@@ -170,8 +170,8 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) {
     recver_future.get();
 
     for (uint i = 0; i < n * n; ++i) {
-      pir::pps::PIRPuncKey pirPuncKeyL, pirPuncKeyR;
-      pir::pps::PIRPuncKey pirPuncKeyLOnline, pirPuncKeyROnline;
+      pir::pps::PIRPuncKey pirPuncKeyL{}, pirPuncKeyR{};
+      pir::pps::PIRPuncKey pirPuncKeyLOnline{}, pirPuncKeyROnline{};
 
       pirClient.Query(i, pirKey, v, pirParam, pirPuncKeyL, pirPuncKeyR);
 
diff --git a/experiment/pir/pps/sender.cc b/experiment/pir/pps/sender.cc
index 6bd7341..66a9413 100644
--- a/experiment/pir/pps/sender.cc
+++ b/experiment/pir/pps/sender.cc
@@ -22,7 +22,7 @@
 
 namespace pir::pps {
 std::array<std::byte, 16> Uint128_to_bytes(PIRKey sk) {
-  std::array<std::byte, 16> bytes;
+  std::array<std::byte, 16> bytes{};
   uint64_t high = static_cast<uint64_t>(sk >> 64);
   uint64_t low = static_cast<uint64_t>(sk & 0xFFFFFFFFFFFFFFFF);
   std::memcpy(bytes.data(), &high, sizeof(high));
diff --git a/psi/apsi_wrapper/api/receiver_c_wrapper.cc b/psi/apsi_wrapper/api/receiver_c_wrapper.cc
index 5d7e709..358adfe 100644
--- a/psi/apsi_wrapper/api/receiver_c_wrapper.cc
+++ b/psi/apsi_wrapper/api/receiver_c_wrapper.cc
@@ -36,7 +36,7 @@ Receiver* BucketReceiverMake(size_t bucket_cnt, size_t thread_count) {
 }
 
 void BucketReceiverFree(Receiver** receiver) {
-  if (receiver != nullptr || *receiver == nullptr) {
+  if (receiver == nullptr || *receiver == nullptr) {
     return;
   }
   (void)std::unique_ptr<ApiReceiver>(reinterpret_cast<ApiReceiver*>(*receiver));
diff --git a/psi/rr22/rr22_psi.h b/psi/rr22/rr22_psi.h
index c538472..6375bac 100644
--- a/psi/rr22/rr22_psi.h
+++ b/psi/rr22/rr22_psi.h
@@ -257,13 +257,18 @@ class Rr22Runner {
       futures[i] = std::async(
           std::launch::async,
           [&](size_t thread_idx) {
+            std::shared_ptr<yacl::link::Context> spawn_read_lctx =
+                read_lctx_->Spawn(std::to_string(thread_idx));
+            std::shared_ptr<yacl::link::Context> spawn_run_lctx =
+                run_lctx_->Spawn(std::to_string(thread_idx));
+            std::shared_ptr<yacl::link::Context> spawn_intersection_lctx =
+                intersection_lctx_->Spawn(std::to_string(thread_idx));
             for (size_t j = 0; j < bucket_num_; j++) {
               if (j % parallel_num == thread_idx) {
                 auto runner = CreateBucketRunner(j, is_sender);
-                runner->Prepare(read_lctx_->Spawn(std::to_string(thread_idx)));
-                runner->RunOprf(run_lctx_->Spawn(std::to_string(thread_idx)));
-                runner->GetIntersection(
-                    intersection_lctx_->Spawn(std::to_string(thread_idx)));
+                runner->Prepare(spawn_read_lctx);
+                runner->RunOprf(spawn_run_lctx);
+                runner->GetIntersection(spawn_intersection_lctx);
               }
             }
           },
diff --git a/psi/sealpir/seal_pir.cc b/psi/sealpir/seal_pir.cc
index 69b9ee1..3c6a062 100644
--- a/psi/sealpir/seal_pir.cc
+++ b/psi/sealpir/seal_pir.cc
@@ -52,6 +52,7 @@ uint32_t ComputeExpansionRatio(seal::EncryptionParameters params) {
     double logqi = log2(params.coeff_modulus()[i].value());
     expansion_ratio += ceil(logqi / logt);
   }
+  YACL_ENFORCE(expansion_ratio > 0, "expansion_ratio must be greater than 0");
   return expansion_ratio;
 }
 uint64_t CoefficientsPerElement(uint32_t logt, uint64_t ele_size) {
@@ -169,7 +170,7 @@ vector<seal::Plaintext> DecomposeToPlaintexts(seal::EncryptionParameters params,
   const auto N = params.poly_modulus_degree();
   const auto coeff_mod_count = params.coeff_modulus().size();
   const uint32_t logt = log2(params.plain_modulus().value());
-  const uint64_t pt_bitmask = (1 << logt) - 1;
+  const uint64_t pt_bitmask = (1ULL << logt) - 1;
 
   vector<seal::Plaintext> result(ComputeExpansionRatio(params) * ct.size());
   auto pt_iter = result.begin();
@@ -750,7 +751,7 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery(
 
   for (uint32_t i = 0; i < logm - 1; ++i) {
     vector<Ciphertext> new_tmp(tmp.size() << 1);
-    int index_raw = (N << 1) - (1 << i);
+    int index_raw = (N << 1) - (1ULL << i);
     int index = (index_raw + N) % (N << 1);
     // int index = (index_raw * galelts[i]) % (N << 1);
 
@@ -768,13 +769,13 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery(
   }
 
   vector<Ciphertext> new_tmp(tmp.size() << 1);
-  int index_raw = (N << 1) - (1 << (logm - 1));
+  int index_raw = (N << 1) - (1ULL << (logm - 1));
   int index = (index_raw + N) % (N << 1);
   // int index = (index_raw * galelts[logm - 1]) % (N << 1);
   Plaintext two("2");
 
   for (uint32_t j = 0; j < tmp.size(); ++j) {
-    if (j < (m - (1 << (logm - 1)))) {
+    if (j < (m - (1ULL << (logm - 1)))) {
       evaluator_->apply_galois(tmp[j], galelts[logm - 1], galkey,
                                tmpctxt_rotated);
       evaluator_->add(tmp[j], tmpctxt_rotated, new_tmp[j]);
diff --git a/psi/utils/ub_psi_cache.h b/psi/utils/ub_psi_cache.h
index 28b5f34..e85b92a 100644
--- a/psi/utils/ub_psi_cache.h
+++ b/psi/utils/ub_psi_cache.h
@@ -21,6 +21,7 @@
 #include <utility>
 #include <vector>
 
+#include "spdlog/spdlog.h"
 #include "yacl/base/byte_container_view.h"
 
 #include "psi/utils/batch_provider.h"
@@ -90,8 +91,14 @@ class UbPsiCache : public IUbPsiCache {
              std::vector<uint8_t> private_key);
 
   ~UbPsiCache() {
-    Flush();
-    out_stream_->Close();
+    try {
+      Flush();
+      if (out_stream_) {
+        out_stream_->Close();
+      }
+    } catch (const std::exception& e) {
+      SPDLOG_ERROR("UbPsiCache flush failed: {}", e.what());
+    }
   }
 
   void SaveData(yacl::ByteContainerView item, size_t index,