Skip to content

Commit

Permalink
Huocun/self main (#206)
Browse files Browse the repository at this point in the history
* add prefix to patch ref

* fix config issues

* add io_flush

* delete funny enforce
  • Loading branch information
huocun-ant authored Nov 15, 2024
1 parent 8a129e7 commit 6e67a19
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 59 deletions.
6 changes: 3 additions & 3 deletions bazel/psi.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ FAST_FLAGS = ["-O1"]

def _psi_copts():
return select({
"//bazel:psi_build_as_release": RELEASE_FLAGS,
"//bazel:psi_build_as_debug": DEBUG_FLAGS,
"//bazel:psi_build_as_fast": FAST_FLAGS,
"@psi//bazel:psi_build_as_release": RELEASE_FLAGS,
"@psi//bazel:psi_build_as_debug": DEBUG_FLAGS,
"@psi//bazel:psi_build_as_fast": FAST_FLAGS,
"//conditions:default": FAST_FLAGS,
}) + WARNING_FLAGS

Expand Down
72 changes: 36 additions & 36 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _com_github_facebook_zstd():
maybe(
http_archive,
name = "com_github_facebook_zstd",
build_file = "//bazel:zstd.BUILD",
build_file = "@psi//bazel:zstd.BUILD",
strip_prefix = "zstd-1.5.5",
sha256 = "98e9c3d949d1b924e28e01eccb7deed865eefebf25c2f21c702e5cd5b63b85e1",
type = ".tar.gz",
Expand All @@ -93,7 +93,7 @@ def _upb():
],
patch_args = ["-p1"],
patches = [
"//bazel/patches:upb.patch",
"@psi//bazel/patches:upb.patch",
],
)

Expand All @@ -106,14 +106,14 @@ def _com_github_emptoolkit_emp_tool():
type = "tar.gz",
patch_args = ["-p1"],
patches = [
"//bazel/patches:emp-tool.patch",
"//bazel/patches:emp-tool-cmake.patch",
"//bazel/patches:emp-tool-sse2neon.patch",
"@psi//bazel/patches:emp-tool.patch",
"@psi//bazel/patches:emp-tool-cmake.patch",
"@psi//bazel/patches:emp-tool-sse2neon.patch",
],
urls = [
"https://github.com/emp-toolkit/emp-tool/archive/refs/tags/0.2.5.tar.gz",
],
build_file = "//bazel:emp-tool.BUILD",
build_file = "@psi//bazel:emp-tool.BUILD",
)

def _com_github_intel_ipp():
Expand All @@ -122,10 +122,10 @@ def _com_github_intel_ipp():
name = "com_github_intel_ipp",
sha256 = "d70f42832337775edb022ca8ac1ac418f272e791ec147778ef7942aede414cdc",
strip_prefix = "cryptography-primitives-ippcp_2021.8",
build_file = "//bazel:ipp.BUILD",
build_file = "@psi//bazel:ipp.BUILD",
patch_args = ["-p1"],
patches = [
"//bazel/patches:ippcp.patch",
"@psi//bazel/patches:ippcp.patch",
],
urls = [
"https://github.com/intel/cryptography-primitives/archive/refs/tags/ippcp_2021.8.tar.gz",
Expand All @@ -140,11 +140,11 @@ def _com_github_microsoft_seal():
strip_prefix = "SEAL-4.1.1",
type = "tar.gz",
patch_args = ["-p1"],
patches = ["//bazel/patches:seal.patch"],
patches = ["@psi//bazel/patches:seal.patch"],
urls = [
"https://github.com/microsoft/SEAL/archive/refs/tags/v4.1.1.tar.gz",
],
build_file = "//bazel:seal.BUILD",
build_file = "@psi//bazel:seal.BUILD",
)

def _com_github_microsoft_apsi():
Expand All @@ -156,11 +156,11 @@ def _com_github_microsoft_apsi():
urls = [
"https://github.com/microsoft/APSI/archive/refs/tags/v0.11.0.tar.gz",
],
build_file = "//bazel:microsoft_apsi.BUILD",
build_file = "@psi//bazel:microsoft_apsi.BUILD",
patch_args = ["-p1"],
patches = [
"//bazel/patches:apsi.patch",
"//bazel/patches:apsi-fourq.patch",
"@psi//bazel/patches:apsi.patch",
"@psi//bazel/patches:apsi-fourq.patch",
],
patch_cmds = [
"rm -rf common/apsi/fourq",
Expand All @@ -177,7 +177,7 @@ def _com_github_microsoft_gsl():
urls = [
"https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.tar.gz",
],
build_file = "//bazel:microsoft_gsl.BUILD",
build_file = "@psi//bazel:microsoft_gsl.BUILD",
)

