From b21a3feecee94e561081622a01f2947ae2c0df7b Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Wed, 21 Aug 2024 22:57:48 +0800 Subject: [PATCH 1/5] add changelog and ci (#3) * add changelog * add workflow * use vcpkg as submodule * fix ut --- .github/stale.yml | 17 + .github/workflows/cpp_workflow.yml | 75 ++++ .gitmodules | 3 + ChangeLog.md | 21 ++ README.md | 26 +- cpp/CMakeLists.txt | 2 +- cpp/cmake/InstallBcosUtilities.cmake | 2 +- cpp/cmake/Options.cmake | 4 + cpp/ppc-crypto-core/src/CMakeLists.txt | 3 +- cpp/ppc-gateway/test/unittests/MockCache.h | 104 +++--- cpp/ppc-mpc/CMakeLists.txt | 10 +- cpp/ppc-pir/tests/CMakeLists.txt | 2 +- cpp/ppc-pir/tests/FakeOtPIRFactory.h | 72 ++-- cpp/ppc-pir/tests/TestBaseOT.cpp | 390 +++++++++++---------- cpp/ppc-pir/tests/data/AysPreDataset.csv | 3 + cpp/ppc-psi/tests/CMakeLists.txt | 2 +- cpp/vcpkg | 1 + cpp/vcpkg-configuration.json | 2 +- cpp/vcpkg.json | 9 +- 19 files changed, 450 insertions(+), 298 deletions(-) create mode 100644 .github/stale.yml create mode 100644 .github/workflows/cpp_workflow.yml create mode 100644 .gitmodules create mode 100644 ChangeLog.md create mode 100644 cpp/ppc-pir/tests/data/AysPreDataset.csv create mode 160000 cpp/vcpkg diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 00000000..e6e32871 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,17 @@ +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 120 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - pinned + - security +# Label to use when marking an issue as stale +staleLabel: wontfix +# Comment to post when marking an issue as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. +# Comment to post when closing a stale issue. Set to `false` to disable +closeComment: false diff --git a/.github/workflows/cpp_workflow.yml b/.github/workflows/cpp_workflow.yml new file mode 100644 index 00000000..bdec90b6 --- /dev/null +++ b/.github/workflows/cpp_workflow.yml @@ -0,0 +1,75 @@ +name: WeDPR-Component ci(cpp) +on: + push: + paths-ignore: + - "docs/**" + - "Changelog.md" + - "README.md" + pull_request: + paths-ignore: + - "docs/**" + - "Changelog.md" + - "README.md" + release: + types: [published, push] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + + +jobs: + build: + name: build + runs-on: ${{ matrix.os }} + continue-on-error: true + strategy: + fail-fast: false + matrix: + os: [macos-12, ubuntu-20.04] + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 5 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-2022-07-28 + override: true + - name: Prepare vcpkg + if: runner.os != 'Windows' + uses: friendlyanon/setup-vcpkg@v1 + with: { committish: 51b14cd4e1230dd51c11ffeff6f7d53c61cc5297 } + - uses: actions/cache@v2 + id: deps_cache + with: + path: | + deps/ + c:/vcpkg + !c:/vcpkg/.git + !c:/vcpkg/buildtrees + !c:/vcpkg/packages + !c:/vcpkg/downloads + key: build-${{ matrix.os }}-${{ github.base_ref }}-${{ hashFiles('.github/workflows/workflow.yml') }} + restore-keys: | + build-${{ matrix.os }}-${{ github.base_ref }}-${{ hashFiles('.github/workflows/workflow.yml') }} + build-${{ matrix.os }}-${{ github.base_ref }}- + build-${{ matrix.os }}- + - name: Build for linux + if: runner.os == 'Linux' + run: | + sudo apt install -y lcov ccache wget libgmp-dev python3-dev + export GCC='gcc-10' + export CXX='g++-10' + bash cpp/tools/install_depends.sh -o ubuntu + mkdir -p cpp/build && cd cpp/build && cmake -DTESTS=ON -DCOVERAGE=ON -DCMAKE_TOOLCHAIN_FILE=${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake ../ + make -j3 + - name: Build for macos + if: runner.os == 'macOS' + run: | + bash cpp/tools/install_depends.sh -o macos + mkdir -p cpp/build && cd cpp/build && cmake -DTESTS=ON -DCOVERAGE=ON -DCMAKE_TOOLCHAIN_FILE=${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake ../ + make -j3 + - name: Test + if: runner.os != 'Windows' + run: | + cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE ctest + make cov \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..b647ee05 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cpp/vcpkg"] + path = cpp/vcpkg + url = https://github.com/microsoft/vcpkg diff --git a/ChangeLog.md b/ChangeLog.md new file mode 100644 index 00000000..8b7b54d4 --- /dev/null +++ b/ChangeLog.md @@ -0,0 +1,21 @@ +### 1.0.0-rc1 +(2024-08-21) + +**新增** + +#### 隐私计算组件 + +- **PSI**: 实现多种类型隐私求交集算法,包括CM2020, RA2018, ECDH-PSI, ECDH-Multi-PSI等 +- **MPC**: 安全多方计算相关组件 +- **MPCSQL**: 基于安全多方计算协议,支持联合分析查询任务 +- **PIR**: 匿踪查询组件 + +#### 隐私计算互联互通 + +- ECDH PSI算法与隐语互联互通 + +#### 隐私建模组件 + +- 2+方的多方联合XGB组件(训练 + 离线预测) +- 2+方多方特征工程组件(特征分箱,WOE/IV计算等) +- 预处理组件 \ No newline at end of file diff --git a/README.md b/README.md index b984af1d..8bdd4c71 100644 --- a/README.md +++ b/README.md @@ -1 +1,25 @@ -# WeDPR-Component \ No newline at end of file +# WeDPR + +![](https://wedpr-lab.readthedocs.io/zh_CN/latest/_static/images/wedpr_logo.png) + + +[![CodeFactor](https://www.codefactor.io/repository/github/webankblockchain/wedpr-component/badge?s=a4c3fb6ffd39e7618378fe13b6bd06c5846cc103)](https://www.codefactor.io/repository/github/webankblockchain/wedpr-component) +[![contributors](https://img.shields.io/github/contributors/WeBankBlockchain/WeDPR)](https://github.com/WeBankBlockchain/WeDPR-Component/graphs/contributors) +[![GitHub activity](https://img.shields.io/github/commit-activity/m/WeBankBlockchain/WeDPR-Component)](https://github.com/WeBankBlockchain/WeDPR-Component/pulse) +[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](http://makeapullrequest.com) + +微众银行多方大数据隐私计算平台[WeDPR](https://github.com/WeBankBlockchain/WeDPR)核心组件库,包括: + +- 丰富的隐私计算算法组件,支持隐私求交集、匿踪查询、多方联合分析(隐私SQL)、数据预处理、隐私建模和预测等能力,以满足多样化的业务场景 +- 统一网关:稳定的跨机构网络通信组件,支持基于Rip协议的路由转发协议 + + +## 技术文档 + +- [文档](https://wedpr-lab.readthedocs.io/zh-cn/latest/) +- [代码](https://github.com/WeBankBlockchain/WeDPR-Component) + + +## License + +WeDPR-Component的开源协议为Apache License 2.0, 详情参见[LICENSE](LICENSE)。 \ No newline at end of file diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 42aef4a7..f15fcba6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -55,6 +55,7 @@ include(BuildInfoGenerator) find_package(OpenSSL REQUIRED) include(InstallBcosUtilities) +find_package(Boost COMPONENTS unit_test_framework) if(BUILD_SDK) ##### the sdk-dependencies ##### # find JNI @@ -89,7 +90,6 @@ if(NOT BUILD_SDK AND NOT BUILD_UDF) ##### the full-dependencies ##### find_package(TBB REQUIRED) find_package(jsoncpp REQUIRED) - find_package(Boost REQUIRED unit_test_framework) find_package(${BCOS_BOOSTSSL_TARGET} REQUIRED) # tcmalloc diff --git a/cpp/cmake/InstallBcosUtilities.cmake b/cpp/cmake/InstallBcosUtilities.cmake index 5fc51ace..83ad5079 100644 --- a/cpp/cmake/InstallBcosUtilities.cmake +++ b/cpp/cmake/InstallBcosUtilities.cmake @@ -1,3 +1,3 @@ -find_package(Boost COMPONENTS log filesystem chrono thread serialization iostreams system) +find_package(Boost COMPONENTS log filesystem chrono thread serialization iostreams system ) find_package(ZLIB REQUIRED) find_package(bcos-utilities REQUIRED) \ No newline at end of file diff --git a/cpp/cmake/Options.cmake b/cpp/cmake/Options.cmake index ab407a53..74a16d18 100644 --- a/cpp/cmake/Options.cmake +++ b/cpp/cmake/Options.cmake @@ -93,6 +93,10 @@ elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows") endif () endif () +if(ENABLE_SSE) + list(APPEND VCPKG_MANIFEST_FEATURES "sse") +endif() + set(ENABLE_CPU_FEATURES OFF) # only ENABLE_CPU_FEATURES for aarch64 and x86 if ("${ARCHITECTURE}" MATCHES "aarch64") diff --git a/cpp/ppc-crypto-core/src/CMakeLists.txt b/cpp/ppc-crypto-core/src/CMakeLists.txt index ed0fa3f2..d40633ab 100644 --- a/cpp/ppc-crypto-core/src/CMakeLists.txt +++ b/cpp/ppc-crypto-core/src/CMakeLists.txt @@ -5,4 +5,5 @@ find_package(OpenSSL REQUIRED) message(STATUS "OPENSSL_INCLUDE_DIR: ${OPENSSL_INCLUDE_DIR}") message(STATUS "OPENSSL_LIBRARIES: ${OPENSSL_LIBRARIES}") -target_link_libraries(${CRYPTO_CORE_TARGET} PUBLIC ${BCOS_UTILITIES_TARGET} OpenSSL::Crypto ${CPU_FEATURES_LIB}) \ No newline at end of file +find_package(unofficial-sodium REQUIRED) +target_link_libraries(${CRYPTO_CORE_TARGET} PUBLIC unofficial-sodium::sodium ${BCOS_UTILITIES_TARGET} OpenSSL::Crypto ${CPU_FEATURES_LIB}) \ No newline at end of file diff --git a/cpp/ppc-gateway/test/unittests/MockCache.h b/cpp/ppc-gateway/test/unittests/MockCache.h index f052b652..a974faed 100644 --- a/cpp/ppc-gateway/test/unittests/MockCache.h +++ b/cpp/ppc-gateway/test/unittests/MockCache.h @@ -21,66 +21,66 @@ #include "ppc-framework/storage/CacheStorage.h" #include -namespace ppc::mock -{ -class MockCache : public storage::CacheStorage -{ +namespace ppc::mock { +class MockCache : public storage::CacheStorage { public: - using Ptr = std::shared_ptr; - MockCache() = default; - ~MockCache() override {} + using Ptr = std::shared_ptr; + MockCache() = default; + ~MockCache() override {} - /// Note: all these interfaces throws exception when error happened - /** - * @brief: check whether the key exists - * @param _key: key - * @return whether the key exists - */ - bool exists(const std::string& _key) override { return m_kv.find(_key) != m_kv.end(); } - - /** - * @brief: set key value - * @param _expirationTime: timeout of key, seconds - */ - void setValue(const std::string& _key, const std::string& _value, - int32_t _expirationSeconds = -1) override - { - m_kv.emplace(_key, _value); - } + /// Note: all these interfaces throws exception when error happened + /** + * @brief: check whether the key exists + * @param _key: key + * @return whether the key exists + */ + bool exists(const std::string &_key) override { + return m_kv.find(_key) != m_kv.end(); + } + /** + * @brief: set key value + * @param _expirationTime: timeout of key, seconds + */ + void setValue(const std::string &_key, const std::string &_value, + int32_t _expirationSeconds = -1) override { + m_kv.emplace(_key, _value); + } - /** - * @brief: get value by key - * @param _key: key - * @return value - */ - Optional getValue(const std::string& _key) override - { - auto it = m_kv.find(_key); - if (it == m_kv.end()) - { - return std::nullopt; - } - - return it->second; + /** + * @brief: get value by key + * @param _key: key + * @return value + */ + std::optional getValue(const std::string &_key) override { + auto it = m_kv.find(_key); + if (it == m_kv.end()) { + return std::nullopt; } - /** - * @brief: set a timeout on key - * @param _expirationTime: timeout of key, ms - * @return whether setting is successful - */ - bool expireKey(const std::string& _key, uint32_t _expirationTime) override { return true; } + return it->second; + } + + /** + * @brief: set a timeout on key + * @param _expirationTime: timeout of key, ms + * @return whether setting is successful + */ + bool expireKey(const std::string &_key, uint32_t _expirationTime) override { + return true; + } - /** - * @brief: delete key - * @param _key: key - * @return the number of key deleted - */ - uint64_t deleteKey(const std::string& _key) override { return m_kv.erase(_key); } + /** + * @brief: delete key + * @param _key: key + * @return the number of key deleted + */ + uint64_t deleteKey(const std::string &_key) override { + return m_kv.erase(_key); + } private: - std::unordered_map> m_kv; + std::unordered_map> m_kv; }; -} // namespace ppc::mock \ No newline at end of file +} // namespace ppc::mock \ No newline at end of file diff --git a/cpp/ppc-mpc/CMakeLists.txt b/cpp/ppc-mpc/CMakeLists.txt index cdb60d56..53720a81 100644 --- a/cpp/ppc-mpc/CMakeLists.txt +++ b/cpp/ppc-mpc/CMakeLists.txt @@ -1,9 +1,9 @@ project(ppc-mpc VERSION ${VERSION}) add_subdirectory(src) -if (TESTS) - enable_testing() - set(CTEST_OUTPUT_ON_FAILURE TRUE) - add_subdirectory(tests) -endif() +#if (TESTS) +# enable_testing() +# set(CTEST_OUTPUT_ON_FAILURE TRUE) +# add_subdirectory(tests) +#endif() diff --git a/cpp/ppc-pir/tests/CMakeLists.txt b/cpp/ppc-pir/tests/CMakeLists.txt index c7255ab0..6a969962 100644 --- a/cpp/ppc-pir/tests/CMakeLists.txt +++ b/cpp/ppc-pir/tests/CMakeLists.txt @@ -9,4 +9,4 @@ target_include_directories(${TEST_BINARY_NAME} PRIVATE .) # target_link_libraries(${TEST_BINARY_NAME} ${PIR_TARGET} ${RPC_TARGET} ${CRYPTO_TARGET} ${BOOST_UNIT_TEST}) target_link_libraries(${TEST_BINARY_NAME} PUBLIC ${IO_TARGET} ${FRONT_TARGET} ${BCOS_UTILITIES_TARGET} ${TARS_PROTOCOL_TARGET} ${PIR_TARGET} ${RPC_TARGET} ${CRYPTO_TARGET} ${PROTOCOL_TARGET} ${BOOST_UNIT_TEST}) -add_test(NAME test-ays WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} COMMAND ${TEST_BINARY_NAME}) \ No newline at end of file +add_test(NAME test-psi WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} COMMAND ${TEST_BINARY_NAME}) \ No newline at end of file diff --git a/cpp/ppc-pir/tests/FakeOtPIRFactory.h b/cpp/ppc-pir/tests/FakeOtPIRFactory.h index 7393a0b8..c38a1b71 100644 --- a/cpp/ppc-pir/tests/FakeOtPIRFactory.h +++ b/cpp/ppc-pir/tests/FakeOtPIRFactory.h @@ -19,10 +19,10 @@ */ #pragma once +#include "ppc-crypto-core/src/hash/BLAKE2bHash.h" +#include "ppc-crypto-core/src/hash/Sha512Hash.h" #include "ppc-crypto/src/ecc/Ed25519EccCrypto.h" #include "ppc-crypto/src/ecc/OpenSSLEccCrypto.h" -#include "ppc-crypto/src/hash/BLAKE2bHash.h" -#include "ppc-crypto/src/hash/Sha512Hash.h" #include "ppc-framework/crypto/CryptoBox.h" #include "ppc-io/src/DataResourceLoaderImpl.h" #include "ppc-tools/src/config/PPCConfig.h" @@ -38,55 +38,51 @@ using namespace ppc::io; using namespace ppc::front; using namespace ppc::tools; -namespace ppc::test -{ +namespace ppc::test { -class FakeOtPIRImpl : public OtPIRImpl -{ +class FakeOtPIRImpl : public OtPIRImpl { public: - using Ptr = std::shared_ptr; - FakeOtPIRImpl(OtPIRConfig::Ptr const& _config, unsigned _idleTimeMs = 0) - : OtPIRImpl(_config, _idleTimeMs) - { - m_enableOutputExists = true; - } - ~FakeOtPIRImpl() override = default; + using Ptr = std::shared_ptr; + FakeOtPIRImpl(OtPIRConfig::Ptr const &_config, unsigned _idleTimeMs = 0) + : OtPIRImpl(_config, _idleTimeMs) { + m_enableOutputExists = true; + } + ~FakeOtPIRImpl() override = default; }; -class FakeOtPIRFactory : public OtPIRFactory -{ +class FakeOtPIRFactory : public OtPIRFactory { public: - using Ptr = std::shared_ptr; + using Ptr = std::shared_ptr; - FakeOtPIRFactory() + FakeOtPIRFactory() : m_front(std::make_shared()), m_dataResourceLoader(std::make_shared( nullptr, nullptr, nullptr, nullptr, nullptr, nullptr)), - m_threadPool(std::make_shared("ot-pir", 4)) - { - auto hashImpl = std::make_shared(); - auto eccCrypto = std::make_shared(hashImpl, ppc::protocol::ECCCurve::P256); - m_cryptoBox = std::make_shared(hashImpl, eccCrypto); - } + m_threadPool(std::make_shared("ot-pir", 4)) { + auto hashImpl = std::make_shared(); + auto eccCrypto = std::make_shared( + hashImpl, ppc::protocol::ECCCurve::P256); + m_cryptoBox = std::make_shared(hashImpl, eccCrypto); + } - ~FakeOtPIRFactory() override = default; + ~FakeOtPIRFactory() override = default; - OtPIRImpl::Ptr createOtPIR(std::string const& _selfParty) - { - auto config = std::make_shared( - _selfParty, m_front, m_cryptoBox, m_threadPool, m_dataResourceLoader, 1); + OtPIRImpl::Ptr createOtPIR(std::string const &_selfParty) { + auto config = + std::make_shared(_selfParty, m_front, m_cryptoBox, + m_threadPool, m_dataResourceLoader, 1); - return std::make_shared(config); - } + return std::make_shared(config); + } - DataResourceLoaderImpl::Ptr resourceLoader() { return m_dataResourceLoader; } - FakeFront::Ptr front() { return m_front; } - CryptoBox::Ptr cryptoBox() { return m_cryptoBox; } + DataResourceLoaderImpl::Ptr resourceLoader() { return m_dataResourceLoader; } + FakeFront::Ptr front() { return m_front; } + CryptoBox::Ptr cryptoBox() { return m_cryptoBox; } private: - FakeFront::Ptr m_front; - DataResourceLoaderImpl::Ptr m_dataResourceLoader; - ThreadPool::Ptr m_threadPool; - CryptoBox::Ptr m_cryptoBox; + FakeFront::Ptr m_front; + DataResourceLoaderImpl::Ptr m_dataResourceLoader; + ThreadPool::Ptr m_threadPool; + CryptoBox::Ptr m_cryptoBox; }; -} // namespace ppc::test \ No newline at end of file +} // namespace ppc::test \ No newline at end of file diff --git a/cpp/ppc-pir/tests/TestBaseOT.cpp b/cpp/ppc-pir/tests/TestBaseOT.cpp index 7aa264e7..9f5477fb 100644 --- a/cpp/ppc-pir/tests/TestBaseOT.cpp +++ b/cpp/ppc-pir/tests/TestBaseOT.cpp @@ -18,19 +18,18 @@ * @date 2023-03-13 */ #include "FakeOtPIRFactory.h" +#include "ppc-crypto-core/src/hash/HashFactoryImpl.h" +#include "ppc-crypto/src/ecc/EccCryptoFactoryImpl.h" +#include "ppc-crypto/src/ecc/OpenSSLEccCrypto.h" #include "ppc-framework/protocol/Protocol.h" +#include "ppc-pir/src/BaseOT.h" #include "ppc-pir/src/Common.h" #include "ppc-pir/src/OtPIRImpl.h" -#include "ppc-pir/src/BaseOT.h" +#include "test-utils/TaskMock.h" #include #include #include #include -#include "ppc-crypto/src/ecc/OpenSSLEccCrypto.h" -#include "ppc-crypto/src/ecc/EccCryptoFactoryImpl.h" -#include "ppc-crypto/src/hash/HashFactoryImpl.h" -#include "test-utils/TaskMock.h" - using namespace ppc::pir; using namespace ppc::crypto; @@ -38,213 +37,216 @@ using namespace ppc::pir; using namespace bcos; using namespace bcos::test; - -namespace ppc::test -{ +namespace ppc::test { BOOST_FIXTURE_TEST_SUITE(OtPIRest, TestPromptFixture) - -BOOST_AUTO_TEST_CASE(testBaseOT) -{ - // 统计函数执行时间 - std::cout<< "testBaseOT" << std::endl; - - auto eccFactory = std::make_shared(); - auto hashFactory = std::make_shared(); - auto hash = hashFactory->createHashImpl((int8_t)ppc::protocol::HashImplName::SHA512); - EccCrypto::Ptr ecc = eccFactory->createEccCrypto((int8_t)ppc::protocol::ECCCurve::P256, hash); - // auto aysService = std::make_shared(); - // auto aes = std::make_shared( - // OpenSSLAES::AESType::AES128, - // SymCrypto::OperationMode::CBC, _seed, bcos::bytes()); - - - std::string datasetPath = "../../../ppc-pir/tests/data/AysPreDataset.csv"; - // std::cout<< "aysService->prepareDataset" << std::endl; - std::string prefix = "testmsg1"; - bcos::bytes sendObfuscatedHash(prefix.begin(), prefix.end()); - - - uint32_t obfuscatedOrder = 10; - auto baseOT = std::make_shared(ecc, hash); - auto messageKeypair = baseOT->prepareDataset(sendObfuscatedHash, datasetPath); - // for(auto iter = messageKeypair.begin(); iter != messageKeypair.end(); ++iter) - // { - // std::string pairKey(iter->first.begin(), iter->first.end()); - // std::string pairValue(iter->second.begin(), iter->second.end()); - // // for(uint32_t i = 0; i < messageKeypair.size(); ++i) { - // // std::cout<< "pairKey:"<< pairKey << std::endl; - // // std::cout<< "pairValue:"<< pairValue << std::endl; - // } - auto start = std::chrono::high_resolution_clock::now(); - - std::string choice = "testmsg1100"; - // std::cout<< "baseOT->senderGenerateCipher" << std::endl; - auto senderMessage = baseOT->senderGenerateCipher(bcos::bytes(choice.begin(), choice.end()), obfuscatedOrder); - // std::cout<< "baseOT->receiverGenerateMessage" << std::endl; - auto receiverMessage = baseOT->receiverGenerateMessage(senderMessage.pointX, senderMessage.pointY, messageKeypair, senderMessage.pointZ); - - // std::cout<< "baseOT->finishSender" << std::endl; - auto result = baseOT->finishSender(senderMessage.scalarBlidingB, receiverMessage.pointWList, receiverMessage.encryptMessagePair, receiverMessage.encryptCipher); - - auto end = std::chrono::high_resolution_clock::now(); - - auto duration = std::chrono::duration_cast(end - start); - std::cout << "执行时间 time: " << duration.count() << " microseconds" << std::endl; - - if(result.size() == 0){ - std::cout<< "final result: message not found" << std::endl; - } - else { - std::cout<< "final result: " << std::string(result.begin(), result.end()) << std::endl; - } - // for(uint32_t i = 0; i < result.size(); ++i) - // { - // std::cout<< std::string(result[i].begin(), result[i].end()) << std::endl; - // } - - BOOST_CHECK(true); +BOOST_AUTO_TEST_CASE(testBaseOT) { + // 统计函数执行时间 + std::cout << "testBaseOT" << std::endl; + + auto eccFactory = std::make_shared(); + auto hashFactory = std::make_shared(); + auto hash = + hashFactory->createHashImpl((int8_t)ppc::protocol::HashImplName::SHA512); + EccCrypto::Ptr ecc = + eccFactory->createEccCrypto((int8_t)ppc::protocol::ECCCurve::P256, hash); + // auto aysService = std::make_shared(); + // auto aes = std::make_shared( + // OpenSSLAES::AESType::AES128, + // SymCrypto::OperationMode::CBC, _seed, bcos::bytes()); + + std::string datasetPath = "../../../ppc-pir/tests/data/AysPreDataset.csv"; + // std::cout<< "aysService->prepareDataset" << std::endl; + std::string prefix = "testmsg1"; + bcos::bytes sendObfuscatedHash(prefix.begin(), prefix.end()); + + uint32_t obfuscatedOrder = 10; + auto baseOT = std::make_shared(ecc, hash); + auto messageKeypair = baseOT->prepareDataset(sendObfuscatedHash, datasetPath); + // for(auto iter = messageKeypair.begin(); iter != messageKeypair.end(); + // ++iter) + // { + // std::string pairKey(iter->first.begin(), iter->first.end()); + // std::string pairValue(iter->second.begin(), iter->second.end()); + // // for(uint32_t i = 0; i < messageKeypair.size(); ++i) { + // // std::cout<< "pairKey:"<< pairKey << std::endl; + // // std::cout<< "pairValue:"<< pairValue << std::endl; + // } + auto start = std::chrono::high_resolution_clock::now(); + + std::string choice = "testmsg1100"; + // std::cout<< "baseOT->senderGenerateCipher" << std::endl; + auto senderMessage = baseOT->senderGenerateCipher( + bcos::bytes(choice.begin(), choice.end()), obfuscatedOrder); + // std::cout<< "baseOT->receiverGenerateMessage" << std::endl; + auto receiverMessage = baseOT->receiverGenerateMessage( + senderMessage.pointX, senderMessage.pointY, messageKeypair, + senderMessage.pointZ); + + // std::cout<< "baseOT->finishSender" << std::endl; + auto result = baseOT->finishSender( + senderMessage.scalarBlidingB, receiverMessage.pointWList, + receiverMessage.encryptMessagePair, receiverMessage.encryptCipher); + + auto end = std::chrono::high_resolution_clock::now(); + + auto duration = + std::chrono::duration_cast(end - start); + std::cout << "执行时间 time: " << duration.count() << " microseconds" + << std::endl; + + if (result.size() == 0) { + std::cout << "final result: message not found" << std::endl; + } else { + std::cout << "final result: " << std::string(result.begin(), result.end()) + << std::endl; + } + // for(uint32_t i = 0; i < result.size(); ++i) + // { + // std::cout<< std::string(result[i].begin(), result[i].end()) << + // std::endl; + // } + + BOOST_CHECK(true); } - void testOTPIR(FakeOtPIRFactory::Ptr _factory, OtPIRImpl::Ptr _sender, - OtPIRImpl::Ptr _receiver, ppc::protocol::Task::ConstPtr _senderPirTask, - ppc::protocol::Task::ConstPtr _receiverPirTask, - std::vector const& _expectedPIRResult, bool _expectedSuccess, - int _expectedErrorCode = 0) -{ - std::atomic flag = 0; - - _sender->asyncRunTask(_senderPirTask, [_senderPirTask, _expectedSuccess, _expectedErrorCode, - &flag](ppc::protocol::TaskResult::Ptr&& _response) { - if (_expectedSuccess) - { - BOOST_CHECK(_response->error() == nullptr || _response->error()->errorCode() == 0); - BOOST_CHECK(_response->taskID() == _senderPirTask->id()); - auto result = _response->error(); - BOOST_CHECK(result == nullptr || result->errorCode() == 0); + OtPIRImpl::Ptr _receiver, + ppc::protocol::Task::ConstPtr _senderPirTask, + ppc::protocol::Task::ConstPtr _receiverPirTask, + std::vector const &_expectedPIRResult, + bool _expectedSuccess, int _expectedErrorCode = 0) { + std::atomic flag = 0; + + _sender->asyncRunTask( + _senderPirTask, [_senderPirTask, _expectedSuccess, _expectedErrorCode, + &flag](ppc::protocol::TaskResult::Ptr &&_response) { + if (_expectedSuccess) { + BOOST_CHECK(_response->error() == nullptr || + _response->error()->errorCode() == 0); + BOOST_CHECK(_response->taskID() == _senderPirTask->id()); + auto result = _response->error(); + BOOST_CHECK(result == nullptr || result->errorCode() == 0); + } else { + BOOST_CHECK(_response->error() != nullptr); + auto result = _response->error(); + BOOST_CHECK(result != nullptr); + BOOST_CHECK(_response->error()->errorCode() == _expectedErrorCode); } - else - { - BOOST_CHECK(_response->error() != nullptr); - auto result = _response->error(); - BOOST_CHECK(result != nullptr); - BOOST_CHECK(_response->error()->errorCode() == _expectedErrorCode); + flag++; + }); + _sender->start(); + + _receiver->asyncRunTask( + _receiverPirTask, [_receiverPirTask, _expectedSuccess, + &flag](ppc::protocol::TaskResult::Ptr &&_response) { + if (_expectedSuccess) { + BOOST_CHECK(_response->error() == nullptr || + _response->error()->errorCode() == 0); + BOOST_CHECK(_response->taskID() == _receiverPirTask->id()); + auto result = _response->error(); + BOOST_CHECK(result == nullptr || result->errorCode() == 0); + } else { + BOOST_CHECK(_response->error() != nullptr); + auto result = _response->error(); + BOOST_CHECK(result != nullptr); } flag++; - }); - _sender->start(); - - _receiver->asyncRunTask(_receiverPirTask, - [_receiverPirTask, _expectedSuccess, &flag](ppc::protocol::TaskResult::Ptr&& _response) { - if (_expectedSuccess) - { - BOOST_CHECK(_response->error() == nullptr || _response->error()->errorCode() == 0); - BOOST_CHECK(_response->taskID() == _receiverPirTask->id()); - auto result = _response->error(); - BOOST_CHECK(result == nullptr || result->errorCode() == 0); - } - else - { - BOOST_CHECK(_response->error() != nullptr); - auto result = _response->error(); - BOOST_CHECK(result != nullptr); - } - flag++; - }); - _receiver->start(); - - // wait for the task finish and check - while (flag < 2) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - - _sender->stop(); - _receiver->stop(); - - // if (_expectedSuccess && !_expectedPIRResult.empty()) - // { - // checkTaskPIRResult(_factory->resourceLoader(), _receiverPirTask, _expectedPIRResult.size(), - // _expectedPIRResult); - // } + }); + _receiver->start(); + + // wait for the task finish and check + while (flag < 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + _sender->stop(); + _receiver->stop(); + + // if (_expectedSuccess && !_expectedPIRResult.empty()) + // { + // checkTaskPIRResult(_factory->resourceLoader(), _receiverPirTask, + // _expectedPIRResult.size(), + // _expectedPIRResult); + // } } - -void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, - bool _syncResults, PartyResource::Ptr _senderParty, PartyResource::Ptr _receiverParty, - std::vector const& _expectedPIRResult, bool _expectedSuccess, - int _expectedErrorCode = 0) -{ - auto factory = std::make_shared(); - - // fake the sender - std::string senderAgencyName = "sender"; - auto senderPIR = factory->createOtPIR("sender"); - - // fake the receiver - std::string receiverAgencyName = _receiverParty->id(); - auto receiverPIR = factory->createOtPIR(receiverAgencyName); - - // register the server-pir into the front - factory->front()->registerOTPIR(senderAgencyName, senderPIR); - factory->front()->registerOTPIR(receiverAgencyName, receiverPIR); - - // trigger the pir task - auto senderPIRTask = std::make_shared(senderAgencyName); - senderPIRTask->setId(_taskID); - senderPIRTask->setParam(_params); - senderPIRTask->setSelf(_senderParty); - senderPIRTask->addParty(_receiverParty); - senderPIRTask->setSyncResultToPeer(_syncResults); - senderPIRTask->setAlgorithm((uint8_t)PSIAlgorithmType::OT_PIR_2PC); - senderPIRTask->setType((uint8_t)ppc::protocol::TaskType::PIR); - - auto receiverPIRTask = std::make_shared(receiverAgencyName); - receiverPIRTask->setId(_taskID); - receiverPIRTask->setParam(_params); - receiverPIRTask->setSelf(_receiverParty); - receiverPIRTask->addParty(_senderParty); - receiverPIRTask->setSyncResultToPeer(_syncResults); - receiverPIRTask->setAlgorithm((uint8_t)PSIAlgorithmType::OT_PIR_2PC); - receiverPIRTask->setType((uint8_t)ppc::protocol::TaskType::PIR); - - testOTPIR(factory, senderPIR, receiverPIR, senderPIRTask, receiverPIRTask, - _expectedPIRResult, _expectedSuccess, _expectedErrorCode); +void testOTPIRImplFunc(const std::string &_taskID, const std::string &_params, + bool _syncResults, PartyResource::Ptr _senderParty, + PartyResource::Ptr _receiverParty, + std::vector const &_expectedPIRResult, + bool _expectedSuccess, int _expectedErrorCode = 0) { + auto factory = std::make_shared(); + + // fake the sender + std::string senderAgencyName = "sender"; + auto senderPIR = factory->createOtPIR("sender"); + + // fake the receiver + std::string receiverAgencyName = _receiverParty->id(); + auto receiverPIR = factory->createOtPIR(receiverAgencyName); + + // register the server-pir into the front + factory->front()->registerOTPIR(senderAgencyName, senderPIR); + factory->front()->registerOTPIR(receiverAgencyName, receiverPIR); + + // trigger the pir task + auto senderPIRTask = std::make_shared(senderAgencyName); + senderPIRTask->setId(_taskID); + senderPIRTask->setParam(_params); + senderPIRTask->setSelf(_senderParty); + senderPIRTask->addParty(_receiverParty); + senderPIRTask->setSyncResultToPeer(_syncResults); + senderPIRTask->setAlgorithm((uint8_t)PSIAlgorithmType::OT_PIR_2PC); + senderPIRTask->setType((uint8_t)ppc::protocol::TaskType::PIR); + + auto receiverPIRTask = std::make_shared(receiverAgencyName); + receiverPIRTask->setId(_taskID); + receiverPIRTask->setParam(_params); + receiverPIRTask->setSelf(_receiverParty); + receiverPIRTask->addParty(_senderParty); + receiverPIRTask->setSyncResultToPeer(_syncResults); + receiverPIRTask->setAlgorithm((uint8_t)PSIAlgorithmType::OT_PIR_2PC); + receiverPIRTask->setType((uint8_t)ppc::protocol::TaskType::PIR); + + testOTPIR(factory, senderPIR, receiverPIR, senderPIRTask, receiverPIRTask, + _expectedPIRResult, _expectedSuccess, _expectedErrorCode); } -BOOST_AUTO_TEST_CASE(testNormalOtPIRCase) -{ - std::string otSearchPath = "../../../ppc-pir/tests/data/AysPreDataset.csv"; - std::string outputPath = "../../../ppc-pir/tests/data/output.csv"; +BOOST_AUTO_TEST_CASE(testNormalOtPIRCase) { + std::string otSearchPath = "../../../ppc-pir/tests/data/AysPreDataset.csv"; + std::string outputPath = "../../../ppc-pir/tests/data/output.csv"; - uint32_t count = 513; - // prepareInputs(senderPath, count, receiverPath, count, count); + uint32_t count = 513; + // prepareInputs(senderPath, count, receiverPath, count, count); - auto senderParty = mockParty((uint16_t)ppc::protocol::PartyType::Client, "sender", - "senderPartyResource", "sender_inputs", protocol::DataResourceType::FILE, ""); - auto senderOutputDesc = std::make_shared(); - senderOutputDesc->setPath(outputPath); - senderParty->mutableDataResource()->setOutputDesc(senderOutputDesc); + auto senderParty = mockParty((uint16_t)ppc::protocol::PartyType::Client, + "sender", "senderPartyResource", "sender_inputs", + protocol::DataResourceType::FILE, ""); + auto senderOutputDesc = std::make_shared(); + senderOutputDesc->setPath(outputPath); + senderParty->mutableDataResource()->setOutputDesc(senderOutputDesc); - auto receiverParty = mockParty((uint16_t)ppc::protocol::PartyType::Server, "receiver", - "receiverPartyResource", "receiver_inputs", DataResourceType::FILE, otSearchPath); + auto receiverParty = + mockParty((uint16_t)ppc::protocol::PartyType::Server, "receiver", + "receiverPartyResource", "receiver_inputs", + DataResourceType::FILE, otSearchPath); - // auto receiverOutputDesc = std::make_shared(); - // receiverOutputDesc->setPath(outputPath); - // receiverParty->mutableDataResource()->setOutputDesc(receiverOutputDesc); + // auto receiverOutputDesc = std::make_shared(); + // receiverOutputDesc->setPath(outputPath); + // receiverParty->mutableDataResource()->setOutputDesc(receiverOutputDesc); - std::vector expectedResult; - for (uint32_t i = 0; i < count; i++) - { - expectedResult.emplace_back(std::to_string(100000 + i)); - } + std::vector expectedResult; + for (uint32_t i = 0; i < count; i++) { + expectedResult.emplace_back(std::to_string(100000 + i)); + } - std::string jobParams = "{\"searchId\":\"testmsg1100\",\"requestAgencyId\":\"receiver\",\"prefixLength\":6}"; + std::string jobParams = "{\"searchId\":\"testmsg1100\",\"requestAgencyId\":" + "\"receiver\",\"prefixLength\":6}"; - testOTPIRImplFunc( - "0x12345678", jobParams, true, senderParty, receiverParty, expectedResult, true, 0); + testOTPIRImplFunc("0x12345678", jobParams, true, senderParty, receiverParty, + expectedResult, true, 0); } BOOST_AUTO_TEST_SUITE_END() -} // namespace ppc::test +} // namespace ppc::test diff --git a/cpp/ppc-pir/tests/data/AysPreDataset.csv b/cpp/ppc-pir/tests/data/AysPreDataset.csv new file mode 100644 index 00000000..590b61f2 --- /dev/null +++ b/cpp/ppc-pir/tests/data/AysPreDataset.csv @@ -0,0 +1,3 @@ +id,x1 +1,test +2,test2 diff --git a/cpp/ppc-psi/tests/CMakeLists.txt b/cpp/ppc-psi/tests/CMakeLists.txt index f2fa68c3..26b00337 100644 --- a/cpp/ppc-psi/tests/CMakeLists.txt +++ b/cpp/ppc-psi/tests/CMakeLists.txt @@ -6,6 +6,6 @@ set(TEST_BINARY_NAME test-ppc-psi) add_executable(${TEST_BINARY_NAME} ${SOURCES}) target_include_directories(${TEST_BINARY_NAME} PRIVATE .) -target_link_libraries(${TEST_BINARY_NAME} ${ECDH_CONN_PSI_TARGET} ${RA2018_PSI_TARGET} ${LABELED_PSI_TARGET} ${CM2020_PSI_TARGET} ${ECDH_2PC_PSI_TARGET} ${PROTOCOL_TARGET} ${IO_TARGET} ${LABELED_PSI_TARGET} ${CRYPTO_TARGET} ${BOOST_UNIT_TEST}) +target_link_libraries(${TEST_BINARY_NAME} ${BS_ECDH_PSI_TARGET} ${ECDH_CONN_PSI_TARGET} ${RA2018_PSI_TARGET} ${LABELED_PSI_TARGET} ${CM2020_PSI_TARGET} ${ECDH_2PC_PSI_TARGET} ${PROTOCOL_TARGET} ${IO_TARGET} ${LABELED_PSI_TARGET} ${CRYPTO_TARGET} ${BOOST_UNIT_TEST}) target_link_libraries(${TEST_BINARY_NAME} ${RA2018_PSI_TARGET} ${LABELED_PSI_TARGET} ${CM2020_PSI_TARGET} ${ECDH_2PC_PSI_TARGET} ${PROTOCOL_TARGET} ${IO_TARGET} ${LABELED_PSI_TARGET} ${CRYPTO_TARGET} ${BOOST_UNIT_TEST}) add_test(NAME test-ppc-psi WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} COMMAND ${TEST_BINARY_NAME}) \ No newline at end of file diff --git a/cpp/vcpkg b/cpp/vcpkg new file mode 160000 index 00000000..51b14cd4 --- /dev/null +++ b/cpp/vcpkg @@ -0,0 +1 @@ +Subproject commit 51b14cd4e1230dd51c11ffeff6f7d53c61cc5297 diff --git a/cpp/vcpkg-configuration.json b/cpp/vcpkg-configuration.json index 71d2a9ae..5913a7c1 100644 --- a/cpp/vcpkg-configuration.json +++ b/cpp/vcpkg-configuration.json @@ -3,7 +3,7 @@ { "kind": "git", "repository": "https://github.com/FISCO-BCOS/registry", - "baseline": "a3508ded2bb7f83d95dd3f7406b05b2500a1fdbe", + "baseline": "070f336149afdac5cc9ace97df01de7ee31aab30", "packages": [ "openssl", "bcos-utilities", diff --git a/cpp/vcpkg.json b/cpp/vcpkg.json index c85989f4..34c765ca 100644 --- a/cpp/vcpkg.json +++ b/cpp/vcpkg.json @@ -65,7 +65,7 @@ }, { "name": "libhdfs3", - "version": "2024-04-27" + "version": "2024-04-27" }, { "name": "tbb", @@ -75,5 +75,10 @@ "name": "tarscpp", "version": "3.0.3-1#1" } - ] + ], + "features": { + "sse": { + "description": "Enable SSE4.2 for libhdfs3" + } + } } \ No newline at end of file From ccec2073842042b8d1cd425f0d0f516f4e00b3fc Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Thu, 22 Aug 2024 01:34:17 +0800 Subject: [PATCH 2/5] fix ci --- .github/workflows/cpp_workflow.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cpp_workflow.yml b/.github/workflows/cpp_workflow.yml index bdec90b6..8d9f76e8 100644 --- a/.github/workflows/cpp_workflow.yml +++ b/.github/workflows/cpp_workflow.yml @@ -72,4 +72,10 @@ jobs: if: runner.os != 'Windows' run: | cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE ctest - make cov \ No newline at end of file + make cov + - uses: webiny/action-post-run@3.1.0 + id: post-run-command + with: + if: runner.os == 'Linux' + run: cat vcpkg/buildtrees/libhdfs3/config-x64-linux-dbg-err.log + working-directory: /home/runner/work/WeDPR-Component/WeDPR-Component/ \ No newline at end of file From e5b27b7c4ba191744912f218eab89aef1c1d8530 Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Thu, 22 Aug 2024 11:18:14 +0800 Subject: [PATCH 3/5] refactor cmake (#5) * refactor cmake * add build sdk ci --- .github/workflows/cpp_workflow.yml | 67 +++++++++++- cpp/CMakeLists.txt | 109 ++++++------------ cpp/cmake/Dependencies.cmake | 51 +++++++++ cpp/cmake/FindGSasl.cmake | 27 ----- cpp/cmake/InstallBcosUtilities.cmake | 2 +- cpp/cmake/Options.cmake | 146 ++++++++++++++----------- cpp/ppc-crypto-core/src/CMakeLists.txt | 4 +- cpp/ppc-crypto/src/CMakeLists.txt | 4 +- cpp/vcpkg.json | 78 +++++++------ 9 files changed, 280 insertions(+), 208 deletions(-) create mode 100644 cpp/cmake/Dependencies.cmake delete mode 100644 cpp/cmake/FindGSasl.cmake diff --git a/.github/workflows/cpp_workflow.yml b/.github/workflows/cpp_workflow.yml index 8d9f76e8..5f75d524 100644 --- a/.github/workflows/cpp_workflow.yml +++ b/.github/workflows/cpp_workflow.yml @@ -19,7 +19,7 @@ concurrency: jobs: build: - name: build + name: build all runs-on: ${{ matrix.os }} continue-on-error: true strategy: @@ -73,9 +73,64 @@ jobs: run: | cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE ctest make cov - - uses: webiny/action-post-run@3.1.0 - id: post-run-command + + build_sdk: + name: build sdk + runs-on: ${{ matrix.os }} + continue-on-error: true + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, windows-2019, macos-12] + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 5 + - uses: actions-rs/toolchain@v1 with: - if: runner.os == 'Linux' - run: cat vcpkg/buildtrees/libhdfs3/config-x64-linux-dbg-err.log - working-directory: /home/runner/work/WeDPR-Component/WeDPR-Component/ \ No newline at end of file + toolchain: nightly-2022-07-28 + override: true + - name: Prepare vcpkg + if: runner.os != 'Windows' + uses: friendlyanon/setup-vcpkg@v1 + with: { committish: 51b14cd4e1230dd51c11ffeff6f7d53c61cc5297 } + - uses: actions/cache@v2 + id: deps_cache + with: + path: | + deps/ + c:/vcpkg + !c:/vcpkg/.git + !c:/vcpkg/buildtrees + !c:/vcpkg/packages + !c:/vcpkg/downloads + key: build-${{ matrix.os }}-${{ github.base_ref }}-${{ hashFiles('.github/workflows/workflow.yml') }} + restore-keys: | + build-${{ matrix.os }}-${{ github.base_ref }}-${{ hashFiles('.github/workflows/workflow.yml') }} + build-${{ matrix.os }}-${{ github.base_ref }}- + build-${{ matrix.os }}- + - name: Build for windows + if: runner.os == 'Windows' + run: | + mkdir -p build && cd build && cmake -DCMAKE_BUILD_TYPE=Release -DTESTS=OFF -DBUILD_SDK=ON -DVCPKG_TARGET_TRIPLET=x64-windows-static -DCMAKE_TOOLCHAIN_FILE=c:/vcpkg/scripts/buildsystems/vcpkg.cmake ../ + cmake --build . --parallel 3 + - name: Build for linux + if: runner.os == 'Linux' + run: | + sudo apt install -y lcov ccache wget libgmp-dev python3-dev + export GCC='gcc-10' + export CXX='g++-10' + bash cpp/tools/install_depends.sh -o ubuntu + mkdir -p cpp/build && cd cpp/build && cmake -DTESTS=ON -DCOVERAGE=ON -DBUILD_SDK=ON -DCMAKE_TOOLCHAIN_FILE=${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake ../ + make -j3 + - name: Build for macos + if: runner.os == 'macOS' + run: | + bash cpp/tools/install_depends.sh -o macos + mkdir -p cpp/build && cd cpp/build && cmake -DTESTS=ON -DCOVERAGE=ON -DBUILD_SDK=ON -DCMAKE_TOOLCHAIN_FILE=${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake ../ + make -j3 + - name: Test + if: runner.os != 'Windows' + run: | + cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE ctest + make cov \ No newline at end of file diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f15fcba6..0fb668fc 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -35,6 +35,13 @@ if(WIN32) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS "ON") endif() + + +# basic settings +include(Options) +configure_project() + + # vcpkg init if(NOT DEFINED CMAKE_TOOLCHAIN_FILE) find_package(Git REQUIRED) @@ -46,88 +53,40 @@ endif() include(Version) project(WeDPR-Component VERSION ${VERSION}) -# basic settings -include(Options) -configure_project() include(CompilerSettings) include(BuildInfoGenerator) - -find_package(OpenSSL REQUIRED) -include(InstallBcosUtilities) -find_package(Boost COMPONENTS unit_test_framework) -if(BUILD_SDK) - ##### the sdk-dependencies ##### - # find JNI - set(JAVA_AWT_LIBRARY NotNeeded) - set(JAVA_JVM_LIBRARY NotNeeded) - find_package(JNI REQUIRED) - include_directories(${JNI_INCLUDE_DIRS}) -endif() - -# ipp-crypto -#if(ENABLE_IPP_CRYPTO) -# hunter_add_package(ipp-crypto) -#endif() - -if(ENABLE_CPU_FEATURES) - find_package(CpuFeatures REQUIRED) -endif() - include(IncludeDirectories) +# the target settings include(TargetSettings) - -if(BUILD_SDK) - add_subdirectory(ppc-crypto-c-sdk) - set(JNI_SOURCE_PATH ppc-crypto-c-sdk/bindings/java/src/main/c) - add_subdirectory(${JNI_SOURCE_PATH}) +# dependencies +include(Dependencies) + +########### set the sources ########### +set(JNI_SOURCE_PATH ppc-crypto-c-sdk/bindings/java/src/main/c) +set(SDK_SOURCE_LIST ppc-homo ppc-crypto-core ppc-crypto-c-sdk ${JNI_SOURCE_PATH}) +# Note: the udf depends on mysql, not enabled in the full node mode +set(UDF_SOURCE_LIST ${SDK_SOURCE_LIST} ppc-udf) +set(ALL_SOURCE_LIST + ${SDK_SOURCE_LIST} ppc-crypto + libhelper libinitializer ppc-io ppc-protocol + ppc-gateway ppc-front ppc-tars-protocol + ppc-tools ppc-storage ppc-psi ppc-rpc + ppc-http ppc-mpc ppc-pir + ${CEM_SOURCE} ppc-main) +set(CEM_SOURCE "") +if(BUILD_CEM) + set(CEM_SOURCE "ppc-cem") endif() -add_subdirectory(ppc-crypto-core) -add_subdirectory(ppc-homo) - -# when BUILD_SDK, the following modules no need to compile -if(NOT BUILD_SDK AND NOT BUILD_UDF) - ##### the full-dependencies ##### - find_package(TBB REQUIRED) - find_package(jsoncpp REQUIRED) - - find_package(${BCOS_BOOSTSSL_TARGET} REQUIRED) - # tcmalloc - include(ProjectTCMalloc) - - find_package(SEAL REQUIRED) - find_package(Kuku REQUIRED) - - # APSI: Note: APSI depends on seal 4.0 and Kuku 2.1 - include(ProjectAPSI) - # Wedpr Crypto - include(ProjectWedprCrypto) - include(FindGSasl) - include(Installlibhdfs3) - - add_subdirectory(ppc-crypto) - add_subdirectory(libhelper) - add_subdirectory(libinitializer) - add_subdirectory(ppc-io) - add_subdirectory(ppc-protocol) - add_subdirectory(ppc-gateway) - add_subdirectory(ppc-front) - add_subdirectory(ppc-tars-protocol) - add_subdirectory(ppc-tools) - add_subdirectory(ppc-storage) - add_subdirectory(ppc-psi) - add_subdirectory(ppc-rpc) - add_subdirectory(ppc-http) - add_subdirectory(ppc-mpc) - add_subdirectory(ppc-pir) - if(BUILD_CEM) - add_subdirectory(ppc-cem) - endif () - add_subdirectory(ppc-main) -endif() -if(BUILD_UDF) - add_subdirectory(ppc-udf) + +if(BUILD_ALL) + add_sources("${ALL_SOURCE_LIST}") +elseif(BUILD_UDF) + add_sources("${UDF_SOURCE_LIST}") +elseif(BUILD_SDK) + add_sources("${SDK_SOURCE_LIST}") endif() +########### set the sources end ########### if (TESTS) enable_testing() diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake new file mode 100644 index 00000000..4354187e --- /dev/null +++ b/cpp/cmake/Dependencies.cmake @@ -0,0 +1,51 @@ + +# ipp-crypto +#if(ENABLE_IPP_CRYPTO) +# hunter_add_package(ipp-crypto) +#endif() + +######## common dependencies ######## +find_package(OpenSSL REQUIRED) +include(InstallBcosUtilities) + +if (TESTS) + find_package(Boost COMPONENTS unit_test_framework) +endif() + +# cpp_features +if(ENABLE_CPU_FEATURES) + find_package(CpuFeatures REQUIRED) +endif() +find_package(unofficial-sodium CONFIG REQUIRED) +######## common dependencies end ######## + + +##### the full-dependencies ##### +if(BUILD_ALL) + find_package(TBB REQUIRED) + find_package(jsoncpp REQUIRED) + + find_package(${BCOS_BOOSTSSL_TARGET} REQUIRED) + # tcmalloc + include(ProjectTCMalloc) + + find_package(SEAL REQUIRED) + find_package(Kuku REQUIRED) + + # APSI: Note: APSI depends on seal 4.0 and Kuku 2.1 + include(ProjectAPSI) + # Wedpr Crypto + include(ProjectWedprCrypto) + include(Installlibhdfs3) +endif() +##### the full-dependencies end ##### + +##### the sdk-dependencies ##### +if(BUILD_SDK) + # find JNI + set(JAVA_AWT_LIBRARY NotNeeded) + set(JAVA_JVM_LIBRARY NotNeeded) + find_package(JNI REQUIRED) + include_directories(${JNI_INCLUDE_DIRS}) +endif() +##### the sdk-dependencies end##### \ No newline at end of file diff --git a/cpp/cmake/FindGSasl.cmake b/cpp/cmake/FindGSasl.cmake deleted file mode 100644 index caefeea1..00000000 --- a/cpp/cmake/FindGSasl.cmake +++ /dev/null @@ -1,27 +0,0 @@ -# - Try to find the GNU sasl library (gsasl) -# -# Once done this will define -# -# GSASL_FOUND - System has gnutls -# GSASL_INCLUDE_DIR - The gnutls include directory -# GSASL_LIBRARIES - The libraries needed to use gnutls -# GSASL_DEFINITIONS - Compiler switches required for using gnutls - - -IF (GSASL_INCLUDE_DIR AND GSASL_LIBRARIES) - # in cache already - SET(GSasl_FIND_QUIETLY TRUE) -ENDIF (GSASL_INCLUDE_DIR AND GSASL_LIBRARIES) - -FIND_PATH(GSASL_INCLUDE_DIR gsasl.h) - -FIND_LIBRARY(GSASL_LIBRARIES gsasl) -FIND_LIBRARY(GSASL_STATIC_LIBRARIES NAMES "libgsasl.a") - -INCLUDE(FindPackageHandleStandardArgs) - -# handle the QUIETLY and REQUIRED arguments and set GSASL_FOUND to TRUE if -# all listed variables are TRUE -FIND_PACKAGE_HANDLE_STANDARD_ARGS(GSASL DEFAULT_MSG GSASL_LIBRARIES GSASL_INCLUDE_DIR) - -MARK_AS_ADVANCED(GSASL_INCLUDE_DIR GSASL_LIBRARIES) \ No newline at end of file diff --git a/cpp/cmake/InstallBcosUtilities.cmake b/cpp/cmake/InstallBcosUtilities.cmake index 83ad5079..6177ba72 100644 --- a/cpp/cmake/InstallBcosUtilities.cmake +++ b/cpp/cmake/InstallBcosUtilities.cmake @@ -1,3 +1,3 @@ find_package(Boost COMPONENTS log filesystem chrono thread serialization iostreams system ) find_package(ZLIB REQUIRED) -find_package(bcos-utilities REQUIRED) \ No newline at end of file +find_package(bcos-utilities CONFIG REQUIRED) \ No newline at end of file diff --git a/cpp/cmake/Options.cmake b/cpp/cmake/Options.cmake index 74a16d18..0588372d 100644 --- a/cpp/cmake/Options.cmake +++ b/cpp/cmake/Options.cmake @@ -9,6 +9,11 @@ macro(default_option O DEF) set(${O} ${DEF}) endif () endmacro() +macro(add_sources source_list) + foreach(source ${source_list}) + add_subdirectory(${source}) + endforeach() +endmacro() # common settings if ("${CMAKE_SIZEOF_VOID_P}" STREQUAL "4") @@ -44,89 +49,102 @@ macro(configure_project) default_option(ENABLE_DEMO OFF) # sdk + default_option(BUILD_ALL ON) default_option(BUILD_SDK OFF) default_option(BUILD_UDF OFF) - if(BUILD_UDF) - set(BUILD_SDK ON) - endif() # Suffix like "-rc1" e.t.c. to append to versions wherever needed. if (NOT DEFINED VERSION_SUFFIX) set(VERSION_SUFFIX "") endif () - print_config(${NAME}) -endmacro() -# for boost-ssl enable/disable native -set(ARCH_NATIVE OFF) -if ("${ARCHITECTURE}" MATCHES "aarch64" OR "${ARCHITECTURE}" MATCHES "arm64") - set(ARCH_NATIVE ON) -endif () + # for boost-ssl enable/disable native + set(ARCH_NATIVE OFF) + if ("${ARCHITECTURE}" MATCHES "aarch64" OR "${ARCHITECTURE}" MATCHES "arm64") + set(ARCH_NATIVE ON) + endif () -set(VISIBILITY_FLAG " -fvisibility=hidden -fvisibility-inlines-hidden") -if (BUILD_UDF) - set(VISIBILITY_FLAG "") -endif() -if (BUILD_SDK) - set(VISIBILITY_FLAG "") -endif() -set(MARCH_TYPE "-march=x86-64 -mtune=generic ${VISIBILITY_FLAG}") -if (ARCH_NATIVE) - set(MARCH_TYPE "-march=native -mtune=native ${VISIBILITY_FLAG}") -endif () + set(VISIBILITY_FLAG " -fvisibility=hidden -fvisibility-inlines-hidden") -# for enable sse4.2(hdfs used) -set(ENABLE_SSE OFF) -# for enable/disable ipp-crypto -if (APPLE) - EXECUTE_PROCESS(COMMAND sysctl -a COMMAND grep "machdep.cpu.*features" COMMAND tr -d '\n' OUTPUT_VARIABLE SUPPORTED_INSTRUCTIONS) - message("* SUPPORTED_INSTRUCTIONS: ${SUPPORTED_INSTRUCTIONS}") - # detect sse4.2 - if (${SUPPORTED_INSTRUCTIONS} MATCHES ".*SSE4.2.*") - set(ENABLE_SSE ON) + set(MARCH_TYPE "-march=x86-64 -mtune=generic ${VISIBILITY_FLAG}") + if (ARCH_NATIVE) + set(MARCH_TYPE "-march=native -mtune=native ${VISIBILITY_FLAG}") endif () -elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows") - # detect sse4_2 - FILE(READ "/proc/cpuinfo" SUPPORTED_INSTRUCTIONS) - if (${SUPPORTED_INSTRUCTIONS} MATCHES ".*sse4_2.*") - set(ENABLE_SSE ON) + + # for enable sse4.2(hdfs used) + set(ENABLE_SSE OFF) + # for enable/disable ipp-crypto + if (APPLE) + EXECUTE_PROCESS(COMMAND sysctl -a COMMAND grep "machdep.cpu.*features" COMMAND tr -d '\n' OUTPUT_VARIABLE SUPPORTED_INSTRUCTIONS) + message("* SUPPORTED_INSTRUCTIONS: ${SUPPORTED_INSTRUCTIONS}") + # detect sse4.2 + if (${SUPPORTED_INSTRUCTIONS} MATCHES ".*SSE4.2.*") + set(ENABLE_SSE ON) + endif () + elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows") + # detect sse4_2 + FILE(READ "/proc/cpuinfo" SUPPORTED_INSTRUCTIONS) + if (${SUPPORTED_INSTRUCTIONS} MATCHES ".*sse4_2.*") + set(ENABLE_SSE ON) + endif () endif () -endif () -if(ENABLE_SSE) - list(APPEND VCPKG_MANIFEST_FEATURES "sse") -endif() + set(ENABLE_CPU_FEATURES OFF) + # only ENABLE_CPU_FEATURES for aarch64 and x86 + if ("${ARCHITECTURE}" MATCHES "aarch64") + add_definitions(-DARCH) + set(ENABLE_CPU_FEATURES ON) + endif () -set(ENABLE_CPU_FEATURES OFF) -# only ENABLE_CPU_FEATURES for aarch64 and x86 -if ("${ARCHITECTURE}" MATCHES "aarch64") - add_definitions(-DARCH) - set(ENABLE_CPU_FEATURES ON) -endif () + if ("${ARCHITECTURE}" MATCHES "x86_64") + add_definitions(-DX86) + set(ENABLE_CPU_FEATURES ON) + endif () -if ("${ARCHITECTURE}" MATCHES "x86_64") - add_definitions(-DX86) - set(ENABLE_CPU_FEATURES ON) -endif () + if (ENABLE_CPU_FEATURES) + add_definitions(-DENABLE_CPU_FEATURES) + endif () -if (ENABLE_CPU_FEATURES) - add_definitions(-DENABLE_CPU_FEATURES) -endif () + # Enable CONN_PSI Joint Running With Ant Company + if (ENABLE_CONN) + add_definitions(-DENABLE_CONN) + endif () -# Enable CONN_PSI Joint Running With Ant Company -if (ENABLE_CONN) - add_definitions(-DENABLE_CONN) -endif () + set(ENABLE_IPP_CRYPTO OFF) + # Note: only ENABLE_CRYPTO_MB for x86_64 + # if ("${ARCHITECTURE}" MATCHES "x86_64") + # set(ENABLE_IPP_CRYPTO ON) + # add_definitions(-DENABLE_CRYPTO_MB) + # endif () -set(ENABLE_IPP_CRYPTO OFF) -# Note: only ENABLE_CRYPTO_MB for x86_64 -# if ("${ARCHITECTURE}" MATCHES "x86_64") -# set(ENABLE_IPP_CRYPTO ON) -# add_definitions(-DENABLE_CRYPTO_MB) -# endif () + # fix the boost beast build failed for [call to 'async_teardown' is ambiguous] + add_definitions(-DBOOST_ASIO_DISABLE_CONCEPTS) -# fix the boost beast build failed for [call to 'async_teardown' is ambiguous] -add_definitions(-DBOOST_ASIO_DISABLE_CONCEPTS) + ####### options settings ###### + if (BUILD_UDF) + set(VISIBILITY_FLAG "") + set(BUILD_ALL OFF) + endif() + if (BUILD_SDK) + set(VISIBILITY_FLAG "") + set(BUILD_ALL OFF) + endif() + if (BUILD_ALL) + # install all dependencies + list(APPEND VCPKG_MANIFEST_FEATURES "all") + endif() + if(ENABLE_SSE) + # enable sse for libhdfs3 + list(APPEND VCPKG_MANIFEST_FEATURES "sse") + endif() + # cpp_features + if(ENABLE_CPU_FEATURES) + list(APPEND VCPKG_MANIFEST_FEATURES "cpufeatures") + message("##### append cpp_features: ${VCPKG_MANIFEST_FEATURES}") + endif() + ####### options settings ###### + print_config("WeDPR-Component") +endmacro() macro(print_config NAME) message("") diff --git a/cpp/ppc-crypto-core/src/CMakeLists.txt b/cpp/ppc-crypto-core/src/CMakeLists.txt index d40633ab..3387e7a2 100644 --- a/cpp/ppc-crypto-core/src/CMakeLists.txt +++ b/cpp/ppc-crypto-core/src/CMakeLists.txt @@ -5,5 +5,5 @@ find_package(OpenSSL REQUIRED) message(STATUS "OPENSSL_INCLUDE_DIR: ${OPENSSL_INCLUDE_DIR}") message(STATUS "OPENSSL_LIBRARIES: ${OPENSSL_LIBRARIES}") -find_package(unofficial-sodium REQUIRED) -target_link_libraries(${CRYPTO_CORE_TARGET} PUBLIC unofficial-sodium::sodium ${BCOS_UTILITIES_TARGET} OpenSSL::Crypto ${CPU_FEATURES_LIB}) \ No newline at end of file +target_link_libraries(${CRYPTO_CORE_TARGET} PUBLIC unofficial-sodium::sodium + unofficial-sodium::sodium_config_public ${BCOS_UTILITIES_TARGET} OpenSSL::Crypto ${CPU_FEATURES_LIB}) \ No newline at end of file diff --git a/cpp/ppc-crypto/src/CMakeLists.txt b/cpp/ppc-crypto/src/CMakeLists.txt index b699d88c..954f1ca6 100644 --- a/cpp/ppc-crypto/src/CMakeLists.txt +++ b/cpp/ppc-crypto/src/CMakeLists.txt @@ -1,12 +1,12 @@ file(GLOB_RECURSE SRCS *.cpp) add_library(${CRYPTO_TARGET} ${SRCS}) -find_package(unofficial-sodium REQUIRED) + find_package(OpenSSL REQUIRED) message(STATUS "OPENSSL_INCLUDE_DIR: ${OPENSSL_INCLUDE_DIR}") message(STATUS "OPENSSL_LIBRARIES: ${OPENSSL_LIBRARIES}") -target_link_libraries(${CRYPTO_TARGET} PUBLIC ${BCOS_UTILITIES_TARGET} ${CRYPTO_CORE_TARGET} OpenSSL::Crypto unofficial-sodium::sodium TBB::tbb ${CPU_FEATURES_LIB}) +target_link_libraries(${CRYPTO_TARGET} PUBLIC ${BCOS_UTILITIES_TARGET} ${CRYPTO_CORE_TARGET} OpenSSL::Crypto unofficial-sodium::sodium unofficial-sodium::sodium_config_public TBB::tbb ${CPU_FEATURES_LIB}) if (ENABLE_IPP_CRYPTO) find_package(ipp-crypto REQUIRED) diff --git a/cpp/vcpkg.json b/cpp/vcpkg.json index 34c765ca..635e6990 100644 --- a/cpp/vcpkg.json +++ b/cpp/vcpkg.json @@ -16,42 +16,15 @@ "name": "openssl", "version>=": "1.1.1-tassl" }, - "tarscpp", "libsodium", + { + "name": "libsodium", + "version>=": "1.0.18" + }, { "name": "bcos-utilities", "version>=": "1.0.0" } - , - { - "name": "bcos-boostssl", - "version>=": "3.2.3" - }, - { - "name": "seal", - "version>=": "4.0.0", - "features": ["no-throw-tran"] - }, - { - "name": "kuku", - "version>=": "2.1" - }, - { - "name": "redis-plus-plus", - "version>=": "1.3.6" - }, - { - "name": "mysql-connector-cpp", - "version>=": "8.0.32" - }, - { - "name": "cpu-features", - "version>=": "0.9.0" - }, - "libhdfs3", - "tarscpp", - "tbb", - "libxml2" ], "builtin-baseline": "51b14cd4e1230dd51c11ffeff6f7d53c61cc5297", "overrides": [ @@ -74,11 +47,54 @@ { "name": "tarscpp", "version": "3.0.3-1#1" + }, + { + "name": "libsodium", + "version": "1.0.18#9" } ], "features": { "sse": { "description": "Enable SSE4.2 for libhdfs3" + }, + "cpufeatures":{ + "description": "Enable cpu features", + "dependencies": [ + { + "name": "cpu-features", + "version>=": "0.9.0" + } + ] + }, + "all": { + "description": "all dependencies", + "dependencies": [ + { + "name": "bcos-boostssl", + "version>=": "3.2.3" + }, + { + "name": "seal", + "version>=": "4.0.0", + "features": ["no-throw-tran"] + }, + { + "name": "kuku", + "version>=": "2.1" + }, + { + "name": "redis-plus-plus", + "version>=": "1.3.6" + }, + { + "name": "mysql-connector-cpp", + "version>=": "8.0.32" + }, + "libhdfs3", + "tarscpp", + "tbb", + "libxml2" + ] } } } \ No newline at end of file From ae2217b66cafef2abe55a2d96ec9ce19fadc9b8e Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Thu, 22 Aug 2024 11:22:32 +0800 Subject: [PATCH 4/5] fix ci && upload artifact --- .github/workflows/cpp_workflow.yml | 96 ++++++++++++++++++- cpp/CMakeLists.txt | 2 +- cpp/cmake/CompilerSettings.cmake | 20 ++-- cpp/cmake/Coverage.cmake | 41 ++++++++ cpp/cmake/Dependencies.cmake | 12 +-- cpp/cmake/Options.cmake | 3 +- .../bindings/java/src/main/c/CMakeLists.txt | 3 +- 7 files changed, 152 insertions(+), 25 deletions(-) create mode 100644 cpp/cmake/Coverage.cmake diff --git a/.github/workflows/cpp_workflow.yml b/.github/workflows/cpp_workflow.yml index 5f75d524..3ec7f9c1 100644 --- a/.github/workflows/cpp_workflow.yml +++ b/.github/workflows/cpp_workflow.yml @@ -16,6 +16,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true +env: + RUST_BACKTRACE: 1 + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true jobs: build: @@ -72,7 +75,7 @@ jobs: if: runner.os != 'Windows' run: | cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE ctest - make cov + make coverage build_sdk: name: build sdk @@ -112,7 +115,7 @@ jobs: - name: Build for windows if: runner.os == 'Windows' run: | - mkdir -p build && cd build && cmake -DCMAKE_BUILD_TYPE=Release -DTESTS=OFF -DBUILD_SDK=ON -DVCPKG_TARGET_TRIPLET=x64-windows-static -DCMAKE_TOOLCHAIN_FILE=c:/vcpkg/scripts/buildsystems/vcpkg.cmake ../ + mkdir -p cpp/build && cd cpp/build && cmake -DCMAKE_BUILD_TYPE=Release -DTESTS=OFF -DBUILD_SDK=ON -DVCPKG_TARGET_TRIPLET=x64-windows-static -DCMAKE_TOOLCHAIN_FILE=c:/vcpkg/scripts/buildsystems/vcpkg.cmake ../ cmake --build . --parallel 3 - name: Build for linux if: runner.os == 'Linux' @@ -132,5 +135,90 @@ jobs: - name: Test if: runner.os != 'Windows' run: | - cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE ctest - make cov \ No newline at end of file + cd cpp/build && CTEST_OUTPUT_ON_FAILURE=TRUE make test + make coverage + - uses: actions/upload-artifact@v2 + if: runner.os == 'macos' + with: + name: libppc-crypto-sdk-jni.dylib + path: ./cpp/ppc-crypto-c-sdk/bindings/java/src/main/resources/META-INF/native/libppc-crypto-sdk-jni.dylib + - uses: actions/upload-artifact@v2 + if: runner.os == 'Windows' + with: + name: libppc-crypto-sdk-jni.dylib + path: D:\a\WeDPR-Component\cpp\ppc-crypto-c-sdk\bindings\java\src\main\resources\META-INF\native\Release\ppc-crypto-sdk-jni.dll + + build_centos: + name: build_centos full node + runs-on: ${{ matrix.os }} + continue-on-error: true + strategy: + fail-fast: false + matrix: + os: [ubuntu-20.04] + container: docker.io/centos:7 + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 5 + - uses: actions/cache@v2 + id: deps_cache + with: + path: | + /home/runner/.ccache + /Users/runner/.ccache/ + deps/ + key: centos-notest-${{ matrix.os }}-${{ github.base_ref }}-${{ hashFiles('.github/workflows/workflow.yml') }} + restore-keys: | + centos-notest-${{ matrix.os }}-${{ github.base_ref }}-${{ hashFiles('.github/workflows/workflow.yml') }} + centos-notest-${{ matrix.os }}-${{ github.base_ref }}- + centos-notest-${{ matrix.os }}- + - name: Prepare centos tools + run: | + sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo + sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo + sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo + yum install -y java-11-openjdk-devel git make gcc gcc-c++ glibc-static glibc-devel openssl cmake3 ccache devtoolset-11 llvm-toolset-7.0 rh-perl530-perl libzstd-devel zlib-devel flex bison python-devel python3-devel + yum install -y devtoolset-10 devtoolset-11 llvm-toolset-7 rh-perl530-perl cmake3 zlib-devel ccache lcov python-devel python3-devel + yum install -y https://packages.endpointdev.com/rhel/7/os/x86_64/endpoint-repo.x86_64.rpm + yum install -y git + - name: Prepare vcpkg + if: runner.os != 'Windows' + uses: friendlyanon/setup-vcpkg@v1 + with: { committish: 7e3dcf74e37034eea358934a90a11d618520e139 } + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-2022-07-28 + override: true + - name: Build + run: | + alias cmake='cmake3' + . /opt/rh/devtoolset-10/enable + . /opt/rh/rh-perl530/enable + export LIBCLANG_PATH=/opt/rh/llvm-toolset-7/root/lib64/ + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + . /opt/rh/llvm-toolset-7/enable + mkdir -p cpp/build + cd cpp/build + cmake3 -DCMAKE_BUILD_TYPE=Release -DTESTS=ON -DCMAKE_TOOLCHAIN_FILE=${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake ../ + cmake3 --build . --parallel 3 + - name: Test + run: | + export OMP_NUM_THREADS=1 + cd build && CTEST_OUTPUT_ON_FAILURE=TRUE make test + - uses: actions/upload-artifact@v2 + with: + name: ppc-air-node-centos-x64 + path: ./cpp/build/bin/ppc-air-node + - uses: actions/upload-artifact@v2 + with: + name: ppc-pro-node-centos-x64 + path: ./cpp/build/bin/ppc-pro-node + - uses: actions/upload-artifact@v2 + with: + name: ppc-gateway-service-centos-x64 + path: ./cpp/build/bin/ppc-gateway-service + - uses: actions/upload-artifact@v2 + with: + name: libppc-crypto-sdk-jni.so + path: ./cpp/ppc-crypto-c-sdk/bindings/java/src/main/resources/META-INF/native/libppc-crypto-sdk-jni.so \ No newline at end of file diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0fb668fc..698343ec 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -96,5 +96,5 @@ endif () # for code coverage if (COVERAGE) include(Coverage) - config_coverage("coverage" "'/usr*' 'boost/*'") + config_coverage("coverage" "") endif () diff --git a/cpp/cmake/CompilerSettings.cmake b/cpp/cmake/CompilerSettings.cmake index 005d130a..5e1140e3 100644 --- a/cpp/cmake/CompilerSettings.cmake +++ b/cpp/cmake/CompilerSettings.cmake @@ -111,16 +111,16 @@ if (("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MA if (COVERAGE) set(TESTS ON) - if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") - set(CMAKE_CXX_FLAGS "-g --coverage ${CMAKE_CXX_FLAGS}") - set(CMAKE_C_FLAGS "-g --coverage ${CMAKE_C_FLAGS}") - set(CMAKE_SHARED_LINKER_FLAGS "--coverage ${CMAKE_SHARED_LINKER_FLAGS}") - set(CMAKE_EXE_LINKER_FLAGS "--coverage ${CMAKE_EXE_LINKER_FLAGS}") - elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") - add_compile_options(-Wno-unused-command-line-argument) - set(CMAKE_CXX_FLAGS "-g -fprofile-arcs -ftest-coverage ${CMAKE_CXX_FLAGS}") - set(CMAKE_C_FLAGS "-g -fprofile-arcs -ftest-coverage ${CMAKE_C_FLAGS}") - endif() + if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") + set(CMAKE_CXX_FLAGS "-g --coverage ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-g --coverage ${CMAKE_C_FLAGS}") + set(CMAKE_SHARED_LINKER_FLAGS "--coverage ${CMAKE_SHARED_LINKER_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "--coverage ${CMAKE_EXE_LINKER_FLAGS}") + elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + add_compile_options(-Wno-unused-command-line-argument) + set(CMAKE_CXX_FLAGS "-g -fprofile-arcs -ftest-coverage ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-g -fprofile-arcs -ftest-coverage ${CMAKE_C_FLAGS}") + endif() endif () elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "MSVC") diff --git a/cpp/cmake/Coverage.cmake b/cpp/cmake/Coverage.cmake new file mode 100644 index 00000000..b76a97e2 --- /dev/null +++ b/cpp/cmake/Coverage.cmake @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------------ +# Copyright (C) 2021 FISCO BCOS. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------ +# File: Coverage.cmake +# Function: Define coverage related functions +# ------------------------------------------------------------------------------ +# REMOVE_FILE_PATTERN eg.: '/usr*' '${CMAKE_SOURCE_DIR}/deps**' '${CMAKE_SOURCE_DIR}/evmc*' ‘${CMAKE_SOURCE_DIR}/fisco-bcos*’ +function(config_coverage TARGET REMOVE_FILE_PATTERN) + find_program(LCOV_TOOL lcov) + message(STATUS "lcov tool: ${LCOV_TOOL}") + if (LCOV_TOOL) + message(STATUS "coverage dir: " ${CMAKE_BINARY_DIR}) + message(STATUS "coverage TARGET: " ${TARGET}) + message(STATUS "coverage REMOVE_FILE_PATTERN: " ${REMOVE_FILE_PATTERN}) + if (APPLE) + add_custom_target(${TARGET} + COMMAND ${LCOV_TOOL} -keep-going --ignore-errors inconsistent,unmapped,source --rc lcov_branch_coverage=1 -o ${CMAKE_BINARY_DIR}/coverage.info.in -c -i -d ${CMAKE_BINARY_DIR}/ + COMMAND ${LCOV_TOOL} -keep-going --ignore-errors inconsistent,unmapped,source --rc lcov_branch_coverage=1 -r ${CMAKE_BINARY_DIR}/coverage.info.in '*MacOS*' '/usr*' '.*vcpkg_installed*' '.*boost/*' '*test*' '*build*' '*deps*' ${REMOVE_FILE_PATTERN} -o ${CMAKE_BINARY_DIR}/coverage.info + COMMAND genhtml --keep-going --ignore-errors inconsistent,unmapped,source --rc lcov_branch_coverage=1 -q -o ${CMAKE_BINARY_DIR}/CodeCoverage ${CMAKE_BINARY_DIR}/coverage.info) + else() + add_custom_target(${TARGET} + COMMAND ${LCOV_TOOL} --rc lcov_branch_coverage=1 -o ${CMAKE_BINARY_DIR}/coverage.info.in -c -i -d ${CMAKE_BINARY_DIR}/ + COMMAND ${LCOV_TOOL} --rc lcov_branch_coverage=1 -r ${CMAKE_BINARY_DIR}/coverage.info.in '*MacOS*' '/usr*' '.*vcpkg_installed*' '.*boost/*' '*test*' '*build*' '*deps*' ${REMOVE_FILE_PATTERN} -o ${CMAKE_BINARY_DIR}/coverage.info + COMMAND genhtml --rc lcov_branch_coverage=1 -q -o ${CMAKE_BINARY_DIR}/CodeCoverage ${CMAKE_BINARY_DIR}/coverage.info) + endif() + else () + message(FATAL_ERROR "Can't find lcov tool. Please install lcov") + endif() +endfunction() \ No newline at end of file diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake index 4354187e..ff1000b4 100644 --- a/cpp/cmake/Dependencies.cmake +++ b/cpp/cmake/Dependencies.cmake @@ -41,11 +41,9 @@ endif() ##### the full-dependencies end ##### ##### the sdk-dependencies ##### -if(BUILD_SDK) - # find JNI - set(JAVA_AWT_LIBRARY NotNeeded) - set(JAVA_JVM_LIBRARY NotNeeded) - find_package(JNI REQUIRED) - include_directories(${JNI_INCLUDE_DIRS}) -endif() +# find JNI +set(JAVA_AWT_LIBRARY NotNeeded) +set(JAVA_JVM_LIBRARY NotNeeded) +find_package(JNI REQUIRED) +include_directories(${JNI_INCLUDE_DIRS}) ##### the sdk-dependencies end##### \ No newline at end of file diff --git a/cpp/cmake/Options.cmake b/cpp/cmake/Options.cmake index 0588372d..1b4245ab 100644 --- a/cpp/cmake/Options.cmake +++ b/cpp/cmake/Options.cmake @@ -140,7 +140,6 @@ macro(configure_project) # cpp_features if(ENABLE_CPU_FEATURES) list(APPEND VCPKG_MANIFEST_FEATURES "cpufeatures") - message("##### append cpp_features: ${VCPKG_MANIFEST_FEATURES}") endif() ####### options settings ###### print_config("WeDPR-Component") @@ -154,6 +153,8 @@ macro(print_config NAME) message("-- CMake Cmake version and location ${CMAKE_VERSION} (${CMAKE_COMMAND})") message("-- Compiler C++ compiler version ${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") message("-- CMAKE_BUILD_TYPE Build type ${CMAKE_BUILD_TYPE}") + message("-- VCPKG_MANIFEST_FEATURES VCPKG manifest features ${VCPKG_MANIFEST_FEATURES}") + message("-- CMAKE_TOOLCHAIN_FILE Cmake toolchain file ${CMAKE_TOOLCHAIN_FILE}") message("-- TARGET_PLATFORM Target platform ${CMAKE_SYSTEM_NAME} ${ARCHITECTURE}") message("-- BUILD_STATIC Build static ${BUILD_STATIC}") message("-- COVERAGE Build code coverage ${COVERAGE}") diff --git a/cpp/ppc-crypto-c-sdk/bindings/java/src/main/c/CMakeLists.txt b/cpp/ppc-crypto-c-sdk/bindings/java/src/main/c/CMakeLists.txt index 3741b8c1..94a27864 100644 --- a/cpp/ppc-crypto-c-sdk/bindings/java/src/main/c/CMakeLists.txt +++ b/cpp/ppc-crypto-c-sdk/bindings/java/src/main/c/CMakeLists.txt @@ -11,5 +11,4 @@ target_link_libraries(${PPC_CRYPTO_SDK_JNI_STATIC_TARGET} PUBLIC ${PPC_CRYPTO_C_ SET(LIBRARY_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../resources/META-INF/native/) message(STATUS "CMAKE_INSTALL_INCLUDEDIR => ${CMAKE_INSTALL_INCLUDEDIR}") -message(STATUS "CMAKE_CURRENT_SOURCE_DIR => ${CMAKE_CURRENT_SOURCE_DIR}") -message(STATUS "LIB_DIR_PATH => ${CMAKE_CURRENT_SOURCE_DIR}/../resources/META-INF/native/") \ No newline at end of file +message(STATUS "CMAKE_CURRENT_SOURCE_DIR => ${CMAKE_CURRENT_SOURCE_DIR}") \ No newline at end of file From c0bdf753597ee4bca88bcbb5f76fd8165443f609 Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Thu, 22 Aug 2024 14:29:52 +0800 Subject: [PATCH 5/5] format python code (#6) --- python/ppc_common/ppc_crypto/ihc_cipher.py | 27 ++-- .../ppc_common/ppc_crypto/paillier_cipher.py | 12 +- .../ppc_common/ppc_crypto/paillier_codec.py | 2 +- .../ppc_crypto/test/phe_unittest.py | 8 +- .../ppc_protos/generated/ppc_model_pb2.py | 58 ++++---- python/ppc_common/ppc_utils/http_utils.py | 4 +- .../ppc_utils/ppc_model_config_parser.py | 2 +- python/ppc_model/common/base_context.py | 135 ++++++++++++------ python/ppc_model/common/global_context.py | 3 +- python/ppc_model/common/model_result.py | 63 ++++---- python/ppc_model/common/protocol.py | 3 +- python/ppc_model/datasets/dataset.py | 16 ++- .../test/test_feature_binning.py | 16 ++- .../ppc_model/datasets/test/test_dataset.py | 8 +- .../feature_engineering_engine.py | 6 +- .../vertical/active_party.py | 12 +- .../vertical/passive_party.py | 24 ++-- .../feature_engineering/vertical/utils.py | 6 +- python/ppc_model/metrics/evaluation.py | 90 +++++++----- python/ppc_model/metrics/model_plot.py | 50 ++++--- python/ppc_model/metrics/test/test_metrics.py | 45 +++--- .../network/http/model_controller.py | 6 +- python/ppc_model/network/http/restx.py | 3 +- python/ppc_model/network/stub.py | 3 +- python/ppc_model/ppc_model_app.py | 50 +++---- .../local_processing/preprocessing.py | 6 +- .../preprocessing/tests/test_preprocessing.py | 27 ++-- .../ppc_model/secure_lgbm/monitor/callback.py | 6 +- .../secure_lgbm/monitor/early_stopping.py | 3 +- .../secure_lgbm/monitor/evaluation_monitor.py | 3 +- .../monitor/train_callback_unittest.py | 12 +- .../secure_lgbm_prediction_engine.py | 2 +- .../secure_lgbm/test/test_cipher_packing.py | 24 ++-- .../secure_lgbm/test/test_pack_gh.py | 6 +- .../secure_lgbm/test/test_save_load_model.py | 14 +- .../test_secure_lgbm_performance_training.py | 24 ++-- .../test/test_secure_lgbm_training.py | 12 +- .../secure_lgbm/vertical/active_party.py | 101 ++++++++----- .../ppc_model/secure_lgbm/vertical/booster.py | 77 ++++++---- .../secure_lgbm/vertical/passive_party.py | 38 +++-- python/ppc_model/task/task_manager.py | 12 +- .../task/test/task_manager_unittest.py | 6 +- .../ppc_model_gateway_app.py | 27 ++-- python/ppc_model_gateway/test/server.py | 3 +- 44 files changed, 640 insertions(+), 415 deletions(-) diff --git a/python/ppc_common/ppc_crypto/ihc_cipher.py b/python/ppc_common/ppc_crypto/ihc_cipher.py index 7cc385ff..54061208 100644 --- a/python/ppc_common/ppc_crypto/ihc_cipher.py +++ b/python/ppc_common/ppc_crypto/ihc_cipher.py @@ -10,19 +10,19 @@ @dataclass class IhcCiphertext(): __slots__ = ['c_left', 'c_right'] - + def __init__(self, c_left: int, c_right: int) -> None: self.c_left = c_left self.c_right = c_right - + def __add__(self, other): cipher_left = self.c_left + other.c_left cipher_right = self.c_right + other.c_right return IhcCiphertext(cipher_left, cipher_right) - + def __eq__(self, other): return self.c_left == other.c_left and self.c_right == other.c_right - + def encode(self) -> bytes: # 计算每个整数的字节长度 len_c_left = (self.c_left.bit_length() + 7) // 8 @@ -37,17 +37,20 @@ def encode(self) -> bytes: # 返回所有数据 return len_bytes + c_left_bytes + c_right_bytes - + @classmethod def decode(cls, encoded_data: bytes): # 解码整数的长度 len_c_left, len_c_right = struct.unpack('>II', encoded_data[:8]) # 根据长度解码整数 - c_left = int.from_bytes(encoded_data[8:8 + len_c_left], byteorder='big') - c_right = int.from_bytes(encoded_data[8 + len_c_left:8 + len_c_left + len_c_right], byteorder='big') + c_left = int.from_bytes( + encoded_data[8:8 + len_c_left], byteorder='big') + c_right = int.from_bytes( + encoded_data[8 + len_c_left:8 + len_c_left + len_c_right], byteorder='big') return cls(c_left, c_right) - + + class IhcCipher(PheCipher): def __init__(self, key_length: int = 256, iter_round: int = 16) -> None: super().__init__(key_length) @@ -56,9 +59,9 @@ def __init__(self, key_length: int = 256, iter_round: int = 16) -> None: self.private_key = key self.iter_round = iter_round self.key_length = key_length - + self.max_mod = 1 << key_length - + def encrypt(self, number: int) -> IhcCiphertext: random_u = secrets.randbits(self.key_length) x_this = number @@ -70,7 +73,7 @@ def encrypt(self, number: int) -> IhcCiphertext: # cipher = IhcCiphertext(x_this, x_last, self.max_mod) cipher = IhcCiphertext(x_this, x_last) return cipher - + def decrypt(self, cipher: IhcCiphertext) -> int: x_this = cipher.c_right x_last = cipher.c_left @@ -79,7 +82,7 @@ def decrypt(self, cipher: IhcCiphertext) -> int: x_last = x_this x_this = x_tmp return x_this - + def encrypt_batch(self, numbers) -> list: return [self.encrypt(num) for num in numbers] diff --git a/python/ppc_common/ppc_crypto/paillier_cipher.py b/python/ppc_common/ppc_crypto/paillier_cipher.py index d2e0232a..822093be 100644 --- a/python/ppc_common/ppc_crypto/paillier_cipher.py +++ b/python/ppc_common/ppc_crypto/paillier_cipher.py @@ -29,17 +29,21 @@ def decrypt_batch(self, ciphers) -> list: def encrypt_batch_parallel(self, numbers) -> list: num_cores = os.cpu_count() batch_size = math.ceil(len(numbers) / num_cores) - batches = [numbers[i:i + batch_size] for i in range(0, len(numbers), batch_size)] + batches = [numbers[i:i + batch_size] + for i in range(0, len(numbers), batch_size)] with ProcessPoolExecutor(max_workers=num_cores) as executor: - futures = [executor.submit(self.encrypt_batch, batch) for batch in batches] + futures = [executor.submit(self.encrypt_batch, batch) + for batch in batches] result = [future.result() for future in futures] return [item for sublist in result for item in sublist] def decrypt_batch_parallel(self, ciphers) -> list: num_cores = os.cpu_count() batch_size = math.ceil(len(ciphers) / num_cores) - batches = [ciphers[i:i + batch_size] for i in range(0, len(ciphers), batch_size)] + batches = [ciphers[i:i + batch_size] + for i in range(0, len(ciphers), batch_size)] with ProcessPoolExecutor(max_workers=num_cores) as executor: - futures = [executor.submit(self.decrypt_batch, batch) for batch in batches] + futures = [executor.submit(self.decrypt_batch, batch) + for batch in batches] result = [future.result() for future in futures] return [item for sublist in result for item in sublist] diff --git a/python/ppc_common/ppc_crypto/paillier_codec.py b/python/ppc_common/ppc_crypto/paillier_codec.py index fb5f3bf0..d66f3c95 100644 --- a/python/ppc_common/ppc_crypto/paillier_codec.py +++ b/python/ppc_common/ppc_crypto/paillier_codec.py @@ -24,7 +24,7 @@ def decode_enc_key(public_key_bytes: bytes) -> PaillierPublicKey: @staticmethod def encode_cipher(cipher: EncryptedNumber, be_secure=True) -> Tuple[bytes, bytes]: return PaillierCodec._int_to_bytes(cipher.ciphertext(be_secure=be_secure)), \ - PaillierCodec._int_to_bytes(cipher.exponent) + PaillierCodec._int_to_bytes(cipher.exponent) @staticmethod def decode_cipher(public_key: PaillierPublicKey, ciphertext: bytes, exponent: bytes) -> EncryptedNumber: diff --git a/python/ppc_common/ppc_crypto/test/phe_unittest.py b/python/ppc_common/ppc_crypto/test/phe_unittest.py index 5d036b75..a8ef12a7 100644 --- a/python/ppc_common/ppc_crypto/test/phe_unittest.py +++ b/python/ppc_common/ppc_crypto/test/phe_unittest.py @@ -30,7 +30,7 @@ def test_enc_and_dec_parallel(self): print("dec_p:", end_time - start_time, "seconds") self.assertListEqual(list(inputs), list(outputs)) - + def test_ihc_enc_and_dec_parallel(self): ihc = IhcCipher(key_length=256) try_size = 100000 @@ -48,17 +48,17 @@ def test_ihc_enc_and_dec_parallel(self): cipher_start = ciphers[0] for i in range(1, len(ciphers)): cipher_left = (cipher_start.c_left + ciphers[i].c_left) - cipher_right = (cipher_start.c_right + ciphers[i].c_right ) + cipher_right = (cipher_start.c_right + ciphers[i].c_right) # IhcCiphertext(cipher_left, cipher_right, cipher_start.max_mod) IhcCiphertext(cipher_left, cipher_right) end_time = time.time() print(f"size:{try_size}, add_p raw with class: {end_time - start_time} seconds, average times: {(end_time - start_time)/try_size * 1000 * 1000} us") - + start_time = time.time() cipher_start = ciphers[0] for i in range(1, len(ciphers)): cipher_left = (cipher_start.c_left + ciphers[i].c_left) - cipher_right = (cipher_start.c_right + ciphers[i].c_right ) + cipher_right = (cipher_start.c_right + ciphers[i].c_right) # IhcCiphertext(cipher_left, cipher_right) end_time = time.time() print(f"size:{try_size}, add_p raw: {end_time - start_time} seconds, average times: {(end_time - start_time)/try_size * 1000 * 1000} us") diff --git a/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py b/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py index 57f7ca3a..ced82b9e 100644 --- a/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py +++ b/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py @@ -12,40 +12,38 @@ _sym_db = _symbol_database.Default() - - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fppc_model.proto\x12\tppc.model\"|\n\x0cModelRequest\x12\x0e\n\x06sender\x18\x01 \x01(\t\x12\x10\n\x08receiver\x18\x02 \x01(\t\x12\x0f\n\x07task_id\x18\x03 \x01(\t\x12\x0b\n\x03key\x18\x04 \x01(\t\x12\x0b\n\x03seq\x18\x05 \x01(\x03\x12\x11\n\tslice_num\x18\x06 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x07 \x01(\x0c\"3\n\x0c\x42\x61seResponse\x12\x12\n\nerror_code\x18\x01 \x01(\x03\x12\x0f\n\x07message\x18\x02 \x01(\t\"M\n\rModelResponse\x12.\n\rbase_response\x18\x01 \x01(\x0b\x32\x17.ppc.model.BaseResponse\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"#\n\rPlainBoolList\x12\x12\n\nplain_list\x18\x01 \x03(\x08\"\xb1\x01\n\rBestSplitInfo\x12\x0f\n\x07tree_id\x18\x01 \x01(\x03\x12\x0f\n\x07leaf_id\x18\x02 \x01(\x03\x12\x0f\n\x07\x66\x65\x61ture\x18\x03 \x01(\x03\x12\r\n\x05value\x18\x04 \x01(\x03\x12\x12\n\nagency_idx\x18\x05 \x01(\x03\x12\x16\n\x0e\x61gency_feature\x18\x06 \x01(\x03\x12\x11\n\tbest_gain\x18\x07 \x01(\x02\x12\x0e\n\x06w_left\x18\x08 \x01(\x02\x12\x0f\n\x07w_right\x18\t \x01(\x02\"3\n\x0bModelCipher\x12\x12\n\nciphertext\x18\x01 \x01(\x0c\x12\x10\n\x08\x65xponent\x18\x02 \x01(\x0c\"M\n\nCipherList\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12+\n\x0b\x63ipher_list\x18\x02 \x03(\x0b\x32\x16.ppc.model.ModelCipher\"=\n\x0e\x43ipher1DimList\x12+\n\x0b\x63ipher_list\x18\x01 \x03(\x0b\x32\x16.ppc.model.ModelCipher\"W\n\x0e\x43ipher2DimList\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x31\n\x0e\x63ipher_1d_list\x18\x02 \x03(\x0b\x32\x19.ppc.model.Cipher1DimList\"_\n\rEncAggrLabels\x12\r\n\x05\x66ield\x18\x01 \x01(\t\x12\x12\n\ncount_list\x18\x02 \x03(\x03\x12+\n\x0b\x63ipher_list\x18\x03 \x03(\x0b\x32\x16.ppc.model.ModelCipher\"_\n\x11\x45ncAggrLabelsList\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x36\n\x14\x65nc_aggr_labels_list\x18\x02 \x03(\x0b\x32\x18.ppc.model.EncAggrLabels\"/\n\x10IterationRequest\x12\r\n\x05\x65poch\x18\x01 \x01(\x03\x12\x0c\n\x04stop\x18\x02 \x01(\x08\x32Y\n\x0cModelService\x12I\n\x12MessageInteraction\x12\x17.ppc.model.ModelRequest\x1a\x18.ppc.model.ModelResponse\"\x00\x42\x08P\x01\xa2\x02\x03PPCb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ppc_model_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - _globals['DESCRIPTOR']._options = None - _globals['DESCRIPTOR']._serialized_options = b'P\001\242\002\003PPC' - _globals['_MODELREQUEST']._serialized_start=30 - _globals['_MODELREQUEST']._serialized_end=154 - _globals['_BASERESPONSE']._serialized_start=156 - _globals['_BASERESPONSE']._serialized_end=207 - _globals['_MODELRESPONSE']._serialized_start=209 - _globals['_MODELRESPONSE']._serialized_end=286 - _globals['_PLAINBOOLLIST']._serialized_start=288 - _globals['_PLAINBOOLLIST']._serialized_end=323 - _globals['_BESTSPLITINFO']._serialized_start=326 - _globals['_BESTSPLITINFO']._serialized_end=503 - _globals['_MODELCIPHER']._serialized_start=505 - _globals['_MODELCIPHER']._serialized_end=556 - _globals['_CIPHERLIST']._serialized_start=558 - _globals['_CIPHERLIST']._serialized_end=635 - _globals['_CIPHER1DIMLIST']._serialized_start=637 - _globals['_CIPHER1DIMLIST']._serialized_end=698 - _globals['_CIPHER2DIMLIST']._serialized_start=700 - _globals['_CIPHER2DIMLIST']._serialized_end=787 - _globals['_ENCAGGRLABELS']._serialized_start=789 - _globals['_ENCAGGRLABELS']._serialized_end=884 - _globals['_ENCAGGRLABELSLIST']._serialized_start=886 - _globals['_ENCAGGRLABELSLIST']._serialized_end=981 - _globals['_ITERATIONREQUEST']._serialized_start=983 - _globals['_ITERATIONREQUEST']._serialized_end=1030 - _globals['_MODELSERVICE']._serialized_start=1032 - _globals['_MODELSERVICE']._serialized_end=1121 + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'P\001\242\002\003PPC' + _globals['_MODELREQUEST']._serialized_start = 30 + _globals['_MODELREQUEST']._serialized_end = 154 + _globals['_BASERESPONSE']._serialized_start = 156 + _globals['_BASERESPONSE']._serialized_end = 207 + _globals['_MODELRESPONSE']._serialized_start = 209 + _globals['_MODELRESPONSE']._serialized_end = 286 + _globals['_PLAINBOOLLIST']._serialized_start = 288 + _globals['_PLAINBOOLLIST']._serialized_end = 323 + _globals['_BESTSPLITINFO']._serialized_start = 326 + _globals['_BESTSPLITINFO']._serialized_end = 503 + _globals['_MODELCIPHER']._serialized_start = 505 + _globals['_MODELCIPHER']._serialized_end = 556 + _globals['_CIPHERLIST']._serialized_start = 558 + _globals['_CIPHERLIST']._serialized_end = 635 + _globals['_CIPHER1DIMLIST']._serialized_start = 637 + _globals['_CIPHER1DIMLIST']._serialized_end = 698 + _globals['_CIPHER2DIMLIST']._serialized_start = 700 + _globals['_CIPHER2DIMLIST']._serialized_end = 787 + _globals['_ENCAGGRLABELS']._serialized_start = 789 + _globals['_ENCAGGRLABELS']._serialized_end = 884 + _globals['_ENCAGGRLABELSLIST']._serialized_start = 886 + _globals['_ENCAGGRLABELSLIST']._serialized_end = 981 + _globals['_ITERATIONREQUEST']._serialized_start = 983 + _globals['_ITERATIONREQUEST']._serialized_end = 1030 + _globals['_MODELSERVICE']._serialized_start = 1032 + _globals['_MODELSERVICE']._serialized_end = 1121 # @@protoc_insertion_point(module_scope) diff --git a/python/ppc_common/ppc_utils/http_utils.py b/python/ppc_common/ppc_utils/http_utils.py index 4489240d..8dc43a56 100644 --- a/python/ppc_common/ppc_utils/http_utils.py +++ b/python/ppc_common/ppc_utils/http_utils.py @@ -28,7 +28,8 @@ def send_get_request(endpoint, uri, params=None, headers=None): else: url = f"http://{endpoint}" log.debug(f"send a get request, url: {url}, params: {params}") - response = requests.get(url=url, params=params, headers=headers, timeout=30) + response = requests.get(url=url, params=params, + headers=headers, timeout=30) log.debug(f"response: {response.text}") check_response(response) response_data = json.loads(response.text) @@ -99,4 +100,3 @@ def send_upload_request(endpoint, uri, params=None, headers=None, data=None): except JSONDecodeError: response_data = response.text return response_data - diff --git a/python/ppc_common/ppc_utils/ppc_model_config_parser.py b/python/ppc_common/ppc_utils/ppc_model_config_parser.py index ab2cce26..04ea012f 100644 --- a/python/ppc_common/ppc_utils/ppc_model_config_parser.py +++ b/python/ppc_common/ppc_utils/ppc_model_config_parser.py @@ -471,4 +471,4 @@ def generate_mpc_predict_algorithm(algorithm_name, layers, participants, is_psi) # PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_msg()) # except BaseException as e: # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_code(), - # PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_msg()) \ No newline at end of file + # PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_msg()) diff --git a/python/ppc_model/common/base_context.py b/python/ppc_model/common/base_context.py index a6c1b582..43f1e873 100644 --- a/python/ppc_model/common/base_context.py +++ b/python/ppc_model/common/base_context.py @@ -27,56 +27,101 @@ def __init__(self, job_id: str, job_temp_dir: str): self.workspace = os.path.join(job_temp_dir, self.job_id) if not os.path.exists(self.workspace): os.makedirs(self.workspace) - self.psi_result_path = os.path.join(self.workspace, self.PSI_RESULT_FILE) - self.model_prepare_file = os.path.join(self.workspace, self.MODEL_PREPARE_FILE) - self.preprocessing_result_file = os.path.join(self.workspace, self.PREPROCESSING_RESULT_FILE) - self.eval_column_file = os.path.join(self.workspace, self.EVAL_COLUMN_FILE) + self.psi_result_path = os.path.join( + self.workspace, self.PSI_RESULT_FILE) + self.model_prepare_file = os.path.join( + self.workspace, self.MODEL_PREPARE_FILE) + self.preprocessing_result_file = os.path.join( + self.workspace, self.PREPROCESSING_RESULT_FILE) + self.eval_column_file = os.path.join( + self.workspace, self.EVAL_COLUMN_FILE) self.woe_iv_file = os.path.join(self.workspace, self.WOE_IV_FILE) - self.iv_selected_file = os.path.join(self.workspace, self.IV_SELECTED_FILE) - self.selected_col_file = os.path.join(self.workspace, self.SELECTED_COL_FILE) - self.remote_selected_col_file = os.path.join(self.job_id, self.SELECTED_COL_FILE) + self.iv_selected_file = os.path.join( + self.workspace, self.IV_SELECTED_FILE) + self.selected_col_file = os.path.join( + self.workspace, self.SELECTED_COL_FILE) + self.remote_selected_col_file = os.path.join( + self.job_id, self.SELECTED_COL_FILE) - self.summary_evaluation_file = os.path.join(self.workspace, utils.MPC_XGB_EVALUATION_TABLE) - self.feature_importance_file = os.path.join(self.workspace, utils.XGB_FEATURE_IMPORTANCE_TABLE) - self.feature_bin_file = os.path.join(self.workspace, self.FEATURE_BIN_FILE) - self.model_data_file = os.path.join(self.workspace, self.MODEL_DATA_FILE) - self.test_model_result_file = os.path.join(self.workspace, self.TEST_MODEL_RESULT_FILE) - self.test_model_output_file = os.path.join(self.workspace, self.TEST_MODEL_OUTPUT_FILE) - self.train_model_result_file = os.path.join(self.workspace, self.TRAIN_MODEL_RESULT_FILE) - self.train_model_output_file = os.path.join(self.workspace, self.TRAIN_MODEL_OUTPUT_FILE) + self.summary_evaluation_file = os.path.join( + self.workspace, utils.MPC_XGB_EVALUATION_TABLE) + self.feature_importance_file = os.path.join( + self.workspace, utils.XGB_FEATURE_IMPORTANCE_TABLE) + self.feature_bin_file = os.path.join( + self.workspace, self.FEATURE_BIN_FILE) + self.model_data_file = os.path.join( + self.workspace, self.MODEL_DATA_FILE) + self.test_model_result_file = os.path.join( + self.workspace, self.TEST_MODEL_RESULT_FILE) + self.test_model_output_file = os.path.join( + self.workspace, self.TEST_MODEL_OUTPUT_FILE) + self.train_model_result_file = os.path.join( + self.workspace, self.TRAIN_MODEL_RESULT_FILE) + self.train_model_output_file = os.path.join( + self.workspace, self.TRAIN_MODEL_OUTPUT_FILE) - self.train_metric_roc_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_ROC_FILE) - self.train_metric_ks_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_KS_FILE) - self.train_metric_pr_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_PR_FILE) - self.train_metric_acc_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE) - self.test_metric_roc_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_ROC_FILE) - self.test_metric_ks_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_KS_FILE) - self.test_metric_pr_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_PR_FILE) - self.test_metric_acc_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_ACCURACY_FILE) - self.train_metric_ks_table = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_KS_TABLE) - self.test_metric_ks_table = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_KS_TABLE) - self.model_tree_prefix = os.path.join(self.workspace, utils.XGB_TREE_PERFIX) - self.metrics_iteration_file = os.path.join(self.workspace, utils.METRICS_OVER_ITERATION_FILE) + self.train_metric_roc_file = os.path.join( + self.workspace, utils.MPC_TRAIN_SET_METRIC_ROC_FILE) + self.train_metric_ks_file = os.path.join( + self.workspace, utils.MPC_TRAIN_SET_METRIC_KS_FILE) + self.train_metric_pr_file = os.path.join( + self.workspace, utils.MPC_TRAIN_SET_METRIC_PR_FILE) + self.train_metric_acc_file = os.path.join( + self.workspace, utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE) + self.test_metric_roc_file = os.path.join( + self.workspace, utils.MPC_TRAIN_METRIC_ROC_FILE) + self.test_metric_ks_file = os.path.join( + self.workspace, utils.MPC_TRAIN_METRIC_KS_FILE) + self.test_metric_pr_file = os.path.join( + self.workspace, utils.MPC_TRAIN_METRIC_PR_FILE) + self.test_metric_acc_file = os.path.join( + self.workspace, utils.MPC_TRAIN_METRIC_ACCURACY_FILE) + self.train_metric_ks_table = os.path.join( + self.workspace, utils.MPC_TRAIN_SET_METRIC_KS_TABLE) + self.test_metric_ks_table = os.path.join( + self.workspace, utils.MPC_TRAIN_METRIC_KS_TABLE) + self.model_tree_prefix = os.path.join( + self.workspace, utils.XGB_TREE_PERFIX) + self.metrics_iteration_file = os.path.join( + self.workspace, utils.METRICS_OVER_ITERATION_FILE) - self.remote_summary_evaluation_file = os.path.join(self.job_id, utils.MPC_XGB_EVALUATION_TABLE) - self.remote_feature_importance_file = os.path.join(self.job_id, utils.XGB_FEATURE_IMPORTANCE_TABLE) - self.remote_feature_bin_file = os.path.join(self.job_id, self.FEATURE_BIN_FILE) - self.remote_model_data_file = os.path.join(self.job_id, self.MODEL_DATA_FILE) - self.remote_test_model_output_file = os.path.join(self.job_id, self.TEST_MODEL_OUTPUT_FILE) - self.remote_train_model_output_file = os.path.join(self.job_id, self.TRAIN_MODEL_OUTPUT_FILE) + self.remote_summary_evaluation_file = os.path.join( + self.job_id, utils.MPC_XGB_EVALUATION_TABLE) + self.remote_feature_importance_file = os.path.join( + self.job_id, utils.XGB_FEATURE_IMPORTANCE_TABLE) + self.remote_feature_bin_file = os.path.join( + self.job_id, self.FEATURE_BIN_FILE) + self.remote_model_data_file = os.path.join( + self.job_id, self.MODEL_DATA_FILE) + self.remote_test_model_output_file = os.path.join( + self.job_id, self.TEST_MODEL_OUTPUT_FILE) + self.remote_train_model_output_file = os.path.join( + self.job_id, self.TRAIN_MODEL_OUTPUT_FILE) - self.remote_train_metric_roc_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_ROC_FILE) - self.remote_train_metric_ks_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_KS_FILE) - self.remote_train_metric_pr_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_PR_FILE) - self.remote_train_metric_acc_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE) - self.remote_test_metric_roc_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_ROC_FILE) - self.remote_test_metric_ks_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_KS_FILE) - self.remote_test_metric_pr_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_PR_FILE) - self.remote_test_metric_acc_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_ACCURACY_FILE) - self.remote_train_metric_ks_table = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_KS_TABLE) - self.remote_test_metric_ks_table = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_KS_TABLE) - self.remote_model_tree_prefix = os.path.join(self.job_id, utils.XGB_TREE_PERFIX) - self.remote_metrics_iteration_file = os.path.join(self.job_id, utils.METRICS_OVER_ITERATION_FILE) + self.remote_train_metric_roc_file = os.path.join( + self.job_id, utils.MPC_TRAIN_SET_METRIC_ROC_FILE) + self.remote_train_metric_ks_file = os.path.join( + self.job_id, utils.MPC_TRAIN_SET_METRIC_KS_FILE) + self.remote_train_metric_pr_file = os.path.join( + self.job_id, utils.MPC_TRAIN_SET_METRIC_PR_FILE) + self.remote_train_metric_acc_file = os.path.join( + self.job_id, utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE) + self.remote_test_metric_roc_file = os.path.join( + self.job_id, utils.MPC_TRAIN_METRIC_ROC_FILE) + self.remote_test_metric_ks_file = os.path.join( + self.job_id, utils.MPC_TRAIN_METRIC_KS_FILE) + self.remote_test_metric_pr_file = os.path.join( + self.job_id, utils.MPC_TRAIN_METRIC_PR_FILE) + self.remote_test_metric_acc_file = os.path.join( + self.job_id, utils.MPC_TRAIN_METRIC_ACCURACY_FILE) + self.remote_train_metric_ks_table = os.path.join( + self.job_id, utils.MPC_TRAIN_SET_METRIC_KS_TABLE) + self.remote_test_metric_ks_table = os.path.join( + self.job_id, utils.MPC_TRAIN_METRIC_KS_TABLE) + self.remote_model_tree_prefix = os.path.join( + self.job_id, utils.XGB_TREE_PERFIX) + self.remote_metrics_iteration_file = os.path.join( + self.job_id, utils.METRICS_OVER_ITERATION_FILE) @staticmethod def feature_engineering_input_path(job_id: str, job_temp_dir: str): diff --git a/python/ppc_model/common/global_context.py b/python/ppc_model/common/global_context.py index 13437d90..d552bef5 100644 --- a/python/ppc_model/common/global_context.py +++ b/python/ppc_model/common/global_context.py @@ -7,7 +7,8 @@ # config_path = '{}/../application.yml'.format(dirName) config_path = "application.yml" -components = Initializer(log_config_path='logging.conf', config_path=config_path) +components = Initializer( + log_config_path='logging.conf', config_path=config_path) # matplotlib 线程不安全,并行任务绘图增加全局锁 plot_lock = threading.Lock() diff --git a/python/ppc_model/common/model_result.py b/python/ppc_model/common/model_result.py index 5b156948..0cb55142 100644 --- a/python/ppc_model/common/model_result.py +++ b/python/ppc_model/common/model_result.py @@ -25,22 +25,27 @@ def __init__(self, ctx: Context) -> None: # Synchronization result file if (len(ctx.result_receiver_id_list) == 1 and ctx.participant_id_list[0] != ctx.result_receiver_id_list[0]) \ - or len(ctx.result_receiver_id_list) > 1: + or len(ctx.result_receiver_id_list) > 1: self._sync_result_files() - + def _process_fe_result(self): if os.path.exists(self.ctx.preprocessing_result_file): - column_info_fm = pd.read_csv(self.ctx.preprocessing_result_file, index_col=0) + column_info_fm = pd.read_csv( + self.ctx.preprocessing_result_file, index_col=0) if os.path.exists(self.ctx.iv_selected_file): - column_info_iv_fm = pd.read_csv(self.ctx.iv_selected_file, index_col=0) - merged_df = self.union_column_info(column_info_fm, column_info_iv_fm) + column_info_iv_fm = pd.read_csv( + self.ctx.iv_selected_file, index_col=0) + merged_df = self.union_column_info( + column_info_fm, column_info_iv_fm) else: merged_df = column_info_fm merged_df.fillna("None", inplace=True) - merged_df.to_csv(self.ctx.selected_col_file, sep=utils.CSV_SEP, header=True, index_label='id') + merged_df.to_csv(self.ctx.selected_col_file, + sep=utils.CSV_SEP, header=True, index_label='id') # 存储column_info到hdfs给前端展示 - self._upload_file(self.ctx.components.storage_client, self.ctx.selected_col_file, self.ctx.remote_selected_col_file) + self._upload_file(self.ctx.components.storage_client, + self.ctx.selected_col_file, self.ctx.remote_selected_col_file) @staticmethod def union_column_info(column_info1: pd.DataFrame, column_info2: pd.DataFrame): @@ -55,10 +60,12 @@ def union_column_info(column_info1: pd.DataFrame, column_info2: pd.DataFrame): column_info_merge (DataFrame): The union column_info. """ # 将column_info1和column_info2按照left_index=True, right_index=True的方式进行合并 如果列有缺失则赋值为None 行的顺序按照column_info1 - column_info_conbine = column_info1.merge(column_info2, how='outer', left_index=True, right_index=True, sort=False) + column_info_conbine = column_info1.merge( + column_info2, how='outer', left_index=True, right_index=True, sort=False) col1_index_list = column_info1.index.to_list() col2_index_list = column_info2.index.to_list() - merged_list = col1_index_list + [item for item in col2_index_list if item not in col1_index_list] + merged_list = col1_index_list + \ + [item for item in col2_index_list if item not in col1_index_list] column_info_conbine = column_info_conbine.reindex(merged_list) return column_info_conbine @@ -71,6 +78,7 @@ def _upload_file(storage_client, local_file, remote_file): def _download_file(storage_client, local_file, remote_file): if storage_client is not None and not os.path.exists(local_file): storage_client.download_file(remote_file, local_file) + @staticmethod def make_graph_data(components, job_id, graph_file_name): graph_format = 'svg+xml' @@ -119,37 +127,39 @@ def make_csv_data(components, job_id, csv_file_name): def _remove_workspace(self): if os.path.exists(self.ctx.workspace): shutil.rmtree(self.ctx.workspace) - self.log.info(f'job {self.ctx.job_id}: {self.ctx.workspace} has been removed.') + self.log.info( + f'job {self.ctx.job_id}: {self.ctx.workspace} has been removed.') else: - self.log.info(f'job {self.ctx.job_id}: {self.ctx.workspace} does not exist.') + self.log.info( + f'job {self.ctx.job_id}: {self.ctx.workspace} does not exist.') def _sync_result_files(self): if self.ctx.algorithm_type == AlgorithmType.Train.name: - self.sync_result_file(self.ctx, self.ctx.metrics_iteration_file, + self.sync_result_file(self.ctx, self.ctx.metrics_iteration_file, self.ctx.remote_metrics_iteration_file, 'f1') - self.sync_result_file(self.ctx, self.ctx.feature_importance_file, + self.sync_result_file(self.ctx, self.ctx.feature_importance_file, self.ctx.remote_feature_importance_file, 'f2') - self.sync_result_file(self.ctx, self.ctx.summary_evaluation_file, + self.sync_result_file(self.ctx, self.ctx.summary_evaluation_file, self.ctx.remote_summary_evaluation_file, 'f3') - self.sync_result_file(self.ctx, self.ctx.train_metric_ks_table, + self.sync_result_file(self.ctx, self.ctx.train_metric_ks_table, self.ctx.remote_train_metric_ks_table, 'f4') - self.sync_result_file(self.ctx, self.ctx.train_metric_roc_file, + self.sync_result_file(self.ctx, self.ctx.train_metric_roc_file, self.ctx.remote_train_metric_roc_file, 'f5') - self.sync_result_file(self.ctx, self.ctx.train_metric_ks_file, + self.sync_result_file(self.ctx, self.ctx.train_metric_ks_file, self.ctx.remote_train_metric_ks_file, 'f6') - self.sync_result_file(self.ctx, self.ctx.train_metric_pr_file, + self.sync_result_file(self.ctx, self.ctx.train_metric_pr_file, self.ctx.remote_train_metric_pr_file, 'f7') - self.sync_result_file(self.ctx, self.ctx.train_metric_acc_file, + self.sync_result_file(self.ctx, self.ctx.train_metric_acc_file, self.ctx.remote_train_metric_acc_file, 'f8') - self.sync_result_file(self.ctx, self.ctx.test_metric_ks_table, + self.sync_result_file(self.ctx, self.ctx.test_metric_ks_table, self.ctx.remote_test_metric_ks_table, 'f9') - self.sync_result_file(self.ctx, self.ctx.test_metric_roc_file, + self.sync_result_file(self.ctx, self.ctx.test_metric_roc_file, self.ctx.remote_test_metric_roc_file, 'f10') - self.sync_result_file(self.ctx, self.ctx.test_metric_ks_file, + self.sync_result_file(self.ctx, self.ctx.test_metric_ks_file, self.ctx.remote_test_metric_ks_file, 'f11') - self.sync_result_file(self.ctx, self.ctx.test_metric_pr_file, + self.sync_result_file(self.ctx, self.ctx.test_metric_pr_file, self.ctx.remote_test_metric_pr_file, 'f12') - self.sync_result_file(self.ctx, self.ctx.test_metric_acc_file, + self.sync_result_file(self.ctx, self.ctx.test_metric_acc_file, self.ctx.remote_test_metric_acc_file, 'f13') @staticmethod @@ -163,11 +173,12 @@ def sync_result_file(ctx, local_file, remote_file, key_file): byte_data, partner_index) else: if ctx.components.config_data['AGENCY_ID'] in ctx.result_receiver_id_list: - byte_data = SendMessage._receive_byte_data(ctx.components.stub, ctx, + byte_data = SendMessage._receive_byte_data(ctx.components.stub, ctx, f'{CommonMessage.SYNC_FILE.value}_{key_file}', 0) with open(local_file, 'wb') as f: f.write(byte_data) - ResultFileHandling._upload_file(ctx.components.storage_client, local_file, remote_file) + ResultFileHandling._upload_file( + ctx.components.storage_client, local_file, remote_file) class CommonMessage(Enum): diff --git a/python/ppc_model/common/protocol.py b/python/ppc_model/common/protocol.py index c09a4da4..2ec8f65d 100644 --- a/python/ppc_model/common/protocol.py +++ b/python/ppc_model/common/protocol.py @@ -37,7 +37,8 @@ def packing_data(codec, public_key, cipher_list): for cipher in cipher_list: model_cipher = ModelCipher() - model_cipher.ciphertext, model_cipher.exponent = codec.encode_cipher(cipher) + model_cipher.ciphertext, model_cipher.exponent = codec.encode_cipher( + cipher) enc_data_pb.cipher_list.append(model_cipher) return utils.pb_to_bytes(enc_data_pb) diff --git a/python/ppc_model/datasets/dataset.py b/python/ppc_model/datasets/dataset.py index 0278c51c..a8db9c58 100644 --- a/python/ppc_model/datasets/dataset.py +++ b/python/ppc_model/datasets/dataset.py @@ -112,7 +112,7 @@ def _random_split_dataset(self): def _customized_split_dataset(self): if self.ctx.role == TaskRole.ACTIVE_PARTY: for partner_index in range(1, len(self.ctx.participant_id_list)): - byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, + byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', partner_index) if not os.path.exists(self.eval_column_file) and byte_data != bytes(): with open(self.eval_column_file, 'wb') as f: @@ -130,15 +130,17 @@ def _customized_split_dataset(self): byte_data = f.read() SendMessage._send_byte_data(self.ctx.components.stub, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', byte_data, 0) - byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, - f'{CommonMessage.EVAL_SET_FILE.value}', 0) + byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, + f'{CommonMessage.EVAL_SET_FILE.value}', 0) if not os.path.exists(self.eval_column_file): with open(self.eval_column_file, 'wb') as f: f.write(byte_data) - - eval_set_df = pd.read_csv(self.eval_column_file, header=0) - train_data = self.model_data[eval_set_df[self.eval_set_column] == self.train_set_value] - test_data = self.model_data[eval_set_df[self.eval_set_column] == self.eval_set_value] + + eval_set_df = pd.read_csv(self.eval_column_file, header=0) + train_data = self.model_data[eval_set_df[self.eval_set_column] + == self.train_set_value] + test_data = self.model_data[eval_set_df[self.eval_set_column] + == self.eval_set_value] return train_data, test_data diff --git a/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py b/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py index e06b82b2..20592c87 100644 --- a/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py +++ b/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py @@ -31,16 +31,17 @@ def test_train_feature_binning(self): 'algorithm_type': 'Train', 'algorithm_subtype': None, 'model_dict': { - 'objective': 'regression', - 'max_bin': 10, - 'n_estimators': 6, - 'max_depth': 3, - 'use_goss': 1 + 'objective': 'regression', + 'max_bin': 10, + 'n_estimators': 6, + 'max_depth': 3, + 'use_goss': 1 } } task_info = SecureLGBMContext(args, self.components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label) secure_dataset = SecureDataset(task_info, model_data) print(secure_dataset.train_idx.shape) print(secure_dataset.train_X.shape) @@ -79,7 +80,8 @@ def test_test_feature_binning(self): } task_info = SecureLGBMContext(args, self.components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label) secure_dataset = SecureDataset(task_info, model_data) print(secure_dataset.train_idx.shape) print(secure_dataset.train_X.shape) diff --git a/python/ppc_model/datasets/test/test_dataset.py b/python/ppc_model/datasets/test/test_dataset.py index 7a544cfc..e7519454 100644 --- a/python/ppc_model/datasets/test/test_dataset.py +++ b/python/ppc_model/datasets/test/test_dataset.py @@ -48,7 +48,7 @@ class TestSecureDataset(unittest.TestCase): iv_selected_file = './iv_selected.csv' if not os.path.exists(iv_selected_file): iv_selected = pd.DataFrame( - {'feature': [f'x{i + 1}' for i in range(30)], + {'feature': [f'x{i + 1}' for i in range(30)], 'iv_selected': np.random.binomial(n=1, p=0.5, size=30)}) iv_selected.to_csv(iv_selected_file, index=None) @@ -203,9 +203,11 @@ def test_read_dataset(self): df = pd.DataFrame(origin_data, columns=columns) csv_file = '/tmp/data_x1_to_x10.csv' df.to_csv(csv_file, index=False) - field_list, label, feature = SecureDataset.read_dataset(csv_file, False, delimiter=',') + field_list, label, feature = SecureDataset.read_dataset( + csv_file, False, delimiter=',') self.assertEqual(['id'] + field_list, columns) - field_list, label, feature = SecureDataset.read_dataset(csv_file, True, delimiter=',') + field_list, label, feature = SecureDataset.read_dataset( + csv_file, True, delimiter=',') self.assertEqual(['id'] + field_list, columns) diff --git a/python/ppc_model/feature_engineering/feature_engineering_engine.py b/python/ppc_model/feature_engineering/feature_engineering_engine.py index 16acdb7b..0b3bb88f 100644 --- a/python/ppc_model/feature_engineering/feature_engineering_engine.py +++ b/python/ppc_model/feature_engineering/feature_engineering_engine.py @@ -17,7 +17,8 @@ def run(args): args['job_id'], components.config_data['JOB_TEMP_DIR']) if args['is_label_holder']: - field_list, label, feature = SecureDataset.read_dataset(input_path, True) + field_list, label, feature = SecureDataset.read_dataset( + input_path, True) context = FeatureEngineeringContext( args=args, components=components, @@ -28,7 +29,8 @@ def run(args): ) vfe = VerticalFeatureEngineeringActiveParty(context) else: - field_list, _, feature = SecureDataset.read_dataset(input_path, False) + field_list, _, feature = SecureDataset.read_dataset( + input_path, False) context = FeatureEngineeringContext( args=args, components=components, diff --git a/python/ppc_model/feature_engineering/vertical/active_party.py b/python/ppc_model/feature_engineering/vertical/active_party.py index 6710c47d..82523420 100644 --- a/python/ppc_model/feature_engineering/vertical/active_party.py +++ b/python/ppc_model/feature_engineering/vertical/active_party.py @@ -139,7 +139,8 @@ def _get_all_enc_aggr_labels(self, partner_id): enc_aggr_labels_list_pb = EncAggrLabelsList() utils.bytes_to_pb(enc_aggr_labels_list_pb, data) - public_key = self.ctx.codec.decode_enc_key(enc_aggr_labels_list_pb.public_key) + public_key = self.ctx.codec.decode_enc_key( + enc_aggr_labels_list_pb.public_key) res = [] for enc_aggr_labels_pb in enc_aggr_labels_list_pb.enc_aggr_labels_list: @@ -160,7 +161,8 @@ def _send_enc_labels(self, enc_labels, receiver): log = self.ctx.components.logger() start_time = time.time() - data = PheMessage.packing_data(self.ctx.codec, self.ctx.phe.public_key, enc_labels) + data = PheMessage.packing_data( + self.ctx.codec, self.ctx.phe.public_key, enc_labels) self.ctx.components.stub.push(PushRequest( receiver=receiver, task_id=self.ctx.task_id, @@ -174,8 +176,10 @@ def _send_enc_labels(self, enc_labels, receiver): def _save_and_sync_fe_results(self): log = self.ctx.components.logger() task_id = self.ctx.task_id - self.woe_iv_df.to_csv(self.ctx.woe_iv_file, sep=',', header=True, index=None) - self.iv_selected_df.to_csv(self.ctx.iv_selected_file, sep=',', header=True, index=None) + self.woe_iv_df.to_csv(self.ctx.woe_iv_file, + sep=',', header=True, index=None) + self.iv_selected_df.to_csv( + self.ctx.iv_selected_file, sep=',', header=True, index=None) self.ctx.components.storage_client.upload_file(self.ctx.woe_iv_file, self.ctx.job_id + os.sep + self.ctx.WOE_IV_FILE) log.info(f"Saving fe results finished, task_id: {task_id}") diff --git a/python/ppc_model/feature_engineering/vertical/passive_party.py b/python/ppc_model/feature_engineering/vertical/passive_party.py index 8aa63996..4b3ffd47 100644 --- a/python/ppc_model/feature_engineering/vertical/passive_party.py +++ b/python/ppc_model/feature_engineering/vertical/passive_party.py @@ -30,7 +30,8 @@ def fit(self, *args, **kwargs) -> None: public_key, enc_labels = self._get_enc_labels() # 根据特征分箱,聚合加密标签 - aggr_labels_bytes_list = self._binning_and_aggregating_all(public_key, enc_labels) + aggr_labels_bytes_list = self._binning_and_aggregating_all( + public_key, enc_labels) # 发送聚合的密文标签 self._send_all_enc_aggr_labels(public_key, aggr_labels_bytes_list) @@ -49,7 +50,8 @@ def _get_enc_labels(self): key=FeMessage.ENC_LABELS.value )) - public_key, enc_labels = PheMessage.unpacking_data(self.ctx.codec, data) + public_key, enc_labels = PheMessage.unpacking_data( + self.ctx.codec, data) log.info(f"All enc labels received, task_id: {self.ctx.task_id}, label_num: {len(enc_labels)}, " f"size: {len(data) / 1024}KB, timecost: {time.time() - start_time}s") return public_key, enc_labels @@ -59,7 +61,8 @@ def _binning_and_aggregating_all(self, public_key, enc_labels) -> list: start_time = time.time() params = [] for i in range(self.ctx.feature.shape[1]): - is_continuous = is_continuous_feature(self.ctx.categorical, self.ctx.feature_name_list[i]) + is_continuous = is_continuous_feature( + self.ctx.categorical, self.ctx.feature_name_list[i]) params.append({ 'is_continuous': is_continuous, 'feature_index': i, @@ -89,7 +92,8 @@ def _binning_and_aggregating_all(self, public_key, enc_labels) -> list: def _binning_and_aggregating_one(param): feature = param['feature'] if param['is_continuous']: - bins = FeatureBinning.binning_continuous_feature(feature, param['group_num'])[0] + bins = FeatureBinning.binning_continuous_feature( + feature, param['group_num'])[0] else: bins = FeatureBinning.binning_categorical_feature(feature)[0] @@ -103,8 +107,10 @@ def _binning_and_aggregating_one(param): else: data_dict[key] = {'count': 1, 'sum': value} - count_list = [data_dict[key]['count'] for key in sorted(data_dict.keys())] - aggr_enc_labels = [data_dict[key]['sum'] for key in sorted(data_dict.keys())] + count_list = [data_dict[key]['count'] + for key in sorted(data_dict.keys())] + aggr_enc_labels = [data_dict[key]['sum'] + for key in sorted(data_dict.keys())] return VerticalFeatureEngineeringPassiveParty._encode_enc_aggr_labels( param['codec'], param['field'], count_list, aggr_enc_labels) @@ -125,12 +131,14 @@ def _encode_enc_aggr_labels(codec, field, count_list, aggr_enc_labels): def _send_all_enc_aggr_labels(self, public_key, aggr_labels_bytes_list): start_time = time.time() enc_aggr_labels_list_pb = EncAggrLabelsList() - enc_aggr_labels_list_pb.public_key = self.ctx.codec.encode_enc_key(public_key) + enc_aggr_labels_list_pb.public_key = self.ctx.codec.encode_enc_key( + public_key) for aggr_labels_bytes in aggr_labels_bytes_list: enc_aggr_labels_pb = EncAggrLabels() utils.bytes_to_pb(enc_aggr_labels_pb, aggr_labels_bytes) - enc_aggr_labels_list_pb.enc_aggr_labels_list.append(enc_aggr_labels_pb) + enc_aggr_labels_list_pb.enc_aggr_labels_list.append( + enc_aggr_labels_pb) data = utils.pb_to_bytes(enc_aggr_labels_list_pb) diff --git a/python/ppc_model/feature_engineering/vertical/utils.py b/python/ppc_model/feature_engineering/vertical/utils.py index cdeeceaf..32faa2fc 100644 --- a/python/ppc_model/feature_engineering/vertical/utils.py +++ b/python/ppc_model/feature_engineering/vertical/utils.py @@ -41,9 +41,11 @@ def calculate_woe_iv(feature: np.ndarray, label: np.ndarray, num_bins: int = 10, combined = pd.DataFrame({'feature': feature, 'label': label}) # 按特征值对数据集进行分箱 if is_continuous: - combined['bins'] = FeatureBinning.binning_continuous_feature(feature, num_bins, is_equal_freq)[0] + combined['bins'] = FeatureBinning.binning_continuous_feature( + feature, num_bins, is_equal_freq)[0] else: - combined['bins'] = FeatureBinning.binning_categorical_feature(feature)[0] + combined['bins'] = FeatureBinning.binning_categorical_feature(feature)[ + 0] # 计算每个分箱中的正负样本数量和总体样本数量 grouped = combined.groupby('bins')['label'].agg(['count', 'sum']) grouped = grouped.rename(columns={'sum': 'pos_event'}) diff --git a/python/ppc_model/metrics/evaluation.py b/python/ppc_model/metrics/evaluation.py index a759dc9d..2bf12ebb 100644 --- a/python/ppc_model/metrics/evaluation.py +++ b/python/ppc_model/metrics/evaluation.py @@ -24,11 +24,11 @@ class Evaluation: - def __init__(self, - ctx: Context, - dataset: SecureDataset, - train_praba:np.ndarray = None, - test_praba:np.ndarray = None) -> None: + def __init__(self, + ctx: Context, + dataset: SecureDataset, + train_praba: np.ndarray = None, + test_praba: np.ndarray = None) -> None: self.job_id = ctx.job_id self.storage_client = ctx.components.storage_client @@ -59,7 +59,8 @@ def __init__(self, train_ks, train_auc = self.evaluation_file( ctx, dataset.train_idx, dataset.train_y, train_praba, 'train') if dataset.train_y is not None: - self.summary_evaluation(dataset, test_ks, test_auc, train_ks, train_auc) + self.summary_evaluation( + dataset, test_ks, test_auc, train_ks, train_auc) @staticmethod def fevaluation( @@ -98,33 +99,39 @@ def summary_evaluation(self, dataset, test_ks, test_auc, train_ks, train_auc): @staticmethod def calculate_ks_and_stats(predicted_proba, actual_label, num_buckets=10): # 合并预测概率和实际标签为一个 DataFrame - df = pd.DataFrame({'predicted_proba': predicted_proba.reshape(-1), 'actual_label': actual_label.reshape(-1)}) + df = pd.DataFrame({'predicted_proba': predicted_proba.reshape(-1), + 'actual_label': actual_label.reshape(-1)}) # 根据预测概率降序排列 df_sorted = df.sort_values(by='predicted_proba', ascending=False) # 将数据划分为 num_buckets 个分组 try: - df_sorted['bucket'] = pd.qcut(df_sorted['predicted_proba'], num_buckets, retbins=True, labels=False)[0] + df_sorted['bucket'] = pd.qcut( + df_sorted['predicted_proba'], num_buckets, retbins=True, labels=False)[0] except Exception: - df_sorted['bucket'] = pd.cut(df_sorted['predicted_proba'], num_buckets, retbins=True, labels=False)[0] + df_sorted['bucket'] = pd.cut( + df_sorted['predicted_proba'], num_buckets, retbins=True, labels=False)[0] # 统计每个分组的信息 stats = df_sorted.groupby('bucket').agg({ 'actual_label': ['count', 'sum'], 'predicted_proba': ['min', 'max'] }) # 计算其他指标 - stats.columns = ['count', 'positive_count', 'predict_proba_min', 'predict_proba_max'] + stats.columns = ['count', 'positive_count', + 'predict_proba_min', 'predict_proba_max'] stats['positive_ratio'] = stats['positive_count'] / stats['count'] stats['negative_ratio'] = 1 - stats['positive_ratio'] stats['count_ratio'] = stats['count'] / stats['count'].sum() # stats['累计坏客户占比'] = stats['坏客户数'].cumsum() / stats['坏客户数'].sum() # 计算累计坏客户占比,从第 9 组开始计算 - stats['cum_positive_ratio'] = stats['positive_count'].iloc[::-1].cumsum()[::-1] / stats['positive_count'].sum() + stats['cum_positive_ratio'] = stats['positive_count'].iloc[::- + 1].cumsum()[::-1] / stats['positive_count'].sum() stats = stats[['count_ratio', 'count', 'positive_count', - 'positive_ratio', 'negative_ratio', 'cum_positive_ratio']].reset_index() - stats.columns = ['分组', '样本占比', '样本数', '正样本数', '正样本比例', '负样本比例', '累积正样本占比'] + 'positive_ratio', 'negative_ratio', 'cum_positive_ratio']].reset_index() + stats.columns = ['分组', '样本占比', '样本数', + '正样本数', '正样本比例', '负样本比例', '累积正样本占比'] return stats - def evaluation_file(self, ctx, data_index: np.ndarray, + def evaluation_file(self, ctx, data_index: np.ndarray, y_true: np.ndarray, y_praba: np.ndarray, label: str = 'test'): if label == 'train': self.model_result_file = ctx.train_model_result_file @@ -150,45 +157,55 @@ def evaluation_file(self, ctx, data_index: np.ndarray, retry_num += 1 try: with plot_lock: - ks_value, auc_value = Evaluation.plot_two_class_graph(self, y_true, y_praba) + ks_value, auc_value = Evaluation.plot_two_class_graph( + self, y_true, y_praba) except: - ctx.components.logger().info(f'y_true = {len(y_true)}, {y_true[0:2]}') - ctx.components.logger().info(f'y_praba = {len(y_praba)}, {y_praba[0:2]}') + ctx.components.logger().info( + f'y_true = {len(y_true)}, {y_true[0:2]}') + ctx.components.logger().info( + f'y_praba = {len(y_praba)}, {y_praba[0:2]}') err = traceback.format_exc() # ctx.components.logger().exception(err) ctx.components.logger().info( f'plot metrics in times-{retry_num} failed, traceback: {err}.') time.sleep(random.uniform(0.1, 3)) - - ResultFileHandling._upload_file(self.storage_client, self.metric_roc_file, self.remote_metric_roc_file) - ResultFileHandling._upload_file(self.storage_client, self.metric_ks_file, self.remote_metric_ks_file) - ResultFileHandling._upload_file(self.storage_client, self.metric_pr_file, self.remote_metric_pr_file) - ResultFileHandling._upload_file(self.storage_client, self.metric_acc_file, self.remote_metric_acc_file) + + ResultFileHandling._upload_file( + self.storage_client, self.metric_roc_file, self.remote_metric_roc_file) + ResultFileHandling._upload_file( + self.storage_client, self.metric_ks_file, self.remote_metric_ks_file) + ResultFileHandling._upload_file( + self.storage_client, self.metric_pr_file, self.remote_metric_pr_file) + ResultFileHandling._upload_file( + self.storage_client, self.metric_acc_file, self.remote_metric_acc_file) # ks table ks_table = self.calculate_ks_and_stats(y_praba, y_true) ks_table.to_csv(self.metric_ks_table, header=True, index=None) - ResultFileHandling._upload_file(self.storage_client, self.metric_ks_table, self.remote_metric_ks_table) + ResultFileHandling._upload_file( + self.storage_client, self.metric_ks_table, self.remote_metric_ks_table) else: ks_value = auc_value = None - + # predict result self._parse_model_result(data_index, y_true, y_praba) - ResultFileHandling._upload_file(self.storage_client, self.model_output_file, self.remote_model_output_file) + ResultFileHandling._upload_file( + self.storage_client, self.model_output_file, self.remote_model_output_file) return ks_value, auc_value def _parse_model_result(self, data_index, y_true=None, y_praba=None): - + np.savetxt(self.model_result_file, y_praba, delimiter=',', fmt='%f') if y_true is None: - df = pd.DataFrame(np.column_stack((data_index, y_praba)), columns=['id', 'class_pred']) + df = pd.DataFrame(np.column_stack( + (data_index, y_praba)), columns=['id', 'class_pred']) else: - df = pd.DataFrame(np.column_stack((data_index, y_true, y_praba)), + df = pd.DataFrame(np.column_stack((data_index, y_true, y_praba)), columns=['id', 'class_label', 'class_pred']) df['class_label'] = df['class_label'].astype(int) - + df['id'] = df['id'].astype(int) df['class_pred'] = df['class_pred'].astype(float) df = df.sort_values(by='id') @@ -202,7 +219,8 @@ def plot_two_class_graph(self, y_true, y_scores): plt.rcParams['figure.figsize'] = (12.0, 8.0) # plot ROC - fpr, tpr, thresholds = roc_curve(y_label_probs, y_pred_probs, pos_label=1) + fpr, tpr, thresholds = roc_curve( + y_label_probs, y_pred_probs, pos_label=1) auc_value = auc(fpr, tpr) plt.figure(f'roc-{self.job_id}') plt.title('ROC Curve') # give plot a title @@ -210,12 +228,12 @@ def plot_two_class_graph(self, y_true, y_scores): plt.ylabel('True Positive Rate (Sensitivity)') plt.plot([0, 1], [0, 1], 'k--', lw=2) plt.plot(fpr, tpr, label='area = {0:0.5f}' - ''.format(auc_value)) + ''.format(auc_value)) plt.legend(loc="lower right") plt.savefig(self.metric_roc_file, dpi=1000) # plt.show() - plt.close('all') + plt.close('all') gc.collect() # plot KS @@ -230,12 +248,12 @@ def plot_two_class_graph(self, y_true, y_scores): # 标记最大ks值 x_index = np.argwhere(abs(fpr - tpr) == ks_value)[0, 0] plt.plot((threshold_x[x_index], threshold_x[x_index]), (fpr[x_index], tpr[x_index]), - label='ks = {:.3f}'.format(ks_value), color='r', marker='o', markerfacecolor='r', markersize=5) + label='ks = {:.3f}'.format(ks_value), color='r', marker='o', markerfacecolor='r', markersize=5) plt.legend(loc="lower right") plt.savefig(self.metric_ks_file, dpi=1000) # plt.show() - plt.close('all') + plt.close('all') gc.collect() # plot Precision Recall @@ -251,7 +269,7 @@ def plot_two_class_graph(self, y_true, y_scores): plt.savefig(self.metric_pr_file, dpi=1000) # plt.show() - plt.close('all') + plt.close('all') gc.collect() # plot accuracy @@ -271,6 +289,6 @@ def plot_two_class_graph(self, y_true, y_scores): plt.savefig(self.metric_acc_file, dpi=1000) # plt.show() - plt.close('all') + plt.close('all') gc.collect() return (ks_value, auc_value) diff --git a/python/ppc_model/metrics/model_plot.py b/python/ppc_model/metrics/model_plot.py index ca169859..bdae534b 100644 --- a/python/ppc_model/metrics/model_plot.py +++ b/python/ppc_model/metrics/model_plot.py @@ -12,9 +12,9 @@ class ModelPlot: - + def __init__(self, model: VerticalBooster) -> None: - + self.ctx = model.ctx self.model = model self._tree_id = 0 @@ -24,17 +24,19 @@ def __init__(self, model: VerticalBooster) -> None: self.storage_client = self.ctx.components.storage_client if model._trees is not None and \ - self.ctx.components.config_data['AGENCY_ID'] in self.ctx.result_receiver_id_list: + self.ctx.components.config_data['AGENCY_ID'] in self.ctx.result_receiver_id_list: self.plot_tree() - + def plot_tree(self): trees = self.model._trees self._split = self.model._X_split for i, tree in enumerate(trees): if i < 6: - tree_file_path = self.ctx.model_tree_prefix + '_' + str(self._tree_id)+'.svg' - remote_tree_file_path = self.ctx.remote_model_tree_prefix + '_' + str(self._tree_id)+'.svg' + tree_file_path = self.ctx.model_tree_prefix + \ + '_' + str(self._tree_id)+'.svg' + remote_tree_file_path = self.ctx.remote_model_tree_prefix + \ + '_' + str(self._tree_id)+'.svg' self._tree_id += 1 self._leaf_id = 0 self._G = DiGraphTree() @@ -49,9 +51,11 @@ def plot_tree(self): retry_num += 1 try: with plot_lock: - self._G.tree_plot(figsize=(10, 5), save_filename=tree_file_path) + self._G.tree_plot( + figsize=(10, 5), save_filename=tree_file_path) except: - self.ctx.components.logger().info(f'tree_id = {i}, tree = {tree}') + self.ctx.components.logger().info( + f'tree_id = {i}, tree = {tree}') self.ctx.components.logger().info(f'G = {self._G}') err = traceback.format_exc() # self.ctx.components.logger().exception(err) @@ -59,15 +63,18 @@ def plot_tree(self): f'plot tree-{i} in times-{retry_num} failed, traceback: {err}.') time.sleep(random.uniform(0.1, 3)) - ResultFileHandling._upload_file(self.storage_client, tree_file_path, remote_tree_file_path) + ResultFileHandling._upload_file( + self.storage_client, tree_file_path, remote_tree_file_path) def _graph_gtree(self, tree, leaf_id=0, depth=0, orient=None, split_info=None): self._leaf_id += 1 self._G.add_node(self._leaf_id) if split_info is not None: if self.ctx.participant_id_list[split_info.agency_idx] == self.ctx.components.config_data['AGENCY_ID']: - feature = str(self.model.dataset.feature_name[split_info.agency_feature]) - value = str(round(float(self._split[split_info.agency_feature][split_info.value]), 4)) + feature = str( + self.model.dataset.feature_name[split_info.agency_feature]) + value = str( + round(float(self._split[split_info.agency_feature][split_info.value]), 4)) else: feature = str(split_info.feature) value = str(split_info.value) @@ -84,8 +91,10 @@ def _graph_gtree(self, tree, leaf_id=0, depth=0, orient=None, split_info=None): self._G.add_weighted_edges_from( [(leaf_id, self._leaf_id, orient+'_'+feature+'_'+value+'_'+str(split_info.w_right))]) my_leaf_id = self._leaf_id - self._graph_gtree(left_tree, my_leaf_id, depth+1, 'left', best_split_info) - self._graph_gtree(right_tree, my_leaf_id, depth+1, 'right', best_split_info) + self._graph_gtree(left_tree, my_leaf_id, depth + + 1, 'left', best_split_info) + self._graph_gtree(right_tree, my_leaf_id, depth + + 1, 'right', best_split_info) else: if leaf_id != 0: self._G.add_weighted_edges_from( @@ -99,7 +108,8 @@ def __init__(self): super().__init__() def tree_leaves(self): - leaves_list = [x for x in self.nodes() if self.out_degree(x)==0 and self.in_degree(x)<=1] + leaves_list = [x for x in self.nodes() if self.out_degree( + x) == 0 and self.in_degree(x) <= 1] return leaves_list def tree_dfs_nodes(self): @@ -107,7 +117,8 @@ def tree_dfs_nodes(self): return nodes_list def tree_dfs_leaves(self): - dfs_leaves = [x for x in self.tree_dfs_nodes() if x in self.tree_leaves()] + dfs_leaves = [x for x in self.tree_dfs_nodes() + if x in self.tree_leaves()] return dfs_leaves def tree_depth(self): @@ -127,7 +138,8 @@ def tree_plot(self, split=True, figsize=(20, 10), dpi=300, save_filename=None): if split: labels = {} # leaves = self.tree_leaves() - leaves = [x for x in self.nodes() if self.out_degree(x)==0 and self.in_degree(x)<=1] + leaves = [x for x in self.nodes() if self.out_degree(x) == + 0 and self.in_degree(x) <= 1] if leaves == [0]: leaves = [] @@ -146,7 +158,8 @@ def tree_plot(self, split=True, figsize=(20, 10), dpi=300, save_filename=None): else: in_node = list(nx.neighbors(self, n))[0] weight = edge_labels[(n, in_node)] - labels[n] = weight.split('_')[1] + ':' + weight.split('_')[2] + labels[n] = weight.split( + '_')[1] + ':' + weight.split('_')[2] # for key, value in edge_labels.items(): # edge_labels[key] = round(float(value.split('_')[3]), 4) @@ -169,7 +182,8 @@ def tree_plot(self, split=True, figsize=(20, 10), dpi=300, save_filename=None): '-' + str(round(float(value.split('_')[3]), 4)) plt.figure(figsize=figsize, dpi=dpi) - nx.draw(self, pos, with_labels=True, labels=labels, font_weight='bold') + nx.draw(self, pos, with_labels=True, + labels=labels, font_weight='bold') nx.draw_networkx_edge_labels(self, pos, edge_labels=edge_labels) # plt.show() if save_filename is not None: diff --git a/python/ppc_model/metrics/test/test_metrics.py b/python/ppc_model/metrics/test/test_metrics.py index 3545aaa7..5382d5d8 100644 --- a/python/ppc_model/metrics/test/test_metrics.py +++ b/python/ppc_model/metrics/test/test_metrics.py @@ -66,13 +66,14 @@ class TestXgboostTraining(unittest.TestCase): args_a, args_b = mock_args() def test_active_metrics(self): - + active_components = Initializer(log_config_path='', config_path='') active_components.config_data = { 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} active_components.mock_logger = MockLogger() task_info_a = SecureLGBMContext(self.args_a, active_components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=True) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=True) secure_dataset_a = SecureDataset(task_info_a, model_data) booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) print(secure_dataset_a.feature_name) @@ -86,7 +87,8 @@ def test_active_metrics(self): # booster_a._train_praba = np.random.rand(len(secure_dataset_a.train_y)) booster_a._test_praba = np.random.rand(len(secure_dataset_a.test_y)) - Evaluation(task_info_a, secure_dataset_a, booster_a._train_praba, booster_a._test_praba) + Evaluation(task_info_a, secure_dataset_a, + booster_a._train_praba, booster_a._test_praba) def test_passive_metrics(self): @@ -95,7 +97,8 @@ def test_passive_metrics(self): 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} passive_components.mock_logger = MockLogger() task_info_b = SecureLGBMContext(self.args_b, passive_components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=False) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=False) secure_dataset_b = SecureDataset(task_info_b, model_data) booster_b = VerticalLGBMPassiveParty(task_info_b, secure_dataset_b) print(secure_dataset_b.feature_name) @@ -107,7 +110,8 @@ def test_passive_metrics(self): # booster_b._train_praba = np.random.rand(len(secure_dataset_b.train_idx)) booster_b._test_praba = np.random.rand(len(secure_dataset_b.test_idx)) - Evaluation(task_info_b, secure_dataset_b, booster_b._train_praba, booster_b._test_praba) + Evaluation(task_info_b, secure_dataset_b, + booster_b._train_praba, booster_b._test_praba) def test_model_plot(self): @@ -116,7 +120,8 @@ def test_model_plot(self): 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} active_components.mock_logger = MockLogger() task_info_a = SecureLGBMContext(self.args_a, active_components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=True) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=True) secure_dataset_a = SecureDataset(task_info_a, model_data) booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) if os.path.exists(booster_a.ctx.model_data_file): @@ -128,32 +133,32 @@ def test_digraphtree(self): Gtree.add_node(0) Gtree.add_nodes_from([1, 2]) Gtree.add_weighted_edges_from( - [(0, 1, 'left_'+str(2)+'_'+str(3)+'_'+str(0.5)), - (0, 2, 'right_'+str(2)+'_'+str(3)+'_'+str(0.9))]) + [(0, 1, 'left_'+str(2)+'_'+str(3)+'_'+str(0.5)), + (0, 2, 'right_'+str(2)+'_'+str(3)+'_'+str(0.9))]) Gtree.add_nodes_from([3, 4]) Gtree.add_weighted_edges_from( - [(1, 3, 'left_'+str(20)+'_'+str(4)+'_'+str(0.5)), - (1, 4, 'right_'+str(20)+'_'+str(4)+'_'+str(0.9))]) + [(1, 3, 'left_'+str(20)+'_'+str(4)+'_'+str(0.5)), + (1, 4, 'right_'+str(20)+'_'+str(4)+'_'+str(0.9))]) Gtree.add_nodes_from([5, 6]) Gtree.add_weighted_edges_from( - [(2, 5, 'left_'+str(2)+'_'+str(7)+'_'+str(0.5)), - (2, 6, 'right_'+str(2)+'_'+str(7)+'_'+str(0.9))]) + [(2, 5, 'left_'+str(2)+'_'+str(7)+'_'+str(0.5)), + (2, 6, 'right_'+str(2)+'_'+str(7)+'_'+str(0.9))]) Gtree.add_nodes_from([7, 8]) Gtree.add_weighted_edges_from( - [(3, 7, 'left_'+str(1)+'_'+str(11)+'_'+str(0.5)), - (3, 8, 'right_'+str(1)+'_'+str(11)+'_'+str(0.9))]) + [(3, 7, 'left_'+str(1)+'_'+str(11)+'_'+str(0.5)), + (3, 8, 'right_'+str(1)+'_'+str(11)+'_'+str(0.9))]) Gtree.add_nodes_from([9, 10]) Gtree.add_weighted_edges_from( - [(4, 9, 'left_'+str(18)+'_'+str(2)+'_'+str(0.5)), - (4, 10, 'right_'+str(18)+'_'+str(2)+'_'+str(0.9))]) + [(4, 9, 'left_'+str(18)+'_'+str(2)+'_'+str(0.5)), + (4, 10, 'right_'+str(18)+'_'+str(2)+'_'+str(0.9))]) Gtree.add_nodes_from([11, 12]) Gtree.add_weighted_edges_from( - [(5, 11, 'left_'+str(23)+'_'+str(25)+'_'+str(0.5)), - (5, 12, 'right_'+str(23)+'_'+str(25)+'_'+str(0.9))]) + [(5, 11, 'left_'+str(23)+'_'+str(25)+'_'+str(0.5)), + (5, 12, 'right_'+str(23)+'_'+str(25)+'_'+str(0.9))]) Gtree.add_nodes_from([13, 14]) Gtree.add_weighted_edges_from( - [(6, 13, 'left_'+str(16)+'_'+str(10)+'_'+str(0.5)), - (6, 14, 'right_'+str(16)+'_'+str(10)+'_'+str(0.9))]) + [(6, 13, 'left_'+str(16)+'_'+str(10)+'_'+str(0.5)), + (6, 14, 'right_'+str(16)+'_'+str(10)+'_'+str(0.9))]) # Gtree.tree_plot() # Gtree.tree_plot(split=False, figsize=(10, 5)) diff --git a/python/ppc_model/network/http/model_controller.py b/python/ppc_model/network/http/model_controller.py index 3ebbac2f..f8b0130f 100644 --- a/python/ppc_model/network/http/model_controller.py +++ b/python/ppc_model/network/http/model_controller.py @@ -30,7 +30,8 @@ def post(self, model_id): """ args = request.get_json() task_id = model_id - components.logger().info(f"run task request, task_id: {task_id}, args: {args}") + components.logger().info( + f"run task request, task_id: {task_id}, args: {args}") task_type = args['task_type'] components.task_manager.run_task( task_id, ModelTask(task_type), (args,)) @@ -43,7 +44,8 @@ def get(self, model_id): """ response = utils.BASE_RESPONSE task_id = model_id - status, traffic_volume, time_costs = components.task_manager.status(task_id) + status, traffic_volume, time_costs = components.task_manager.status( + task_id) response['data'] = { 'status': status, 'traffic_volume': traffic_volume, diff --git a/python/ppc_model/network/http/restx.py b/python/ppc_model/network/http/restx.py index 83dba0ce..553337ae 100644 --- a/python/ppc_model/network/http/restx.py +++ b/python/ppc_model/network/http/restx.py @@ -22,7 +22,8 @@ def default_error_handler(e): components.logger().exception(e) info = e.to_dict() response = {'errorCode': info['code'], 'message': info['message']} - components.logger().error(f"OnError: code: {info['code']}, message: {info['message']}") + components.logger().error( + f"OnError: code: {info['code']}, message: {info['message']}") return response, 500 diff --git a/python/ppc_model/network/stub.py b/python/ppc_model/network/stub.py index 7db22b1e..d472847b 100644 --- a/python/ppc_model/network/stub.py +++ b/python/ppc_model/network/stub.py @@ -47,7 +47,8 @@ def __init__( self.agency_id = agency_id self._thread_event_manager = thread_event_manager self._rpc_client = rpc_client - self._executor = ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)) + self._executor = ThreadPoolExecutor( + max_workers=max(1, os.cpu_count() - 1)) self._send_retry_times = send_retry_times self._retry_interval_s = retry_interval_s # 缓存收到的消息 [task_id:[sender:[key:[seq: data]]]] diff --git a/python/ppc_model/ppc_model_app.py b/python/ppc_model/ppc_model_app.py index baef888b..79060c82 100644 --- a/python/ppc_model/ppc_model_app.py +++ b/python/ppc_model/ppc_model_app.py @@ -1,30 +1,28 @@ # Note: here can't be refactored by autopep +from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine +from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine +from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine +from ppc_model.network.http.restx import api +from ppc_model.network.http.model_controller import ns2 as log_namespace +from ppc_model.network.http.model_controller import ns as task_namespace +from ppc_model.network.grpc.grpc_server import ModelService +from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine +from ppc_model.common.protocol import ModelTask +from ppc_model.common.global_context import components +from ppc_common.ppc_utils import utils +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from paste.translogger import TransLogger +from flask import Flask, Blueprint +from cheroot.wsgi import Server as WSGIServer +from cheroot.ssl.builtin import BuiltinSSLAdapter +import grpc +from threading import Thread +from concurrent import futures +import os +import multiprocessing import sys sys.path.append("../") -import multiprocessing -import os -from concurrent import futures -from threading import Thread - -import grpc -from cheroot.ssl.builtin import BuiltinSSLAdapter -from cheroot.wsgi import Server as WSGIServer -from flask import Flask, Blueprint -from paste.translogger import TransLogger - -from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc -from ppc_common.ppc_utils import utils -from ppc_model.common.global_context import components -from ppc_model.common.protocol import ModelTask -from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine -from ppc_model.network.grpc.grpc_server import ModelService -from ppc_model.network.http.model_controller import ns as task_namespace -from ppc_model.network.http.model_controller import ns2 as log_namespace -from ppc_model.network.http.restx import api -from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine -from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine -from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine app = Flask(__name__) @@ -57,7 +55,8 @@ def model_serve(): if app.config['SSL_SWITCH'] == 0: ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), options=components.grpc_options) - ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(ModelService(), ppc_serve) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server( + ModelService(), ppc_serve) address = "[::]:{}".format(app.config['RPC_PORT']) ppc_serve.add_insecure_port(address) else: @@ -74,7 +73,8 @@ def model_serve(): ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), options=components.grpc_options) - ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(ModelService(), ppc_serve) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server( + ModelService(), ppc_serve) address = "[::]:{}".format(app.config['RPC_PORT']) ppc_serve.add_secure_port(address, server_credentials) diff --git a/python/ppc_model/preprocessing/local_processing/preprocessing.py b/python/ppc_model/preprocessing/local_processing/preprocessing.py index 55082a62..f0d39faf 100644 --- a/python/ppc_model/preprocessing/local_processing/preprocessing.py +++ b/python/ppc_model/preprocessing/local_processing/preprocessing.py @@ -103,8 +103,10 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb if model_setting.eval_set_column is not None: if model_setting.eval_set_column in dataset_df.columns: eval_column = model_setting.eval_set_column - dataset_df[['id', eval_column]].to_csv(ctx.eval_column_file, index=None) - ctx.components.storage_client.upload_file(ctx.eval_column_file, job_id + os.sep + ctx.EVAL_COLUMN_FILE) + dataset_df[['id', eval_column]].to_csv( + ctx.eval_column_file, index=None) + ctx.components.storage_client.upload_file( + ctx.eval_column_file, job_id + os.sep + ctx.EVAL_COLUMN_FILE) if model_setting.eval_set_column != model_setting.psi_select_col: dataset_df = dataset_df.drop(columns=[eval_column]) diff --git a/python/ppc_model/preprocessing/tests/test_preprocessing.py b/python/ppc_model/preprocessing/tests/test_preprocessing.py index 279e00e6..244c6a41 100644 --- a/python/ppc_model/preprocessing/tests/test_preprocessing.py +++ b/python/ppc_model/preprocessing/tests/test_preprocessing.py @@ -534,6 +534,7 @@ def test_process_train_dataframe_with_additional_columns(): # Assert that the processed DataFrame matches the expected output assert processed_df.equals(expected_output) + def test_merge_column_info_from_file(): col_info_file_path = "./test_column_info_merge.csv" iv_info_file_path = "./test_column_info_iv.csv" @@ -548,7 +549,8 @@ def test_merge_column_info_from_file(): # assert expected_df.equals(union_df) column_info_str = json.dumps(column_info_fm.to_dict(orient='index')) assert column_info_str == col_str_expected - + + def construct_dataset(num_samples, num_features, file_path): np.random.seed(0) # 生成标签列 @@ -557,24 +559,26 @@ def construct_dataset(num_samples, num_features, file_path): features = np.random.rand(num_samples, num_features) # 将标签转换为DataFrame labels_df = pd.DataFrame(labels, columns=['Label']) - + # 将特征转换为DataFrame features_df = pd.DataFrame(features) - + # 合并标签和特征DataFrame dataset_df = pd.concat([labels_df, features_df], axis=1) - + # 将DataFrame写入CSV文件 dataset_df.to_csv(file_path, index=False) - + return labels, features + def test_gen_file(): num_samples = 400000 num_features = 100 file_path = "./dataset-{}-{}.csv".format(num_samples, num_features) construct_dataset(num_samples, num_features, file_path) + def test_large_process_train_dataframe(): num_samples = 400000 num_features = 100 @@ -624,15 +628,13 @@ def test_large_process_train_dataframe(): column_info1 = process_dataframe( df_filled, xgb_dict, "./xgb_data_file_path", utils.AlgorithmType.Train.name, "j-123456") end_time = time.time() - print(f"test_large_process_train_dataframe time cost:{end_time-start_time}, num_samples: {num_samples}, num_features: {num_features}") + print( + f"test_large_process_train_dataframe time cost:{end_time-start_time}, num_samples: {num_samples}, num_features: {num_features}") + - - - - # Run the tests # pytest.main() -if __name__=="__main__": +if __name__ == "__main__": import time # test_large_process_train_dataframe() time1 = time.time() @@ -667,6 +669,7 @@ def test_large_process_train_dataframe(): print(f"test_process_psi time cost: {time8-time7}") print(f"test_process_dataframe time cost: {time9-time8}") print(f"test_process_train_dataframe time cost: {time10-time9}") - print(f"test_process_train_dataframe_with_additional_columns time cost: {time11-time10}") + print( + f"test_process_train_dataframe_with_additional_columns time cost: {time11-time10}") print(f"test_merge_column_info_from_file time cost: {time12-time11}") print("All tests pass!") diff --git a/python/ppc_model/secure_lgbm/monitor/callback.py b/python/ppc_model/secure_lgbm/monitor/callback.py index fa1247de..cd145668 100644 --- a/python/ppc_model/secure_lgbm/monitor/callback.py +++ b/python/ppc_model/secure_lgbm/monitor/callback.py @@ -42,7 +42,8 @@ def __init__( self.callbacks = set(callbacks) for cb in callbacks: if not isinstance(cb, TrainingCallback): - raise TypeError("callback must be an instance of `TrainingCallback`.") + raise TypeError( + "callback must be an instance of `TrainingCallback`.") msg = ( "feval must be callable object for monitoring. For builtin metrics" @@ -79,5 +80,6 @@ def after_iteration( ) -> bool: model.after_iteration(pred, eval_on_test) model.eval(self.feval) - ret = any(c.after_iteration(model, model.get_epoch()) for c in self.callbacks) + ret = any(c.after_iteration(model, model.get_epoch()) + for c in self.callbacks) return ret diff --git a/python/ppc_model/secure_lgbm/monitor/early_stopping.py b/python/ppc_model/secure_lgbm/monitor/early_stopping.py index de718ee3..c6140680 100644 --- a/python/ppc_model/secure_lgbm/monitor/early_stopping.py +++ b/python/ppc_model/secure_lgbm/monitor/early_stopping.py @@ -111,7 +111,8 @@ def after_iteration( ) -> bool: history = model.get_history() if len(history.keys()) < 1: - raise ValueError("Must have at least 1 validation dataset for early stopping.") + raise ValueError( + "Must have at least 1 validation dataset for early stopping.") metric_name = self.metric_name # The latest score diff --git a/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py b/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py index c1038bbb..432149a7 100644 --- a/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py +++ b/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py @@ -25,7 +25,8 @@ def _draw_figure(model: _Model): plt.plot(iterations, values, label=metric) max_index = values.index(max(values)) plt.scatter(max_index + 1, values[max_index], color='green') - plt.text(max_index + 1, values[max_index], f'{values[max_index]:.4f}', fontsize=9, ha='right') + plt.text(max_index + 1, values[max_index], + f'{values[max_index]:.4f}', fontsize=9, ha='right') plt.legend() plt.title('Metrics Over Iterations') diff --git a/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py b/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py index 6429807f..6d995868 100644 --- a/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py +++ b/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py @@ -49,7 +49,8 @@ def setUp(self): self.test_y_true = np.random.randint(0, 2, 10000) self.y_pred = np.random.rand(10000) self.model = Booster(self.y_true, self.test_y_true) - self.early_stopping = EarlyStopping(rounds=4, metric_name='auc', maximize=True) + self.early_stopping = EarlyStopping( + rounds=4, metric_name='auc', maximize=True) def test_early_stopping(self): stop = False @@ -58,7 +59,8 @@ def test_early_stopping(self): y_pred = np.random.rand(10000) self.model.after_iteration(y_pred) self.model.eval(fevaluation) - stop = self.early_stopping.after_iteration(self.model, self.model.epoch) + stop = self.early_stopping.after_iteration( + self.model, self.model.epoch) print(self.model.epoch, stop) @@ -88,9 +90,11 @@ def setUp(self): self.test_y_true = np.random.randint(0, 2, 10000) self.y_pred = np.random.rand(10000) self.model = Booster(self.y_true, self.test_y_true, 'tmp') - self.early_stopping = EarlyStopping(rounds=4, metric_name='auc', maximize=True) + self.early_stopping = EarlyStopping( + rounds=4, metric_name='auc', maximize=True) self.monitor = EvaluationMonitor(log, period=2) - self.container = CallbackContainer([self.early_stopping, self.monitor], fevaluation) + self.container = CallbackContainer( + [self.early_stopping, self.monitor], fevaluation) def test_callback_container(self): stop = False diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py b/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py index 1f481a78..203f71e1 100644 --- a/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py +++ b/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py @@ -34,5 +34,5 @@ def run(args): # 获取测试集的预测值评估指标 Evaluation(task_info, secure_dataset, test_praba=test_praba) - + ResultFileHandling(task_info) diff --git a/python/ppc_model/secure_lgbm/test/test_cipher_packing.py b/python/ppc_model/secure_lgbm/test/test_cipher_packing.py index 1f205a71..1120c6b3 100644 --- a/python/ppc_model/secure_lgbm/test/test_cipher_packing.py +++ b/python/ppc_model/secure_lgbm/test/test_cipher_packing.py @@ -20,39 +20,45 @@ def test_cipher_list(self): ciphers = paillier.encrypt_batch_parallel(data_list) # ciphers = paillier.encrypt_batch(data_list) print("enc:", time.time() - start_time, "seconds") - + start_time = time.time() enc_data_pb = CipherList() - enc_data_pb.public_key = PaillierCodec.encode_enc_key(paillier.public_key) + enc_data_pb.public_key = PaillierCodec.encode_enc_key( + paillier.public_key) for cipher in ciphers: paillier_cipher = ModelCipher() - paillier_cipher.ciphertext, paillier_cipher.exponent = PaillierCodec.encode_cipher(cipher) + paillier_cipher.ciphertext, paillier_cipher.exponent = PaillierCodec.encode_cipher( + cipher) enc_data_pb.cipher_list.append(paillier_cipher) print("pack ciphers:", time.time() - start_time, "seconds") ciphers2 = [] for i in range(100): ciphers2.append(np.array(ciphers[10*i:10*(i+1)]).sum()) - + start_time = time.time() enc_data_pb2 = CipherList() - enc_data_pb2.public_key = PaillierCodec.encode_enc_key(paillier.public_key) + enc_data_pb2.public_key = PaillierCodec.encode_enc_key( + paillier.public_key) for cipher in ciphers2: paillier_cipher2 = ModelCipher() - paillier_cipher2.ciphertext, paillier_cipher2.exponent = PaillierCodec.encode_cipher(cipher, be_secure=False) + paillier_cipher2.ciphertext, paillier_cipher2.exponent = PaillierCodec.encode_cipher( + cipher, be_secure=False) enc_data_pb2.cipher_list.append(paillier_cipher2) print("pack ciphers:", time.time() - start_time, "seconds") ciphers3 = [] for i in range(100): ciphers3.append(np.array(ciphers[10*i:10*(i+1)]).sum()) - + start_time = time.time() enc_data_pb3 = CipherList() - enc_data_pb3.public_key = PaillierCodec.encode_enc_key(paillier.public_key) + enc_data_pb3.public_key = PaillierCodec.encode_enc_key( + paillier.public_key) for cipher in ciphers3: paillier_cipher3 = ModelCipher() - paillier_cipher3.ciphertext, paillier_cipher3.exponent = PaillierCodec.encode_cipher(cipher) + paillier_cipher3.ciphertext, paillier_cipher3.exponent = PaillierCodec.encode_cipher( + cipher) enc_data_pb3.cipher_list.append(paillier_cipher3) print("pack ciphers:", time.time() - start_time, "seconds") diff --git a/python/ppc_model/secure_lgbm/test/test_pack_gh.py b/python/ppc_model/secure_lgbm/test/test_pack_gh.py index dba82280..798de369 100644 --- a/python/ppc_model/secure_lgbm/test/test_pack_gh.py +++ b/python/ppc_model/secure_lgbm/test/test_pack_gh.py @@ -15,16 +15,16 @@ def test_pack_gh(self): result_array = np.array( [429496329600000000000000002000, 200000000000000000001500, - 429496599600000000004294965896, 0, + 429496599600000000004294965896, 0, 4294965616, 429495194200000000000000001235], dtype=object) - + assert np.array_equal(gh_list, result_array) def test_unpack_gh(self): gh_list = np.array( [429496329600000000000000002000, 200000000000000000001500, - 429496599600000000004294965896, 0, + 429496599600000000004294965896, 0, 4294965616, 429495194200000000000000001235], dtype=object) gh_sum_list = np.array([sum(gh_list), sum(gh_list)*2]) diff --git a/python/ppc_model/secure_lgbm/test/test_save_load_model.py b/python/ppc_model/secure_lgbm/test/test_save_load_model.py index a899b805..e42c4369 100644 --- a/python/ppc_model/secure_lgbm/test/test_save_load_model.py +++ b/python/ppc_model/secure_lgbm/test/test_save_load_model.py @@ -56,7 +56,7 @@ def test_save_load_model(self): booster_predict = VerticalBooster(task_info, dataset=None) booster_predict.load_model() - + assert x_split == booster_predict._X_split assert trees == booster_predict._trees @@ -66,15 +66,15 @@ def _build_tree(max_depth, depth=0, weight=0): if depth == max_depth: return weight - best_split_info = BestSplitInfo( - feature=np.random.randint(0,10), - value=np.random.randint(0,4), + best_split_info = BestSplitInfo( + feature=np.random.randint(0, 10), + value=np.random.randint(0, 4), best_gain=np.random.rand(), w_left=np.random.rand(), w_right=np.random.rand(), - agency_idx=np.random.randint(0,2), - agency_feature=np.random.randint(0,5) - ) + agency_idx=np.random.randint(0, 2), + agency_feature=np.random.randint(0, 5) + ) # print(best_split_info) if best_split_info.best_gain > 0.2: diff --git a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py index 936a53ac..e992415d 100644 --- a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py +++ b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py @@ -84,8 +84,10 @@ def setUp(self): send_retry_times=3, retry_interval_s=0.1 ) - self._active_rpc_client.set_message_handler(self._passive_stub.on_message_received) - self._passive_rpc_client.set_message_handler(self._active_stub.on_message_received) + self._active_rpc_client.set_message_handler( + self._passive_stub.on_message_received) + self._passive_rpc_client.set_message_handler( + self._active_stub.on_message_received) def test_fit(self): args_a, args_b = mock_args() @@ -96,7 +98,8 @@ def test_fit(self): 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} active_components.mock_logger = MockLogger() task_info_a = SecureLGBMContext(args_a, active_components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=True) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=True) secure_dataset_a = SecureDataset(task_info_a, model_data) booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) print(secure_dataset_a.feature_name) @@ -113,7 +116,8 @@ def test_fit(self): 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} passive_components.mock_logger = MockLogger() task_info_b = SecureLGBMContext(args_b, passive_components) - model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=False) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=False) secure_dataset_b = SecureDataset(task_info_b, model_data) booster_b = VerticalLGBMPassiveParty(task_info_b, secure_dataset_b) print(secure_dataset_b.feature_name) @@ -128,14 +132,16 @@ def active_worker(): booster_a.save_model() train_praba = booster_a.get_train_praba() test_praba = booster_a.get_test_praba() - Evaluation(task_info_a, secure_dataset_a, train_praba, test_praba) + Evaluation(task_info_a, secure_dataset_a, + train_praba, test_praba) # ModelPlot(booster_a) ResultFileHandling(task_info_a) booster_a.load_model() booster_a.predict() test_praba = booster_a.get_test_praba() task_info_a.algorithm_type = 'PPC_PREDICT' - Evaluation(task_info_a, secure_dataset_a, test_praba=test_praba) + Evaluation(task_info_a, secure_dataset_a, + test_praba=test_praba) ResultFileHandling(task_info_a) except Exception as e: task_info_a.components.logger().info(traceback.format_exc()) @@ -146,14 +152,16 @@ def passive_worker(): booster_b.save_model() train_praba = booster_b.get_train_praba() test_praba = booster_b.get_test_praba() - Evaluation(task_info_b, secure_dataset_b, train_praba, test_praba) + Evaluation(task_info_b, secure_dataset_b, + train_praba, test_praba) # ModelPlot(booster_b) ResultFileHandling(task_info_b) booster_b.load_model() booster_b.predict() test_praba = booster_b.get_test_praba() task_info_b.algorithm_type = 'PPC_PREDICT' - Evaluation(task_info_b, secure_dataset_b, test_praba=test_praba) + Evaluation(task_info_b, secure_dataset_b, + test_praba=test_praba) ResultFileHandling(task_info_b) except Exception as e: task_info_b.components.logger().info(traceback.format_exc()) diff --git a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py index 7cddb6db..ea0806ce 100644 --- a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py +++ b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py @@ -134,14 +134,16 @@ def active_worker(): booster_a.save_model() train_praba = booster_a.get_train_praba() test_praba = booster_a.get_test_praba() - Evaluation(task_info_a, secure_dataset_a, train_praba, test_praba) + Evaluation(task_info_a, secure_dataset_a, + train_praba, test_praba) ModelPlot(booster_a) ResultFileHandling(task_info_a) booster_a.load_model() booster_a.predict() test_praba = booster_a.get_test_praba() task_info_a.algorithm_type = 'PPC_PREDICT' - Evaluation(task_info_a, secure_dataset_a, test_praba=test_praba) + Evaluation(task_info_a, secure_dataset_a, + test_praba=test_praba) ResultFileHandling(task_info_a) except Exception as e: task_info_a.components.logger().info(traceback.format_exc()) @@ -152,14 +154,16 @@ def passive_worker(): booster_b.save_model() train_praba = booster_b.get_train_praba() test_praba = booster_b.get_test_praba() - Evaluation(task_info_b, secure_dataset_b, train_praba, test_praba) + Evaluation(task_info_b, secure_dataset_b, + train_praba, test_praba) ModelPlot(booster_b) ResultFileHandling(task_info_b) booster_b.load_model() booster_b.predict() test_praba = booster_b.get_test_praba() task_info_b.algorithm_type = 'PPC_PREDICT' - Evaluation(task_info_b, secure_dataset_b, test_praba=test_praba) + Evaluation(task_info_b, secure_dataset_b, + test_praba=test_praba) ResultFileHandling(task_info_b) except Exception as e: task_info_b.components.logger().info(traceback.format_exc()) diff --git a/python/ppc_model/secure_lgbm/vertical/active_party.py b/python/ppc_model/secure_lgbm/vertical/active_party.py index 524c8266..6451db6b 100644 --- a/python/ppc_model/secure_lgbm/vertical/active_party.py +++ b/python/ppc_model/secure_lgbm/vertical/active_party.py @@ -35,7 +35,8 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: self.storage_client = ctx.components.storage_client self.feature_importance_store = FeatureImportanceStore( FeatureImportanceStore.DEFAULT_IMPORTANCE_LIST, None, self.log) - self.log.info(f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') + self.log.info( + f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') def fit( self, @@ -51,7 +52,8 @@ def fit( for _ in range(self.params.n_estimators): self._tree_id += 1 start_time = time.time() - self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} in active party.') + self.log.info( + f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} in active party.') # 初始化 feature_select, instance, used_glist, used_hlist = self._init_each_tree() @@ -72,21 +74,27 @@ def fit( # 评估 if not self.params.silent and self.dataset.train_y is not None: - auc = Evaluation.fevaluation(self.dataset.train_y, self._train_praba)['auc'] - self.log.info(f'task {self.ctx.task_id}: n_estimators-{self._tree_id}, auc: {auc}.') + auc = Evaluation.fevaluation( + self.dataset.train_y, self._train_praba)['auc'] + self.log.info( + f'task {self.ctx.task_id}: n_estimators-{self._tree_id}, auc: {auc}.') self.log.info(f'task {self.ctx.task_id}: Ending n_estimators-{self._tree_id}, ' f'time_costs: {time.time() - start_time}s.') # 预测验证集 self._test_weights += self._predict_tree( - tree, self._test_X_bin, np.ones(self._test_X_bin.shape[0], dtype=bool), + tree, self._test_X_bin, np.ones( + self._test_X_bin.shape[0], dtype=bool), LGBMMessage.TEST_LEAF_MASK.value) self._test_praba = self._loss_func.sigmoid(self._test_weights) if not self.params.silent and self.dataset.test_y is not None: - auc = Evaluation.fevaluation(self.dataset.test_y, self._test_praba)['auc'] - self.log.info(f'task {self.ctx.task_id}: n_estimators-{self._tree_id}, test auc: {auc}.') + auc = Evaluation.fevaluation( + self.dataset.test_y, self._test_praba)['auc'] + self.log.info( + f'task {self.ctx.task_id}: n_estimators-{self._tree_id}, test auc: {auc}.') if self._iteration_early_stop(): - self.log.info(f"task {self.ctx.task_id}: lgbm early stop after {self._tree_id} iterations.") + self.log.info( + f"task {self.ctx.task_id}: lgbm early stop after {self._tree_id} iterations.") break self._end_active_data() @@ -103,7 +111,8 @@ def predict(self, dataset: SecureDataset = None) -> np.ndarray: dataset.feature_name, self.params.categorical_feature) test_weights = self._init_weight(dataset.test_X.shape[0]) - test_X_bin = self._split_test_data(self.ctx, dataset.test_X, self._X_split) + test_X_bin = self._split_test_data( + self.ctx, dataset.test_X, self._X_split) for tree in self._trees: test_weights += self._predict_tree( @@ -114,7 +123,8 @@ def predict(self, dataset: SecureDataset = None) -> np.ndarray: if dataset.test_y is not None: auc = Evaluation.fevaluation(dataset.test_y, test_praba)['auc'] self.log.info(f'task {self.ctx.task_id}: predict test auc: {auc}.') - self.log.info(f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') + self.log.info( + f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') self._end_active_data(is_train=False) @@ -127,9 +137,12 @@ def _init_active_data(self): # 初始化所有参与方的特征 for i in range(1, len(self.ctx.participant_id_list)): - feature_name_bytes = self._receive_byte_data(self.ctx, LGBMMessage.FEATURE_NAME.value, i) - self._all_feature_name.append([s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s]) - self._all_feature_num += len([s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s]) + feature_name_bytes = self._receive_byte_data( + self.ctx, LGBMMessage.FEATURE_NAME.value, i) + self._all_feature_name.append( + [s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s]) + self._all_feature_num += len([s.decode('utf-8') + for s in feature_name_bytes.split(b' ') if s]) self.log.info(f'task {self.ctx.task_id}: total feature number:{self._all_feature_num}, ' f'total feature name: {self._all_feature_name}.') @@ -139,18 +152,21 @@ def _init_active_data(self): self.dataset.feature_name, self.params.categorical_feature) # 更新feature_importance中的特征列表 - self.feature_importance_store.set_init(list(itertools.chain(*self._all_feature_name))) + self.feature_importance_store.set_init( + list(itertools.chain(*self._all_feature_name))) # 初始化分桶数据集 feat_bin = FeatureBinning(self.ctx) - self._X_bin, self._X_split = feat_bin.data_binning(self.dataset.train_X) + self._X_bin, self._X_split = feat_bin.data_binning( + self.dataset.train_X) def _init_each_tree(self): if self.callback_container: self.callback_container.before_iteration(self.model) - gradient = self._loss_func.compute_gradient(self.dataset.train_y, self._train_praba) + gradient = self._loss_func.compute_gradient( + self.dataset.train_y, self._train_praba) hessian = self._loss_func.compute_hessian(self._train_praba) feature_select = FeatureSelection.feature_selecting( @@ -171,7 +187,8 @@ def _send_gh_instance_list(self, instance, glist, hlist): start_time = time.time() self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} ' f'encrypt g & h in active party.') - enc_ghlist = self.ctx.phe.encrypt_batch_parallel((gh_list).astype('object')) + enc_ghlist = self.ctx.phe.encrypt_batch_parallel( + (gh_list).astype('object')) self.log.info(f'task {self.ctx.task_id}: Finished n_estimators-{self._tree_id} ' f'encrypt gradient & hessian time_costs: {time.time() - start_time}.') @@ -194,18 +211,22 @@ def _build_tree(self, feature_select, instance, glist, hlist, depth=0, weight=0) if self.params.colsample_bylevel > 0 and self.params.colsample_bylevel < 1: feature_select_level = sorted(np.random.choice( feature_select, size=int(len(feature_select) * self.params.colsample_bylevel), replace=False)) - best_split_info = self._find_best_split(feature_select_level, instance, glist, hlist) + best_split_info = self._find_best_split( + feature_select_level, instance, glist, hlist) else: - best_split_info = self._find_best_split(feature_select, instance, glist, hlist) + best_split_info = self._find_best_split( + feature_select, instance, glist, hlist) if best_split_info.best_gain > 0 and best_split_info.best_gain > self.params.min_split_gain: gain_list = {FeatureImportanceType.GAIN: best_split_info.best_gain, FeatureImportanceType.WEIGHT: 1} - self.feature_importance_store.update_feature_importance(best_split_info.feature, gain_list) - left_mask, right_mask = self._get_leaf_mask(best_split_info, instance) + self.feature_importance_store.update_feature_importance( + best_split_info.feature, gain_list) + left_mask, right_mask = self._get_leaf_mask( + best_split_info, instance) if (abs(best_split_info.w_left) * sum(left_mask) / self.params.lr) < self.params.min_child_weight or \ - (abs(best_split_info.w_right) * sum(right_mask) / self.params.lr) < self.params.min_child_weight: + (abs(best_split_info.w_right) * sum(right_mask) / self.params.lr) < self.params.min_child_weight: return weight if sum(left_mask) < self.params.min_child_samples or sum(right_mask) < self.params.min_child_samples: return weight @@ -229,9 +250,11 @@ def _predict_tree(self, tree, X_bin, leaf_mask, key_type): if self.ctx.participant_id_list[best_split_info.agency_idx] == \ self.ctx.components.config_data['AGENCY_ID']: if best_split_info.agency_feature in self.params.my_categorical_idx: - left_mask = X_bin[:, best_split_info.agency_feature] == best_split_info.value + left_mask = X_bin[:, + best_split_info.agency_feature] == best_split_info.value else: - left_mask = X_bin[:, best_split_info.agency_feature] <= best_split_info.value + left_mask = X_bin[:, + best_split_info.agency_feature] <= best_split_info.value else: left_mask = np.frombuffer( self._receive_byte_data( @@ -239,8 +262,10 @@ def _predict_tree(self, tree, X_bin, leaf_mask, key_type): f'{key_type}_{best_split_info.tree_id}_{best_split_info.leaf_id}', best_split_info.agency_idx), dtype='bool') right_mask = ~left_mask - left_weight = self._predict_tree(left_subtree, X_bin, leaf_mask * left_mask, key_type) - right_weight = self._predict_tree(right_subtree, X_bin, leaf_mask * right_mask, key_type) + left_weight = self._predict_tree( + left_subtree, X_bin, leaf_mask * left_mask, key_type) + right_weight = self._predict_tree( + right_subtree, X_bin, leaf_mask * right_mask, key_type) return left_weight + right_weight def _find_best_split(self, feature_select, instance, glist, hlist): @@ -248,7 +273,8 @@ def _find_best_split(self, feature_select, instance, glist, hlist): self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} ' f'leaf-{self._leaf_id} in active party.') grad_hist, hess_hist = self._get_gh_hist(instance, glist, hlist) - best_split_info = self._get_best_split_point(feature_select, glist, hlist, grad_hist, hess_hist) + best_split_info = self._get_best_split_point( + feature_select, glist, hlist, grad_hist, hess_hist) # print('grad_hist_sum', [sum(sublist) for sublist in grad_hist]) best_split_info.tree_id = self._tree_id @@ -271,7 +297,8 @@ def _find_best_split(self, feature_select, instance, glist, hlist): return best_split_info def _get_gh_hist(self, instance, glist, hlist): - ghist, hhist = self._calculate_hist(self._X_bin, instance, glist, hlist) + ghist, hhist = self._calculate_hist( + self._X_bin, instance, glist, hlist) for partner_index in range(1, len(self.ctx.participant_id_list)): partner_feature_name = self._all_feature_name[partner_index] @@ -283,7 +310,8 @@ def _get_gh_hist(self, instance, glist, hlist): partner_index, matrix_data=True) for feature_index in range(len(partner_feature_name)): - ghk_hist = np.array(self.ctx.phe.decrypt_batch(gh_hist[feature_index]), dtype='object') + ghk_hist = np.array(self.ctx.phe.decrypt_batch( + gh_hist[feature_index]), dtype='object') gk_hist, hk_hist = self.unpacking_gh(ghk_hist) partner_ghist[feature_index] = gk_hist partner_hhist[feature_index] = hk_hist @@ -374,7 +402,8 @@ def _get_best_split_agency(all_feature_name, feature): def _init_valid_data(self): self._test_weights = self._init_weight(self.dataset.test_X.shape[0]) - self._test_X_bin = self._split_test_data(self.ctx, self.dataset.test_X, self._X_split) + self._test_X_bin = self._split_test_data( + self.ctx, self.dataset.test_X, self._X_split) def _init_early_stop(self): @@ -382,17 +411,20 @@ def _init_early_stop(self): early_stopping_rounds = self.params.early_stopping_rounds if early_stopping_rounds != 0: eval_metric = self.params.eval_metric - early_stopping = EarlyStopping(rounds=early_stopping_rounds, metric_name=eval_metric, save_best=True) + early_stopping = EarlyStopping( + rounds=early_stopping_rounds, metric_name=eval_metric, save_best=True) callbacks.append(early_stopping) verbose_eval = self.params.verbose_eval if verbose_eval != 0: - evaluation_monitor = EvaluationMonitor(logger=self.log, period=verbose_eval) + evaluation_monitor = EvaluationMonitor( + logger=self.log, period=verbose_eval) callbacks.append(evaluation_monitor) callback_container = None if len(callbacks) != 0: - callback_container = CallbackContainer(callbacks=callbacks, feval=Evaluation.fevaluation) + callback_container = CallbackContainer( + callbacks=callbacks, feval=Evaluation.fevaluation) model = Booster(y_true=self.dataset.train_y, test_y_true=self.dataset.test_y, workspace=self.ctx.workspace, job_id=self.ctx.job_id, @@ -419,7 +451,8 @@ def _iteration_early_stop(self): stop = self.callback_container.after_iteration(model=self.model, pred=pred, eval_on_test=eval_on_test) - self.log.info(f"task {self.ctx.task_id}: after iteration {self._tree_id} iterations, stop: {stop}.") + self.log.info( + f"task {self.ctx.task_id}: after iteration {self._tree_id} iterations, stop: {stop}.") iteration_request = IterationRequest() iteration_request.epoch = self._tree_id - 1 diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index 468158cf..c191b6f2 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -51,7 +51,7 @@ def _init_weight(self, n): return np.zeros(n, dtype=float) @staticmethod - def _get_categorical_idx(feature_name, categorical_feature = []): + def _get_categorical_idx(feature_name, categorical_feature=[]): categorical_idx = [] if len(categorical_feature) > 0: for i in categorical_feature: @@ -86,7 +86,7 @@ def _compute_leaf_weight(lr, λ, gl, hl, gr, hr, reg_alpha): @staticmethod def _calulate_weight(lr, λ, g, h, reg_alpha): - + # weight = lr * - g / (h + λ) if (h + λ) != 0 and g > reg_alpha: weight = lr * - (g - reg_alpha) / (h + λ) @@ -117,15 +117,15 @@ def _get_leaf_mask(self, split_info, instance): for partner_index in range(0, len(self.ctx.participant_id_list)): if self.ctx.participant_id_list[partner_index] != self.ctx.components.config_data['AGENCY_ID']: self._send_byte_data( - self.ctx, f'{LGBMMessage.INSTANCE_MASK.value}_{self._tree_id}_{self._leaf_id}', + self.ctx, f'{LGBMMessage.INSTANCE_MASK.value}_{self._tree_id}_{self._leaf_id}', left_mask.astype('bool').tobytes(), partner_index) else: left_mask = np.frombuffer( self._receive_byte_data( - self.ctx, f'{LGBMMessage.INSTANCE_MASK.value}_{self._tree_id}_{self._leaf_id}', + self.ctx, f'{LGBMMessage.INSTANCE_MASK.value}_{self._tree_id}_{self._leaf_id}', split_info.agency_idx), dtype='bool') right_mask = ~left_mask - + return left_mask, right_mask def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=False): @@ -138,14 +138,16 @@ def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=Fal receiver=partner_id, task_id=ctx.task_id, key=key_type, - data=PheMessage.packing_2dim_data(ctx.codec, ctx.phe.public_key, enc_data) + data=PheMessage.packing_2dim_data( + ctx.codec, ctx.phe.public_key, enc_data) )) else: self._stub.push(PushRequest( receiver=partner_id, task_id=ctx.task_id, key=key_type, - data=PheMessage.packing_data(ctx.codec, ctx.phe.public_key, enc_data) + data=PheMessage.packing_data( + ctx.codec, ctx.phe.public_key, enc_data) )) log.info( @@ -164,9 +166,11 @@ def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False): )) if matrix_data: - public_key, enc_data = PheMessage.unpacking_2dim_data(ctx.codec, byte_data) + public_key, enc_data = PheMessage.unpacking_2dim_data( + ctx.codec, byte_data) else: - public_key, enc_data = PheMessage.unpacking_data(ctx.codec, byte_data) + public_key, enc_data = PheMessage.unpacking_data( + ctx.codec, byte_data) log.info( f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, " @@ -213,51 +217,60 @@ def _split_test_data(ctx, test_X, X_split): def save_model(self, file_path=None): log = self.ctx.components.logger() if file_path is not None: - self.ctx.feature_bin_file = os.path.join(file_path, self.ctx.FEATURE_BIN_FILE) - self.ctx.model_data_file = os.path.join(file_path, self.ctx.MODEL_DATA_FILE) - + self.ctx.feature_bin_file = os.path.join( + file_path, self.ctx.FEATURE_BIN_FILE) + self.ctx.model_data_file = os.path.join( + file_path, self.ctx.MODEL_DATA_FILE) + if self._X_split is not None and not os.path.exists(self.ctx.feature_bin_file): - X_split_dict = {k: v for k, v in zip(self.dataset.feature_name, self._X_split)} + X_split_dict = {k: v for k, v in zip( + self.dataset.feature_name, self._X_split)} with open(self.ctx.feature_bin_file, 'w') as f: json.dump(X_split_dict, f) - ResultFileHandling._upload_file(self.ctx.components.storage_client, + ResultFileHandling._upload_file(self.ctx.components.storage_client, self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) - log.info(f"task {self.ctx.task_id}: Saved x_split to {self.ctx.feature_bin_file} finished.") - + log.info( + f"task {self.ctx.task_id}: Saved x_split to {self.ctx.feature_bin_file} finished.") + if not os.path.exists(self.ctx.model_data_file): serial_trees = [self._serial_tree(tree) for tree in self._trees] with open(self.ctx.model_data_file, 'w') as f: json.dump(serial_trees, f) - ResultFileHandling._upload_file(self.ctx.components.storage_client, + ResultFileHandling._upload_file(self.ctx.components.storage_client, self.ctx.model_data_file, self.ctx.remote_model_data_file) - log.info(f"task {self.ctx.task_id}: Saved serial_trees to {self.ctx.model_data_file} finished.") - + log.info( + f"task {self.ctx.task_id}: Saved serial_trees to {self.ctx.model_data_file} finished.") + def load_model(self, file_path=None): log = self.ctx.components.logger() if file_path is not None: - self.ctx.feature_bin_file = os.path.join(file_path, self.ctx.FEATURE_BIN_FILE) - self.ctx.model_data_file = os.path.join(file_path, self.ctx.MODEL_DATA_FILE) + self.ctx.feature_bin_file = os.path.join( + file_path, self.ctx.FEATURE_BIN_FILE) + self.ctx.model_data_file = os.path.join( + file_path, self.ctx.MODEL_DATA_FILE) if self.ctx.algorithm_type == AlgorithmType.Predict.name: self.ctx.remote_feature_bin_file = os.path.join( self.ctx.lgbm_params.training_job_id, self.ctx.FEATURE_BIN_FILE) self.ctx.remote_model_data_file = os.path.join( self.ctx.lgbm_params.training_job_id, self.ctx.MODEL_DATA_FILE) - - ResultFileHandling._download_file(self.ctx.components.storage_client, + + ResultFileHandling._download_file(self.ctx.components.storage_client, self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) - ResultFileHandling._download_file(self.ctx.components.storage_client, + ResultFileHandling._download_file(self.ctx.components.storage_client, self.ctx.model_data_file, self.ctx.remote_model_data_file) with open(self.ctx.feature_bin_file, 'r') as f: X_split_dict = json.load(f) feature_name = list(X_split_dict.keys()) x_split = list(X_split_dict.values()) - log.info(f"task {self.ctx.task_id}: Load x_split from {self.ctx.feature_bin_file} finished.") + log.info( + f"task {self.ctx.task_id}: Load x_split from {self.ctx.feature_bin_file} finished.") assert len(feature_name) == len(self.dataset.feature_name) with open(self.ctx.model_data_file, 'r') as f: serial_trees = json.load(f) - log.info(f"task {self.ctx.task_id}: Load serial_trees from {self.ctx.model_data_file} finished.") + log.info( + f"task {self.ctx.task_id}: Load serial_trees from {self.ctx.model_data_file} finished.") trees = [self._deserial_tree(tree) for tree in serial_trees] self._X_split = x_split @@ -270,7 +283,8 @@ def _serial_tree(tree): best_split_info, left_tree, right_tree = tree[0] best_split_info_list = [] for field in best_split_info.DESCRIPTOR.fields: - best_split_info_list.append(getattr(best_split_info, field.name)) + best_split_info_list.append( + getattr(best_split_info, field.name)) left_tree = VerticalBooster._serial_tree(left_tree) right_tree = VerticalBooster._serial_tree(right_tree) best_split_info_list.extend([left_tree, right_tree]) @@ -312,7 +326,7 @@ def packing_gh(g_list: np.ndarray, h_list: np.ndarray, expand=1000, mod_length=3 浮点数转整数默乘以 1000(取3位小数) 按照最高数据量100w样本, g/h求和值上限为 1000 * 10**6 = 10**9 基于g/h上限, 负数模运算转正数需要加上 2**32 (4.29*10**9) - + 2. packing g/h负数模运算转为正数后最大值为 2**32-1, 100w样本求和需要预留10**6位 packing g和h时, 对g乘以10**20, 为h预留总计20位长度。 @@ -320,9 +334,10 @@ def packing_gh(g_list: np.ndarray, h_list: np.ndarray, expand=1000, mod_length=3 mod_n = 2 ** mod_length pos_int_glist = ((g_list * expand).astype('int64') + mod_n) % mod_n pos_int_hlist = ((h_list * expand).astype('int64') + mod_n) % mod_n - - gh_list = pos_int_glist.astype('object') * 10**pack_length + pos_int_hlist.astype('object') - + + gh_list = pos_int_glist.astype( + 'object') * 10**pack_length + pos_int_hlist.astype('object') + return gh_list @staticmethod diff --git a/python/ppc_model/secure_lgbm/vertical/passive_party.py b/python/ppc_model/secure_lgbm/vertical/passive_party.py index 2f09f6d1..321f3651 100644 --- a/python/ppc_model/secure_lgbm/vertical/passive_party.py +++ b/python/ppc_model/secure_lgbm/vertical/passive_party.py @@ -17,7 +17,8 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: super().__init__(ctx, dataset) self.params = ctx.lgbm_params self.log = ctx.components.logger() - self.log.info(f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') + self.log.info( + f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') def fit( self, @@ -27,31 +28,37 @@ def fit( self.log.info( f'task {self.ctx.task_id}: Starting the lgbm on the passive party.') self._init_passive_data() - self._test_X_bin = self._split_test_data(self.ctx, self.dataset.test_X, self._X_split) + self._test_X_bin = self._split_test_data( + self.ctx, self.dataset.test_X, self._X_split) for _ in range(self.params.n_estimators): self._tree_id += 1 start_time = time.time() - self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} in passive party.') + self.log.info( + f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} in passive party.') # 初始化 instance, used_ghlist, public_key = self._receive_gh_instance_list() self.ctx.phe.public_key = public_key - self.log.info(f'task {self.ctx.task_id}: Sampling number: {len(instance)}.') + self.log.info( + f'task {self.ctx.task_id}: Sampling number: {len(instance)}.') # 构建 tree = self._build_tree(instance, used_ghlist) self._trees.append(tree) # 预测 - self._predict_tree(tree, self._X_bin, LGBMMessage.PREDICT_LEAF_MASK.value) + self._predict_tree(tree, self._X_bin, + LGBMMessage.PREDICT_LEAF_MASK.value) self.log.info(f'task {self.ctx.task_id}: Ending n_estimators-{self._tree_id}, ' f'time_costs: {time.time() - start_time}s.') # 预测验证集 - self._predict_tree(tree, self._test_X_bin, LGBMMessage.TEST_LEAF_MASK.value) + self._predict_tree(tree, self._test_X_bin, + LGBMMessage.TEST_LEAF_MASK.value) if self._iteration_early_stop(): - self.log.info(f"task {self.ctx.task_id}: lgbm early stop after {self._tree_id} iterations.") + self.log.info( + f"task {self.ctx.task_id}: lgbm early stop after {self._tree_id} iterations.") break self._end_passive_data() @@ -67,11 +74,13 @@ def predict(self, dataset: SecureDataset = None) -> np.ndarray: self.params.my_categorical_idx = self._get_categorical_idx( dataset.feature_name, self.params.categorical_feature) - test_X_bin = self._split_test_data(self.ctx, dataset.test_X, self._X_split) + test_X_bin = self._split_test_data( + self.ctx, dataset.test_X, self._X_split) [self._predict_tree( tree, test_X_bin, LGBMMessage.VALID_LEAF_MASK.value) for tree in self._trees] - self.log.info(f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') + self.log.info( + f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') self._end_passive_data(is_train=False) @@ -116,10 +125,11 @@ def _build_tree(self, instance, ghlist, depth=0, weight=0): best_split_info = self._find_best_split(instance, ghlist) if best_split_info.best_gain > 0 and best_split_info.best_gain > self.params.min_split_gain: - left_mask, right_mask = self._get_leaf_mask(best_split_info, instance) + left_mask, right_mask = self._get_leaf_mask( + best_split_info, instance) if (abs(best_split_info.w_left) * sum(left_mask) / self.params.lr) < self.params.min_child_weight or \ - (abs(best_split_info.w_right) * sum(right_mask) / self.params.lr) < self.params.min_child_weight: + (abs(best_split_info.w_right) * sum(right_mask) / self.params.lr) < self.params.min_child_weight: return weight if sum(left_mask) < self.params.min_child_samples or sum(right_mask) < self.params.min_child_samples: return weight @@ -143,9 +153,11 @@ def _predict_tree(self, tree, X_bin, key_type): if self.ctx.participant_id_list[best_split_info.agency_idx] == \ self.ctx.components.config_data['AGENCY_ID']: if best_split_info.agency_feature in self.params.my_categorical_idx: - left_mask = X_bin[:, best_split_info.agency_feature] == best_split_info.value + left_mask = X_bin[:, + best_split_info.agency_feature] == best_split_info.value else: - left_mask = X_bin[:, best_split_info.agency_feature] <= best_split_info.value + left_mask = X_bin[:, + best_split_info.agency_feature] <= best_split_info.value self._send_byte_data( self.ctx, f'{key_type}_{best_split_info.tree_id}_{best_split_info.leaf_id}', diff --git a/python/ppc_model/task/task_manager.py b/python/ppc_model/task/task_manager.py index d6a308d0..d3ea8e33 100644 --- a/python/ppc_model/task/task_manager.py +++ b/python/ppc_model/task/task_manager.py @@ -51,16 +51,19 @@ def run_task(self, task_id: str, task_type: ModelTask, args=()): job_id = args[0]['job_id'] with self._rw_lock.gen_wlock(): if task_id in self._tasks: - self.logger.info(f"Task already exists, task_id: {task_id}, status: {self._tasks[task_id][0]}") + self.logger.info( + f"Task already exists, task_id: {task_id}, status: {self._tasks[task_id][0]}") return - self._tasks[task_id] = [TaskStatus.RUNNING.value, datetime.datetime.now(), 0, args[0]['job_id']] + self._tasks[task_id] = [TaskStatus.RUNNING.value, + datetime.datetime.now(), 0, args[0]['job_id']] if job_id in self._jobs: self._jobs[job_id].add(task_id) else: self._jobs[job_id] = {task_id} self.logger.info(LOG_START_FLAG_FORMATTER.format(job_id=job_id)) self.logger.info(f"Run task, job_id: {job_id}, task_id: {task_id}") - self._async_executor.execute(task_id, self._handlers[task_type.value], self._on_task_finish, args) + self._async_executor.execute( + task_id, self._handlers[task_type.value], self._on_task_finish, args) def kill_task(self, job_id: str): """ @@ -152,7 +155,8 @@ def _cleanup_finished_tasks(self): del self._jobs[job_id] self._thread_event_manager.remove_event(task_id) self._stub.cleanup_cache(task_id) - self.logger.info(f"Cleanup task cache, task_id: {task_id}, job_id: {job_id}") + self.logger.info( + f"Cleanup task cache, task_id: {task_id}, job_id: {job_id}") def record_model_job_log(self, job_id): log_file = self._get_log_file_path() diff --git a/python/ppc_model/task/test/task_manager_unittest.py b/python/ppc_model/task/test/task_manager_unittest.py index 64f2a86f..060d0d6c 100644 --- a/python/ppc_model/task/test/task_manager_unittest.py +++ b/python/ppc_model/task/test/task_manager_unittest.py @@ -126,8 +126,10 @@ def test_kill_task(self): 'job_id': '0x123456789', 'key': 'TEST_MESSAGE', } - self._task_manager.run_task("my_long_task", ModelTask.XGB_PREDICTING, (args,)) - self.assertEqual(self._task_manager.status("my_long_task")[0], 'RUNNING') + self._task_manager.run_task( + "my_long_task", ModelTask.XGB_PREDICTING, (args,)) + self.assertEqual(self._task_manager.status( + "my_long_task")[0], 'RUNNING') self._task_manager.kill_task("0x123456789") time.sleep(1) self.assertEqual(self._task_manager.status( diff --git a/python/ppc_model_gateway/ppc_model_gateway_app.py b/python/ppc_model_gateway/ppc_model_gateway_app.py index f64a96c4..ef55c4e6 100644 --- a/python/ppc_model_gateway/ppc_model_gateway_app.py +++ b/python/ppc_model_gateway/ppc_model_gateway_app.py @@ -1,18 +1,16 @@ +from ppc_model_gateway.endpoints.partner_to_node import PartnerToNodeService +from ppc_model_gateway.endpoints.node_to_partner import NodeToPartnerService +from ppc_model_gateway import config +from ppc_common.ppc_utils import utils +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +import grpc +from threading import Thread +from concurrent import futures import os # Note: here can't be refactored by autopep import sys sys.path.append("../") -from concurrent import futures -from threading import Thread - -import grpc - -from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc -from ppc_common.ppc_utils import utils -from ppc_model_gateway import config -from ppc_model_gateway.endpoints.node_to_partner import NodeToPartnerService -from ppc_model_gateway.endpoints.partner_to_node import PartnerToNodeService log = config.get_logger() @@ -22,7 +20,8 @@ def node_to_partner_serve(): ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), options=config.grpc_options) - ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(NodeToPartnerService(), ppc_serve) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server( + NodeToPartnerService(), ppc_serve) address = "[::]:{}".format(rpc_port) ppc_serve.add_insecure_port(address) @@ -40,7 +39,8 @@ def partner_to_node_serve(): if config.CONFIG_DATA['SSL_SWITCH'] == 0: ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), options=config.grpc_options) - ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(PartnerToNodeService(), ppc_serve) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server( + PartnerToNodeService(), ppc_serve) address = "[::]:{}".format(rpc_port) ppc_serve.add_insecure_port(address) else: @@ -57,7 +57,8 @@ def partner_to_node_serve(): ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), options=config.grpc_options) - ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(PartnerToNodeService(), ppc_serve) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server( + PartnerToNodeService(), ppc_serve) address = "[::]:{}".format(rpc_port) ppc_serve.add_secure_port(address, server_credentials) diff --git a/python/ppc_model_gateway/test/server.py b/python/ppc_model_gateway/test/server.py index 1df95ecf..03d2a93a 100644 --- a/python/ppc_model_gateway/test/server.py +++ b/python/ppc_model_gateway/test/server.py @@ -21,7 +21,8 @@ def serve(): server = grpc.server(futures.ThreadPoolExecutor( max_workers=max(1, os.cpu_count() - 1)), options=config.grpc_options) - ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(ModelService(), server) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server( + ModelService(), server) server.add_insecure_port(f'[::]:{port}') server.start() print(f'Start serve successfully at {port}.')