Skip to content

Commit

Permalink
ihc support negative
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Oct 25, 2024
1 parent 9e62fb7 commit ceea827
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cpp/cmake/CompilerSettings.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "MSVC")
message(FATAL_ERROR "Unsupported Visual Studio, supported list: [2017, 2019]. Current MSVC_TOOLSET_VERSION: ${MSVC_TOOLSET_VERSION}")
endif()

add_definitions(-DUSE_STD_RANGES)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++17")
#add_definitions(-DUSE_STD_RANGES)
add_compile_options(/std:c++latest)
add_compile_definitions(NOMINMAX)
add_compile_options(-bigobj)
# MSVC only support static build
Expand Down
1 change: 1 addition & 0 deletions cpp/wedpr-protocol/grpc/client/GrpcClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
#include "GrpcClient.h"
#include "Common.h"
#include <memory>

using namespace ppc::protocol;
using namespace ppc::proto;
Expand Down
45 changes: 35 additions & 10 deletions python/ppc_common/ppc_crypto/ihc_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,43 @@
import secrets


class NumberCodec:
def __init__(self, key_length):
self.key_length = key_length
self.max_mod = 1 << key_length
self.non_negative_range = int((1 << key_length) / 3)
self.negative_range = int((1 << int(key_length + 1)) / 3)

def encode(self, value):
if value > 0:
if value > self.non_negative_range:
raise Exception(f"The value {value} out of range")
return value
return value % self.max_mod

def decode(self, value):
if value > self.negative_range:
return value - self.max_mod
return value


@dataclass
class IhcCiphertext():
__slots__ = ['c_left', 'c_right']
__slots__ = ['c_left', 'c_right', 'number_codec']

def __init__(self, c_left: int, c_right: int) -> None:
def __init__(self, c_left: int, c_right: int, number_codec: NumberCodec) -> None:
self.c_left = c_left
self.c_right = c_right
self.number_codec = number_codec

def __add__(self, other):
cipher_left = self.c_left + other.c_left
cipher_right = self.c_right + other.c_right
return IhcCiphertext(cipher_left, cipher_right)
return IhcCiphertext(cipher_left, cipher_right, self.number_codec)

def __mul__(self, num: int):
return IhcCiphertext(num * self.c_left, num * self.c_right)
num_value = self.number_codec.encode(num)
return IhcCiphertext(num_value * self.c_left, num_value * self.c_right, self.number_codec)

def __eq__(self, other):
return self.c_left == other.c_left and self.c_right == other.c_right
Expand Down Expand Up @@ -51,7 +73,7 @@ def decode(cls, encoded_data: bytes):
encoded_data[8:8 + len_c_left], byteorder='big')
c_right = int.from_bytes(
encoded_data[8 + len_c_left:8 + len_c_left + len_c_right], byteorder='big')
return cls(c_left, c_right)
return cls(c_left, c_right, cls.number_codec)


class IhcCipher(PheCipher):
Expand All @@ -64,17 +86,18 @@ def __init__(self, key_length: int = 256, iter_round: int = 16) -> None:
self.key_length = key_length

self.max_mod = 1 << key_length
self.number_codec = NumberCodec(self.key_length)

def encrypt(self, number: int) -> IhcCiphertext:
random_u = secrets.randbits(self.key_length)
x_this = number
x_this = self.number_codec.encode(number)
# print(f"###### x_this: {x_this}, number: {number}")
x_last = random_u
for i in range(0, self.iter_round):
x_tmp = (self.private_key * x_this - x_last) % self.max_mod
x_last = x_this
x_this = x_tmp
# cipher = IhcCiphertext(x_this, x_last, self.max_mod)
cipher = IhcCiphertext(x_this, x_last)
cipher = IhcCiphertext(x_this, x_last, self.number_codec)
return cipher

def decrypt(self, cipher: IhcCiphertext) -> int:
Expand All @@ -84,7 +107,9 @@ def decrypt(self, cipher: IhcCiphertext) -> int:
x_tmp = (self.private_key * x_this - x_last) % self.max_mod
x_last = x_this
x_this = x_tmp
return x_this
result = self.number_codec.decode(x_this)
# print(f"###### x_this: {x_this}, result: {result}")
return result

def encrypt_batch(self, numbers) -> list:
return [self.encrypt(num) for num in numbers]
Expand Down
67 changes: 57 additions & 10 deletions python/ppc_common/ppc_crypto/test/phe_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,80 @@
from ppc_common.ppc_crypto.paillier_cipher import PaillierCipher