def _com_github_microsoft_kuku():
Expand All @@ -190,7 +190,7 @@ def _com_github_microsoft_kuku():
urls = [
"https://github.com/microsoft/Kuku/archive/refs/tags/v2.1.0.tar.gz",
],
build_file = "//bazel:microsoft_kuku.BUILD",
build_file = "@psi//bazel:microsoft_kuku.BUILD",
)

def _com_google_flatbuffers():
Expand All @@ -208,7 +208,7 @@ def _com_google_flatbuffers():
"rm grpc/src/compiler/BUILD.bazel",
"rm src/BUILD.bazel",
],
build_file = "//bazel:flatbuffers.BUILD",
build_file = "@psi//bazel:flatbuffers.BUILD",
)

def _org_apache_arrow():
Expand All @@ -220,7 +220,7 @@ def _org_apache_arrow():
],
sha256 = "2852b21f93ee84185a9d838809c9a9c41bf6deca741bed1744e0fdba6cc19e3f",
strip_prefix = "arrow-apache-arrow-10.0.0",
build_file = "//bazel:arrow.BUILD",
build_file = "@psi//bazel:arrow.BUILD",
)

def _com_github_grpc_grpc():
Expand All @@ -231,7 +231,7 @@ def _com_github_grpc_grpc():
strip_prefix = "grpc-1.51.0",
type = "tar.gz",
patch_args = ["-p1"],
patches = ["//bazel/patches:grpc.patch"],
patches = ["@psi//bazel/patches:grpc.patch"],
urls = [
"https://github.com/grpc/grpc/archive/refs/tags/v1.51.0.tar.gz",
],
Expand All @@ -246,7 +246,7 @@ def _com_github_nelhage_rules_boost():
sha256 = "a7c42df432fae9db0587ff778d84f9dc46519d67a984eff8c79ae35e45f277c1",
strip_prefix = "rules_boost-%s" % RULES_BOOST_COMMIT,
patch_args = ["-p1"],
patches = ["//bazel/patches:boost.patch"],
patches = ["@psi//bazel/patches:boost.patch"],
urls = [
"https://github.com/nelhage/rules_boost/archive/%s.tar.gz" % RULES_BOOST_COMMIT,
],
Expand All @@ -261,7 +261,7 @@ def _com_github_tencent_rapidjson():
],
sha256 = "bf7ced29704a1e696fbccf2a2b4ea068e7774fa37f6d7dd4039d0787f8bed98e",
strip_prefix = "rapidjson-1.1.0",
build_file = "//bazel:rapidjson.BUILD",
build_file = "@psi//bazel:rapidjson.BUILD",
)

def _com_github_xtensor_xsimd():
Expand All @@ -274,14 +274,14 @@ def _com_github_xtensor_xsimd():
sha256 = "d52551360d37709675237d2a0418e28f70995b5b7cdad7c674626bcfbbf48328",
type = "tar.gz",
strip_prefix = "xsimd-8.1.0",
build_file = "//bazel:xsimd.BUILD",
build_file = "@psi//bazel:xsimd.BUILD",
)

