From c4cc6cd20a1e3833dbdf0683e35a748a218a2072 Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Mon, 28 Oct 2024 10:31:09 +0800 Subject: [PATCH] ihc support negative (#66) --- cpp/cmake/CompilerSettings.cmake | 35 +++++++--- cpp/tools/install_depends.sh | 19 +++++- cpp/wedpr-protocol/grpc/client/GrpcClient.cpp | 1 + python/ppc_common/ppc_crypto/ihc_cipher.py | 45 ++++++++++--- .../ppc_crypto/test/phe_unittest.py | 67 ++++++++++++++++--- 5 files changed, 137 insertions(+), 30 deletions(-) diff --git a/cpp/cmake/CompilerSettings.cmake b/cpp/cmake/CompilerSettings.cmake index 9596d8ff..61396933 100644 --- a/cpp/cmake/CompilerSettings.cmake +++ b/cpp/cmake/CompilerSettings.cmake @@ -1,3 +1,12 @@ +set(CMAKE_CXX_STANDARD 20) +set(Boost_NO_WARN_NEW_VERSIONS ON) +message(STATUS "COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") + +# export windows dll symbol +if(WIN32) + message(STATUS "Compile on Windows") + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS "ON") +endif() if (("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")) find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) @@ -123,6 +132,7 @@ if (("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MA endif() endif () elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "MSVC") + add_compile_definitions(NOMINMAX) # Only support visual studio 2017 and visual studio 2019 set(MSVC_MIN_VERSION "1914") # VS2017 15.7, for full-ish C++17 support @@ -137,16 +147,23 @@ elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "MSVC") message(FATAL_ERROR "Unsupported Visual Studio, supported list: [2017, 2019]. Current MSVC_TOOLSET_VERSION: ${MSVC_TOOLSET_VERSION}") endif() - add_definitions(-DUSE_STD_RANGES) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++17") - add_compile_definitions(NOMINMAX) + add_compile_options(/std:c++latest) add_compile_options(-bigobj) - # MSVC only support static build - set(CMAKE_CXX_FLAGS_DEBUG "/MTd /DEBUG") - set(CMAKE_CXX_FLAGS_MINSIZEREL "/MT /Os") - set(CMAKE_CXX_FLAGS_RELEASE "/MT") - set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "/MT /DEBUG") - link_libraries(ws2_32 Crypt32 userenv) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") + if(BUILD_SHARED_LIBS) + if(CMAKE_BUILD_TYPE MATCHES "Debug") + add_compile_options(/MDd) + else() + add_compile_options(/MD) + endif() + else () + if(CMAKE_BUILD_TYPE MATCHES "Debug") + add_compile_options(/MTd) + else() + add_compile_options(/MT) + endif () + endif () + link_libraries(ws2_32 Crypt32 userenv) else () message(WARNING "Your compiler is not tested, if you run into any issues, we'd welcome any patches.") endif () diff --git a/cpp/tools/install_depends.sh b/cpp/tools/install_depends.sh index 108a6aae..a4c467aa 100644 --- a/cpp/tools/install_depends.sh +++ b/cpp/tools/install_depends.sh @@ -26,6 +26,23 @@ install_gsasl_depend() return fi LOG_INFO "download and install gsasl..." + wget --no-check-certificate https://ftp.gnu.org/gnu/gsasl/gsasl-1.8.0.tar.gz && tar -xvf gsasl-1.8.0.tar.gz + + # centos + if [[ "${os_type}" == "centos" ]];then + cd gsasl-1.8.0 && ./configure --with-pic && make -j4 && make install + fi + # macos + if [[ "${os_type}" == "macos" ]];then + cd gsasl-1.8.0 && ./configure --with-pic && make -j4 && make install + fi + # ubuntu + if [[ "${os_type}" == "ubuntu" ]];then + cd gsasl-1.8.0 && ./configure --with-pic && make -j4 && make install + fi + LOG_INFO "download and install gsasl success..." + + LOG_INFO "download and install libgsasl..." wget --no-check-certificate https://ftp.gnu.org/gnu/gsasl/libgsasl-1.8.0.tar.gz && tar -xvf libgsasl-1.8.0.tar.gz # centos @@ -40,7 +57,7 @@ install_gsasl_depend() if [[ "${os_type}" == "ubuntu" ]];then cd libgsasl-1.8.0 && ./configure --with-pic && make -j4 && make install fi - LOG_INFO "download and install gsasl success..." + LOG_INFO "download and install libgsasl success..." } install_nasm_depend() diff --git a/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp b/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp index c0b3d0a5..9eb3429c 100644 --- a/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp +++ b/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp @@ -19,6 +19,7 @@ */ #include "GrpcClient.h" #include "Common.h" +#include using namespace ppc::protocol; using namespace ppc::proto; diff --git a/python/ppc_common/ppc_crypto/ihc_cipher.py b/python/ppc_common/ppc_crypto/ihc_cipher.py index a0f7c132..4adfeb1c 100644 --- a/python/ppc_common/ppc_crypto/ihc_cipher.py +++ b/python/ppc_common/ppc_crypto/ihc_cipher.py @@ -7,21 +7,43 @@ import secrets +class NumberCodec: + def __init__(self, key_length): + self.key_length = key_length + self.max_mod = 1 << key_length + self.non_negative_range = int((1 << key_length) / 3) + self.negative_range = int((1 << int(key_length + 1)) / 3) + + def encode(self, value): + if value > 0: + if value > self.non_negative_range: + raise Exception(f"The value {value} out of range") + return value + return value % self.max_mod + + def decode(self, value): + if value > self.negative_range: + return value - self.max_mod + return value + + @dataclass class IhcCiphertext(): - __slots__ = ['c_left', 'c_right'] + __slots__ = ['c_left', 'c_right', 'number_codec'] - def __init__(self, c_left: int, c_right: int) -> None: + def __init__(self, c_left: int, c_right: int, number_codec: NumberCodec) -> None: self.c_left = c_left self.c_right = c_right + self.number_codec = number_codec def __add__(self, other): cipher_left = self.c_left + other.c_left cipher_right = self.c_right + other.c_right - return IhcCiphertext(cipher_left, cipher_right) - + return IhcCiphertext(cipher_left, cipher_right, self.number_codec) + def __mul__(self, num: int): - return IhcCiphertext(num * self.c_left, num * self.c_right) + num_value = self.number_codec.encode(num) + return IhcCiphertext(num_value * self.c_left, num_value * self.c_right, self.number_codec) def __eq__(self, other): return self.c_left == other.c_left and self.c_right == other.c_right @@ -51,7 +73,7 @@ def decode(cls, encoded_data: bytes): encoded_data[8:8 + len_c_left], byteorder='big') c_right = int.from_bytes( encoded_data[8 + len_c_left:8 + len_c_left + len_c_right], byteorder='big') - return cls(c_left, c_right) + return cls(c_left, c_right, cls.number_codec) class IhcCipher(PheCipher): @@ -64,17 +86,18 @@ def __init__(self, key_length: int = 256, iter_round: int = 16) -> None: self.key_length = key_length self.max_mod = 1 << key_length + self.number_codec = NumberCodec(self.key_length) def encrypt(self, number: int) -> IhcCiphertext: random_u = secrets.randbits(self.key_length) - x_this = number + x_this = self.number_codec.encode(number) + # print(f"###### x_this: {x_this}, number: {number}") x_last = random_u for i in range(0, self.iter_round): x_tmp = (self.private_key * x_this - x_last) % self.max_mod x_last = x_this x_this = x_tmp - # cipher = IhcCiphertext(x_this, x_last, self.max_mod) - cipher = IhcCiphertext(x_this, x_last) + cipher = IhcCiphertext(x_this, x_last, self.number_codec) return cipher def decrypt(self, cipher: IhcCiphertext) -> int: @@ -84,7 +107,9 @@ def decrypt(self, cipher: IhcCiphertext) -> int: x_tmp = (self.private_key * x_this - x_last) % self.max_mod x_last = x_this x_this = x_tmp - return x_this + result = self.number_codec.decode(x_this) + # print(f"###### x_this: {x_this}, result: {result}") + return result def encrypt_batch(self, numbers) -> list: return [self.encrypt(num) for num in numbers] diff --git a/python/ppc_common/ppc_crypto/test/phe_unittest.py b/python/ppc_common/ppc_crypto/test/phe_unittest.py index a8ef12a7..b7381a1e 100644 --- a/python/ppc_common/ppc_crypto/test/phe_unittest.py +++ b/python/ppc_common/ppc_crypto/test/phe_unittest.py @@ -8,11 +8,13 @@ from ppc_common.ppc_crypto.paillier_cipher import PaillierCipher -class PaillierUtilsTest(unittest.TestCase): +class PaillierTest: + def __init__(self, ut): + self.paillier = PaillierCipher(key_length=1024) + self.ut = ut - def test_enc_and_dec_parallel(self): - paillier = PaillierCipher(key_length=1024) - inputs = np.random.randint(1, 10001, size=10) + def test_enc_and_dec_parallel(self, test_size, start, end): + inputs = np.random.randint(start, end, size=test_size) # start_time = time.time() # paillier.encrypt_batch(inputs) @@ -20,21 +22,66 @@ def test_enc_and_dec_parallel(self): # print("enc:", end_time - start_time, "seconds") start_time = time.time() - ciphers = paillier.encrypt_batch_parallel(inputs) + ciphers = self.paillier.encrypt_batch_parallel(inputs) end_time = time.time() print("enc_p:", end_time - start_time, "seconds") start_time = time.time() - outputs = paillier.decrypt_batch_parallel(ciphers) + outputs = self.paillier.decrypt_batch_parallel(ciphers) end_time = time.time() print("dec_p:", end_time - start_time, "seconds") - self.assertListEqual(list(inputs), list(outputs)) + self.ut.assertListEqual(list(inputs), list(outputs)) + self.test_ihc_mul_enc_and_dec(ciphers, inputs, 10) + # test add and enc dec + inputs2 = np.random.randint(start, end, size=test_size) + ciphers2 = self.paillier.encrypt_batch_parallel(inputs2) + self.test_ihc_add_enc_and_desc(ciphers, ciphers2, inputs, inputs2) + + def test_ihc_mul_enc_and_dec(self, ciphers, inputs, mul_value): + start_time = time.time() + mul_ciphers = [] + for cipher in ciphers: + cipher.__mul__(mul_value) + mul_ciphers.append(cipher * (mul_value)) + # decrypt + outputs = self.paillier.decrypt_batch_parallel(mul_ciphers) + mul_result = [] + for input in inputs: + mul_result.append(mul_value * input) + self.ut.assertListEqual(mul_result, list(outputs)) + end_time = time.time() + print( + f"#### test_ihc_mul_enc_and_desc passed, time: {end_time - start_time} seconds") + + def test_ihc_add_enc_and_desc(self, ciphers1, ciphers2, inputs1, inputs2): + start_time = time.time() + add_ciphers = [] + i = 0 + expected_result = [] + for cipher in ciphers1: + add_ciphers.append(cipher + ciphers2[i]) + expected_result.append(inputs1[i] + inputs2[i]) + i += 1 + outputs = self.paillier.decrypt_batch_parallel(add_ciphers) + self.ut.assertListEqual(expected_result, list(outputs)) + end_time = time.time() + print( + f"#### test_ihc_add_enc_and_desc passed, time: {end_time - start_time} seconds, size: {len(inputs1)}") + + +class PaillierUtilsTest(unittest.TestCase): + + def test_enc_and_dec_parallel(self): + paillier_test = PaillierTest(self) + paillier_test.test_enc_and_dec_parallel(10, -20, -1) + paillier_test.test_enc_and_dec_parallel(10000, -20000, 20000) + paillier_test.test_enc_and_dec_parallel(10000, 0, 20000) def test_ihc_enc_and_dec_parallel(self): ihc = IhcCipher(key_length=256) try_size = 100000 - inputs = np.random.randint(1, 10001, size=try_size) + inputs = np.random.randint(-10001, 10001, size=try_size) expected = np.sum(inputs) start_time = time.time() @@ -50,7 +97,7 @@ def test_ihc_enc_and_dec_parallel(self): cipher_left = (cipher_start.c_left + ciphers[i].c_left) cipher_right = (cipher_start.c_right + ciphers[i].c_right) # IhcCiphertext(cipher_left, cipher_right, cipher_start.max_mod) - IhcCiphertext(cipher_left, cipher_right) + IhcCiphertext(cipher_left, cipher_right, ihc.number_codec) end_time = time.time() print(f"size:{try_size}, add_p raw with class: {end_time - start_time} seconds, average times: {(end_time - start_time)/try_size * 1000 * 1000} us") @@ -86,7 +133,7 @@ def test_ihc_enc_and_dec_parallel(self): def test_ihc_code(self): ihc = IhcCipher(key_length=256) try_size = 100000 - inputs = np.random.randint(1, 10001, size=try_size) + inputs = np.random.randint(-10001, 10001, size=try_size) start_time = time.time() ciphers = ihc.encrypt_batch_parallel(inputs) end_time = time.time()