class PaillierUtilsTest(unittest.TestCase):
class PaillierTest:
def __init__(self, ut):
self.paillier = PaillierCipher(key_length=1024)
self.ut = ut

def test_enc_and_dec_parallel(self):
paillier = PaillierCipher(key_length=1024)
inputs = np.random.randint(1, 10001, size=10)
def test_enc_and_dec_parallel(self, test_size, start, end):
inputs = np.random.randint(start, end, size=test_size)

# start_time = time.time()
# paillier.encrypt_batch(inputs)
# end_time = time.time()
# print("enc:", end_time - start_time, "seconds")

start_time = time.time()
ciphers = paillier.encrypt_batch_parallel(inputs)
ciphers = self.paillier.encrypt_batch_parallel(inputs)
end_time = time.time()
print("enc_p:", end_time - start_time, "seconds")

start_time = time.time()
outputs = paillier.decrypt_batch_parallel(ciphers)
outputs = self.paillier.decrypt_batch_parallel(ciphers)
end_time = time.time()
print("dec_p:", end_time - start_time, "seconds")

self.assertListEqual(list(inputs), list(outputs))
self.ut.assertListEqual(list(inputs), list(outputs))
self.test_ihc_mul_enc_and_dec(ciphers, inputs, 10)
# test add and enc dec
inputs2 = np.random.randint(start, end, size=test_size)
ciphers2 = self.paillier.encrypt_batch_parallel(inputs2)
self.test_ihc_add_enc_and_desc(ciphers, ciphers2, inputs, inputs2)

def test_ihc_mul_enc_and_dec(self, ciphers, inputs, mul_value):
start_time = time.time()
mul_ciphers = []
for cipher in ciphers:
cipher.__mul__(mul_value)
mul_ciphers.append(cipher * (mul_value))
# decrypt
outputs = self.paillier.decrypt_batch_parallel(mul_ciphers)
mul_result = []
for input in inputs:
mul_result.append(mul_value * input)
self.ut.assertListEqual(mul_result, list(outputs))
end_time = time.time()
print(
f"#### test_ihc_mul_enc_and_desc passed, time: {end_time - start_time} seconds")

def test_ihc_add_enc_and_desc(self, ciphers1, ciphers2, inputs1, inputs2):
start_time = time.time()
add_ciphers = []
i = 0
expected_result = []
for cipher in ciphers1:
add_ciphers.append(cipher + ciphers2[i])
expected_result.append(inputs1[i] + inputs2[i])
i += 1
outputs = self.paillier.decrypt_batch_parallel(add_ciphers)
self.ut.assertListEqual(expected_result, list(outputs))
end_time = time.time()
print(
f"#### test_ihc_add_enc_and_desc passed, time: {end_time - start_time} seconds, size: {len(inputs1)}")


class PaillierUtilsTest(unittest.TestCase):

def test_enc_and_dec_parallel(self):
paillier_test = PaillierTest(self)
paillier_test.test_enc_and_dec_parallel(10, -20, -1)
paillier_test.test_enc_and_dec_parallel(10000, -20000, 20000)
paillier_test.test_enc_and_dec_parallel(10000, 0, 20000)

def test_ihc_enc_and_dec_parallel(self):
ihc = IhcCipher(key_length=256)
try_size = 100000
inputs = np.random.randint(1, 10001, size=try_size)
inputs = np.random.randint(-10001, 10001, size=try_size)
expected = np.sum(inputs)

start_time = time.time()
Expand All @@ -50,7 +97,7 @@ def test_ihc_enc_and_dec_parallel(self):
cipher_left = (cipher_start.c_left + ciphers[i].c_left)
cipher_right = (cipher_start.c_right + ciphers[i].c_right)
# IhcCiphertext(cipher_left, cipher_right, cipher_start.max_mod)
IhcCiphertext(cipher_left, cipher_right)
IhcCiphertext(cipher_left, cipher_right, ihc.number_codec)
end_time = time.time()
print(f"size:{try_size}, add_p raw with class: {end_time - start_time} seconds, average times: {(end_time - start_time)/try_size * 1000 * 1000} us")

Expand Down Expand Up @@ -86,7 +133,7 @@ def test_ihc_enc_and_dec_parallel(self):
def test_ihc_code(self):
ihc = IhcCipher(key_length=256)
try_size = 100000
inputs = np.random.randint(1, 10001, size=try_size)
inputs = np.random.randint(-10001, 10001, size=try_size)
start_time = time.time()
ciphers = ihc.encrypt_batch_parallel(inputs)
end_time = time.time()
Expand Down

0 comments on commit ceea827

Please sign in to comment.