def _brotli():
maybe(
http_archive,
name = "brotli",
build_file = "//bazel:brotli.BUILD",
build_file = "@psi//bazel:brotli.BUILD",
sha256 = "e720a6ca29428b803f4ad165371771f5398faba397edf6778837a18599ea13ff",
strip_prefix = "brotli-1.1.0",
urls = [
Expand All @@ -299,14 +299,14 @@ def _com_github_lz4_lz4():
sha256 = "030644df4611007ff7dc962d981f390361e6c97a34e5cbc393ddfbe019ffe2c1",
type = "tar.gz",
strip_prefix = "lz4-1.9.3",
build_file = "//bazel:lz4.BUILD",
build_file = "@psi//bazel:lz4.BUILD",
)

def _org_apache_thrift():
maybe(
http_archive,
name = "org_apache_thrift",
build_file = "//bazel:thrift.BUILD",
build_file = "@psi//bazel:thrift.BUILD",
sha256 = "31e46de96a7b36b8b8a457cecd2ee8266f81a83f8e238a9d324d8c6f42a717bc",
strip_prefix = "thrift-0.21.0",
urls = [
Expand All @@ -320,7 +320,7 @@ def _com_google_double_conversion():
name = "com_google_double_conversion",
sha256 = "04ec44461850abbf33824da84978043b22554896b552c5fd11a9c5ae4b4d296e",
strip_prefix = "double-conversion-3.3.0",
build_file = "//bazel:double-conversion.BUILD",
build_file = "@psi//bazel:double-conversion.BUILD",
urls = [
"https://github.com/google/double-conversion/archive/refs/tags/v3.3.0.tar.gz",
],
Expand All @@ -330,7 +330,7 @@ def _bzip2():
maybe(
http_archive,
name = "bzip2",
build_file = "//bazel:bzip2.BUILD",
build_file = "@psi//bazel:bzip2.BUILD",
sha256 = "ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269",
strip_prefix = "bzip2-1.0.8",
urls = [
Expand All @@ -347,7 +347,7 @@ def _com_github_google_snappy():
],
sha256 = "75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7",
strip_prefix = "snappy-1.1.9",
build_file = "//bazel:snappy.BUILD",
build_file = "@psi//bazel:snappy.BUILD",
)

def _com_github_google_perfetto():
Expand All @@ -360,8 +360,8 @@ def _com_github_google_perfetto():
sha256 = "4c8fe8a609fcc77ca653ec85f387ab6c3a048fcd8df9275a1aa8087984b89db8",
strip_prefix = "perfetto-41.0",
patch_args = ["-p1"],
patches = ["//bazel/patches:perfetto.patch"],
build_file = "//bazel:perfetto.BUILD",
patches = ["@psi//bazel/patches:perfetto.patch"],
build_file = "@psi//bazel:perfetto.BUILD",
)

def _com_github_floodyberry_curve25519_donna():
Expand All @@ -371,7 +371,7 @@ def _com_github_floodyberry_curve25519_donna():
strip_prefix = "curve25519-donna-2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2",
sha256 = "ba57d538c241ad30ff85f49102ab2c8dd996148456ed238a8c319f263b7b149a",
type = "tar.gz",
build_file = "//bazel:curve25519-donna.BUILD",
build_file = "@psi//bazel:curve25519-donna.BUILD",
urls = [
"https://github.com/floodyberry/curve25519-donna/archive/2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2.tar.gz",
],
Expand All @@ -386,7 +386,7 @@ def _com_github_ridiculousfish_libdivide():
],
sha256 = "01ffdf90bc475e42170741d381eb9cfb631d9d7ddac7337368bcd80df8c98356",
strip_prefix = "libdivide-5.0",
build_file = "//bazel:libdivide.BUILD",
build_file = "@psi//bazel:libdivide.BUILD",
)

def _com_github_sparsehash_sparsehash():
Expand All @@ -398,14 +398,14 @@ def _com_github_sparsehash_sparsehash():
],
sha256 = "8cd1a95827dfd8270927894eb77f62b4087735cbede953884647f16c521c7e58",
strip_prefix = "sparsehash-sparsehash-2.0.4",
build_file = "//bazel:sparsehash.BUILD",
build_file = "@psi//bazel:sparsehash.BUILD",
)

