From 8caf0347ab6cc0bd4b00ef56252ba4e489d6a05f Mon Sep 17 00:00:00 2001 From: huocun <131865681+huocun-ant@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:02:25 +0800 Subject: [PATCH] repo sync (#215) --- README.md | 1 + benchmark/docker-compose/.env | 3 +- benchmark/stats.py | 3 +- docker/entry.sh | 2 +- experiment/pir/pps/BUILD.bazel | 1 + experiment/pir/pps/client.cc | 41 ++++++++++++++-------- experiment/pir/pps/pps_pir_benchmark.cc | 22 ++++++------ experiment/pir/pps/sender.cc | 2 +- psi/apsi_wrapper/api/receiver_c_wrapper.cc | 2 +- psi/rr22/rr22_psi.h | 13 ++++--- psi/sealpir/seal_pir.cc | 9 ++--- psi/utils/ub_psi_cache.h | 11 ++++-- 12 files changed, 69 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index f70d7b7..b2cc375 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,7 @@ chmod +x traceconv ``` 4. Open chrome://tracing in your chrome and load JSON file. + ## PSI V2 Benchamrk Please refer to [PSI V2 Benchmark](docs/user_guide/psi_v2_benchmark.md) diff --git a/benchmark/docker-compose/.env b/benchmark/docker-compose/.env index c9edf13..d72d4a3 100644 --- a/benchmark/docker-compose/.env +++ b/benchmark/docker-compose/.env @@ -1,6 +1,5 @@ -# OPENSOURCE-CLEANUP GSUB psi:latest secretflow/psi:latest # docker env -IMAGE_WITH_TAG=secretflow/psi-anolis8:0.4.2b0 +IMAGE_WITH_TAG=secretflow/psi:latest # network env # LATENCY=10ms diff --git a/benchmark/stats.py b/benchmark/stats.py index 8f4bc16..84b1040 100644 --- a/benchmark/stats.py +++ b/benchmark/stats.py @@ -16,6 +16,7 @@ import json import csv import sys +import os import time from datetime import datetime @@ -40,7 +41,7 @@ def stream_container_stats(container_name, output_file): data = json.loads(stats) running_time_s = int(time.time()) - start_unix_time cpu_percent = ((data['cpu_stats']['cpu_usage']['total_usage'] - prev_cpu_total) / - (data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100 + (data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100 * os.cpu_count() mem_usage = (data['memory_stats']['usage'] - data['memory_stats']['stats']['inactive_file']) / 1024 / 1024 mem_limit = data['memory_stats']['limit'] / 1024 / 1024 net_tx = 0 diff --git a/docker/entry.sh b/docker/entry.sh index e53ae0a..115a42b 100755 --- a/docker/entry.sh +++ b/docker/entry.sh @@ -7,7 +7,7 @@ cd src_copied conda install -y perl=5.20.3.1 -bazel build psi:main -c opt --config=linux-release --repository_cache=/tmp/bazel_repo_cache +bazel build psi:main -c opt --config=linux-release --remote_timeout=300s --remote_retries=10 chmod 777 bazel-bin/psi/main mkdir -p ../src/docker/linux/amd64 cp bazel-bin/psi/main ../src/docker/linux/amd64 diff --git a/experiment/pir/pps/BUILD.bazel b/experiment/pir/pps/BUILD.bazel index a221070..e8738ae 100644 --- a/experiment/pir/pps/BUILD.bazel +++ b/experiment/pir/pps/BUILD.bazel @@ -34,6 +34,7 @@ psi_cc_library( deps = [ ":ggm_pset", "@yacl//yacl/base:dynamic_bitset", + "@yacl//yacl/base:exception", "@yacl//yacl/crypto/rand", "@yacl//yacl/crypto/tools:prg", ], diff --git a/experiment/pir/pps/client.cc b/experiment/pir/pps/client.cc index b9c34b8..cbc0f71 100644 --- a/experiment/pir/pps/client.cc +++ b/experiment/pir/pps/client.cc @@ -16,6 +16,8 @@ #include <spdlog/spdlog.h> +#include "yacl/base/exception.h" + namespace pir::pps { bool PpsPirClient::Bernoulli() { @@ -34,17 +36,22 @@ uint64_t PpsPirClient::GetRandomU64Less() { // Generate sk and m random numbers \in [n] void PpsPirClient::Setup(PIRKey& sk, std::set<uint64_t>& deltas) { sk = pps_.Gen(lambda_); + + size_t max_try_count = 10 * M(); + size_t count = 0; + // The map.size() must be equal to SET_SIZE. - std::vector<uint64_t> rand = - yacl::crypto::PrgAesCtr<uint64_t>(yacl::crypto::RandU64(), M()); - for (uint64_t i = 0; i < M(); i++) { - // The most expensive operation. - uint64_t r = LemireTrick(rand[i], universe_size_); + size_t i = 0; + while (i < M() && count < max_try_count) { + count += 1; + uint64_t r = LemireTrick(yacl::crypto::RandU64(), universe_size_); if (!deltas.insert(r).second) { - rand[i] = yacl::crypto::RandU64(); - i--; + continue; } + ++i; } + + YACL_ENFORCE(count < max_try_count); } // Params: @@ -91,18 +98,24 @@ void PpsPirClient::Setup(std::vector<PIRKeyUnion>& ck, std::vector<std::unordered_set<uint64_t>>& v) { ck.resize(MM()); v.resize(MM()); - std::vector<uint128_t> rand = - yacl::crypto::PrgAesCtr<uint128_t>(yacl::crypto::RandU128(), MM()); - for (uint64_t i = 0; i < MM(); ++i) { - pps_.Eval(rand[i], v[i]); + + size_t max_try_count = 10 * MM(); + size_t count = 0; + + size_t i = 0; + while (i < MM() && count < max_try_count) { + count += 1; + auto rand = yacl::crypto::RandU128(); + pps_.Eval(rand, v[i]); if (v[i].size() == set_size_) { - ck[i] = PIRKeyUnion(rand[i]); + ck[i] = PIRKeyUnion(rand); } else { v[i].clear(); - rand[i] = yacl::crypto::RandU128(); - --i; + continue; } + ++i; } + YACL_ENFORCE(count < max_try_count); } void PpsPirClient::Query(uint64_t i, std::vector<PIRKeyUnion>& ck, diff --git a/experiment/pir/pps/pps_pir_benchmark.cc b/experiment/pir/pps/pps_pir_benchmark.cc index 016f153..c9daed9 100644 --- a/experiment/pir/pps/pps_pir_benchmark.cc +++ b/experiment/pir/pps/pps_pir_benchmark.cc @@ -46,15 +46,15 @@ static void BM_PpsSingleBitPir(benchmark::State& state) { pir::pps::PpsPirServer pirOfflineServer(n * n, n); pir::pps::PpsPirServer pirOnlineServer(n * n, n); - pir::pps::PIRKey pirKey, pirKeyOffline; - pir::pps::PIRQueryParam pirQueryParam; - pir::pps::PIRPuncKey pirPuncKey, pirPuncKeyOnline; - std::set<uint64_t> deltas, deltasOffline; + pir::pps::PIRKey pirKey{}, pirKeyOffline{}; + pir::pps::PIRQueryParam pirQueryParam{}; + pir::pps::PIRPuncKey pirPuncKey{}, pirPuncKeyOnline{}; + std::set<uint64_t> deltas{}, deltasOffline{}; yacl::dynamic_bitset<> bits; GenerateRandomBitString(bits, n * n); yacl::dynamic_bitset<> h, hOffline; uint64_t query_index = pirClient.GetRandomU64Less(); - bool query_result; + bool query_result{}; constexpr int kWorldSize = 2; const auto contextsOffline = yacl::link::test::SetupWorld(kWorldSize); @@ -102,7 +102,7 @@ static void BM_PpsSingleBitPir(benchmark::State& state) { recver_future.get(); bool a = pirOnlineServer.Answer(pirPuncKeyOnline, bits); - bool aClient; + bool aClient{}; sender_future = std::async(std::launch::async, pir::pps::OnlineServerSendToClient, @@ -129,13 +129,13 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) { pir::pps::PpsPirServer pirOfflineServer(n * n, n); pir::pps::PpsPirServer pirOnlineServer(n * n, n); - std::vector<pir::pps::PIRKeyUnion> pirKey, pirKeyOffline; + std::vector<pir::pps::PIRKeyUnion> pirKey{}, pirKeyOffline{}; yacl::dynamic_bitset<> bits; GenerateRandomBitString(bits, n * n); yacl::dynamic_bitset<> h, hOffline; - pir::pps::PIRQueryParam pirParam; + pir::pps::PIRQueryParam pirParam{}; - bool aLeft, aRight, aLeftOnline, aRightOnline, queryResult; + bool aLeft{}, aRight{}, aLeftOnline{}, aRightOnline{}, queryResult{}; std::vector<std::unordered_set<uint64_t>> v; constexpr int kWorldSize = 2; @@ -170,8 +170,8 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) { recver_future.get(); for (uint i = 0; i < n * n; ++i) { - pir::pps::PIRPuncKey pirPuncKeyL, pirPuncKeyR; - pir::pps::PIRPuncKey pirPuncKeyLOnline, pirPuncKeyROnline; + pir::pps::PIRPuncKey pirPuncKeyL{}, pirPuncKeyR{}; + pir::pps::PIRPuncKey pirPuncKeyLOnline{}, pirPuncKeyROnline{}; pirClient.Query(i, pirKey, v, pirParam, pirPuncKeyL, pirPuncKeyR); diff --git a/experiment/pir/pps/sender.cc b/experiment/pir/pps/sender.cc index 6bd7341..66a9413 100644 --- a/experiment/pir/pps/sender.cc +++ b/experiment/pir/pps/sender.cc @@ -22,7 +22,7 @@ namespace pir::pps { std::array<std::byte, 16> Uint128_to_bytes(PIRKey sk) { - std::array<std::byte, 16> bytes; + std::array<std::byte, 16> bytes{}; uint64_t high = static_cast<uint64_t>(sk >> 64); uint64_t low = static_cast<uint64_t>(sk & 0xFFFFFFFFFFFFFFFF); std::memcpy(bytes.data(), &high, sizeof(high)); diff --git a/psi/apsi_wrapper/api/receiver_c_wrapper.cc b/psi/apsi_wrapper/api/receiver_c_wrapper.cc index 5d7e709..358adfe 100644 --- a/psi/apsi_wrapper/api/receiver_c_wrapper.cc +++ b/psi/apsi_wrapper/api/receiver_c_wrapper.cc @@ -36,7 +36,7 @@ Receiver* BucketReceiverMake(size_t bucket_cnt, size_t thread_count) { } void BucketReceiverFree(Receiver** receiver) { - if (receiver != nullptr || *receiver == nullptr) { + if (receiver == nullptr || *receiver == nullptr) { return; } (void)std::unique_ptr<ApiReceiver>(reinterpret_cast<ApiReceiver*>(*receiver)); diff --git a/psi/rr22/rr22_psi.h b/psi/rr22/rr22_psi.h index c538472..6375bac 100644 --- a/psi/rr22/rr22_psi.h +++ b/psi/rr22/rr22_psi.h @@ -257,13 +257,18 @@ class Rr22Runner { futures[i] = std::async( std::launch::async, [&](size_t thread_idx) { + std::shared_ptr<yacl::link::Context> spawn_read_lctx = + read_lctx_->Spawn(std::to_string(thread_idx)); + std::shared_ptr<yacl::link::Context> spawn_run_lctx = + run_lctx_->Spawn(std::to_string(thread_idx)); + std::shared_ptr<yacl::link::Context> spawn_intersection_lctx = + intersection_lctx_->Spawn(std::to_string(thread_idx)); for (size_t j = 0; j < bucket_num_; j++) { if (j % parallel_num == thread_idx) { auto runner = CreateBucketRunner(j, is_sender); - runner->Prepare(read_lctx_->Spawn(std::to_string(thread_idx))); - runner->RunOprf(run_lctx_->Spawn(std::to_string(thread_idx))); - runner->GetIntersection( - intersection_lctx_->Spawn(std::to_string(thread_idx))); + runner->Prepare(spawn_read_lctx); + runner->RunOprf(spawn_run_lctx); + runner->GetIntersection(spawn_intersection_lctx); } } }, diff --git a/psi/sealpir/seal_pir.cc b/psi/sealpir/seal_pir.cc index 69b9ee1..3c6a062 100644 --- a/psi/sealpir/seal_pir.cc +++ b/psi/sealpir/seal_pir.cc @@ -52,6 +52,7 @@ uint32_t ComputeExpansionRatio(seal::EncryptionParameters params) { double logqi = log2(params.coeff_modulus()[i].value()); expansion_ratio += ceil(logqi / logt); } + YACL_ENFORCE(expansion_ratio > 0, "expansion_ratio must be greater than 0"); return expansion_ratio; } uint64_t CoefficientsPerElement(uint32_t logt, uint64_t ele_size) { @@ -169,7 +170,7 @@ vector<seal::Plaintext> DecomposeToPlaintexts(seal::EncryptionParameters params, const auto N = params.poly_modulus_degree(); const auto coeff_mod_count = params.coeff_modulus().size(); const uint32_t logt = log2(params.plain_modulus().value()); - const uint64_t pt_bitmask = (1 << logt) - 1; + const uint64_t pt_bitmask = (1ULL << logt) - 1; vector<seal::Plaintext> result(ComputeExpansionRatio(params) * ct.size()); auto pt_iter = result.begin(); @@ -750,7 +751,7 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery( for (uint32_t i = 0; i < logm - 1; ++i) { vector<Ciphertext> new_tmp(tmp.size() << 1); - int index_raw = (N << 1) - (1 << i); + int index_raw = (N << 1) - (1ULL << i); int index = (index_raw + N) % (N << 1); // int index = (index_raw * galelts[i]) % (N << 1); @@ -768,13 +769,13 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery( } vector<Ciphertext> new_tmp(tmp.size() << 1); - int index_raw = (N << 1) - (1 << (logm - 1)); + int index_raw = (N << 1) - (1ULL << (logm - 1)); int index = (index_raw + N) % (N << 1); // int index = (index_raw * galelts[logm - 1]) % (N << 1); Plaintext two("2"); for (uint32_t j = 0; j < tmp.size(); ++j) { - if (j < (m - (1 << (logm - 1)))) { + if (j < (m - (1ULL << (logm - 1)))) { evaluator_->apply_galois(tmp[j], galelts[logm - 1], galkey, tmpctxt_rotated); evaluator_->add(tmp[j], tmpctxt_rotated, new_tmp[j]); diff --git a/psi/utils/ub_psi_cache.h b/psi/utils/ub_psi_cache.h index 28b5f34..e85b92a 100644 --- a/psi/utils/ub_psi_cache.h +++ b/psi/utils/ub_psi_cache.h @@ -21,6 +21,7 @@ #include <utility> #include <vector> +#include "spdlog/spdlog.h" #include "yacl/base/byte_container_view.h" #include "psi/utils/batch_provider.h" @@ -90,8 +91,14 @@ class UbPsiCache : public IUbPsiCache { std::vector<uint8_t> private_key); ~UbPsiCache() { - Flush(); - out_stream_->Close(); + try { + Flush(); + if (out_stream_) { + out_stream_->Close(); + } + } catch (const std::exception& e) { + SPDLOG_ERROR("UbPsiCache flush failed: {}", e.what()); + } } void SaveData(yacl::ByteContainerView item, size_t index,