def _com_github_zeromq_cppzmq():
maybe(
http_archive,
name = "com_github_zeromq_cppzmq",
build_file = "//bazel:cppzmq.BUILD",
build_file = "@psi//bazel:cppzmq.BUILD",
strip_prefix = "cppzmq-4.10.0",
sha256 = "c81c81bba8a7644c84932225f018b5088743a22999c6d82a2b5f5cd1e6942b74",
type = ".tar.gz",
Expand All @@ -418,7 +418,7 @@ def _com_github_zeromq_libzmq():
maybe(
http_archive,
name = "com_github_zeromq_libzmq",
build_file = "//bazel:libzmq.BUILD",
build_file = "@psi//bazel:libzmq.BUILD",
strip_prefix = "libzmq-4.3.5",
sha256 = "6c972d1e6a91a0ecd79c3236f04cf0126f2f4dfbbad407d72b4606a7ba93f9c6",
type = ".tar.gz",
Expand All @@ -431,7 +431,7 @@ def _com_github_log4cplus_log4cplus():
maybe(
http_archive,
name = "com_github_log4cplus_log4cplus",
build_file = "//bazel:log4cplus.BUILD",
build_file = "@psi//bazel:log4cplus.BUILD",
strip_prefix = "log4cplus-2.1.1",
sha256 = "42dc435928917fd2f847046c4a0c6086b2af23664d198c7fc1b982c0bfe600c1",
type = ".tar.gz",
Expand All @@ -444,7 +444,7 @@ def _com_github_open_source_parsers_jsoncpp():
maybe(
http_archive,
name = "com_github_open_source_parsers_jsoncpp",
build_file = "//bazel:jsoncpp.BUILD",
build_file = "@psi//bazel:jsoncpp.BUILD",
strip_prefix = "jsoncpp-1.9.6",
sha256 = "f93b6dd7ce796b13d02c108bc9f79812245a82e577581c4c9aabe57075c90ea2",
type = ".tar.gz",
Expand Down
14 changes: 0 additions & 14 deletions psi/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,6 @@ void AbstractPsiParty::CheckSelfConfig() {
YACL_ENFORCE_EQ(static_cast<int>(keys_set.size()), config_.keys().size(),
"Duplicated key is not allowed.");

if (!config_.protocol_config().broadcast_result() &&
config_.advanced_join_type() !=
v2::PsiConfig::ADVANCED_JOIN_TYPE_UNSPECIFIED) {
SPDLOG_WARN(
"broadcast_result turns off while advanced join is enabled. "
"broadcast_result is modified to true since intersection has to be "
"sent to both parties.");

YACL_ENFORCE(!config_.output_config().path().empty(),
"You have to provide path of output.");

config_.mutable_protocol_config()->set_broadcast_result(true);
}

if (!config_.skip_duplicates_check() &&
config_.advanced_join_type() !=
v2::PsiConfig::ADVANCED_JOIN_TYPE_UNSPECIFIED) {
Expand Down
18 changes: 12 additions & 6 deletions psi/utils/join_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,17 @@ JoinProcessor::JoinProcessor(const v2::UbPsiConfig& ub_psi_config,
v2::UbPsiConfig::MODE_FULL,
};

if (gen_output_mode.find(ub_psi_config.mode()) != gen_output_mode.end()) {
YACL_ENFORCE(
ub_psi_config.output_config().type() == v2::IoType::IO_TYPE_FILE_CSV,
"unsupport output format {}",
v2::IoType_Name(ub_psi_config.input_config().type()));
output_path_ = ub_psi_config.output_config().path();
bool gen_output =
(role_ == v2::ROLE_SERVER && ub_psi_config.server_get_result()) ||
(role_ == v2::ROLE_CLIENT && ub_psi_config.client_get_result());
if (gen_output) {
if (gen_output_mode.find(ub_psi_config.mode()) != gen_output_mode.end()) {
YACL_ENFORCE(
ub_psi_config.output_config().type() == v2::IoType::IO_TYPE_FILE_CSV,
"unsupport output format {}",
v2::IoType_Name(ub_psi_config.input_config().type()));
output_path_ = ub_psi_config.output_config().path();
}
}

if (!std::filesystem::exists(ub_psi_config.cache_path())) {
Expand Down Expand Up @@ -265,6 +270,7 @@ std::shared_ptr<KeyInfo> JoinProcessor::GetUniqueKeysInfo() {
KeyInfo::StatInfo JoinProcessor::DealResultIndex(IndexReader& index) {
ResultDumper dumper(sorted_intersect_path_, sorted_except_path_);
auto stat = GetUniqueKeysInfo()->ApplyPeerDupCnt(index, dumper);
dumper.Flush();
if (is_input_key_unique_ && align_output_) {
if (!sorted_intersect_path_.empty()) {
Table::MakeFromCsv(sorted_intersect_path_)
Expand Down
9 changes: 9 additions & 0 deletions psi/utils/table_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,15 @@ void ResultDumper::Dump(const std::string& line, int64_t duplicate_cnt,
}
}

void ResultDumper::Flush() {
if (intersect_file_) {
intersect_file_->flush();
}
if (except_file_) {
except_file_->flush();
}
}

std::vector<std::string> KeyInfo::SourceFileColumns() const {
return table_->Columns();
}
Expand Down
2 changes: 2 additions & 0 deletions psi/utils/table_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ struct ResultDumper {
int64_t except_cnt() const { return except_cnt_; }
int64_t intersect_cnt() const { return intersect_cnt_; }

void Flush();

private:
void Dump(const std::string& line, int64_t duplicate_cnt,
std::shared_ptr<std::ofstream>& file, int64_t* total_dump_cnt);
Expand Down

0 comments on commit 6e67a19

Please sign in to comment.