-
Notifications
You must be signed in to change notification settings - Fork 1
/
regev.py
210 lines (174 loc) · 8.23 KB
/
regev.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
import math
import secrets
from utils import mround, uniform, mod_between, gaussian, SeededRNG
from serialize import serialize_ndarray, deserialize_ndarray
class RegevPublicParameters:
def __init__(self, n: int, m: int, cipher_mod: int, bs: int, bound: int):
self.n = n
self.m = m
self.cipher_mod = cipher_mod
self.bs = bs
self.bound = bound
@classmethod
def for_pack(self, sec_param: int, num_add: int, num_mes: int, mes_mod: int):
n = 792
m = 32
bound = 20
cipher_mod = 8 * mes_mod * num_mes * bound * (num_add + 1)
return RegevPublicParameters(n, m, cipher_mod, num_mes, bound)
class RegevKey:
def __init__(self, seed: bytes, A: np.array, sec: np.array):
self.seed = seed
self.A = A
self.sec = sec
@ classmethod
def gen(cls, pp: RegevPublicParameters, seed: bytes = None):
""" Generate a key for batched/packed Regev encryption """
seed = seed or secrets.token_bytes(32)
rng = SeededRNG(seed)
sec = uniform(pp.cipher_mod, rng, (pp.n, pp.bs))
A = uniform(pp.cipher_mod, rng, (pp.m, pp.n))
return RegevKey(seed, A, sec)
def __eq__(self, other):
return self.seed == other.seed
def to_bytes(self) -> bytearray:
""" Returns a byte representation of the Regev key 'self' """
return self.seed
@ classmethod
def from_bytes(cls, pp: RegevPublicParameters, b: bytearray):
""" Reconstructs a key (in the batched/packed Regev encryption) from its byte representation """
seed = b
return RegevKey.gen(pp, seed)
def __repr__(self):
return f"Regev Key:\nRandomness Seed: {self.seed}"
class BatchedRegevCiphertext:
def __init__(self, c1: np.ndarray, c2: np.ndarray, mes_mod: int):
self.c1 = c1
self.c2 = c2
self.mes_mod = mes_mod
@ classmethod
def encrypt_raw(cls, pp: RegevPublicParameters, k: RegevKey, mes: np.ndarray, mes_mod: int = 2, seed=None):
rng = SeededRNG(seed or secrets.token_bytes(32))
if mes.ndim != 1:
raise MessageWrongDimensions()
if mes.shape[0] != pp.bs:
raise MessageWrongSize(
f"Expected message size {pp.bs}, got {mes.shape[0]}")
mes = mes % mes_mod
r = uniform(2, rng, lbound=-1, shape=(1, pp.m))
# print(r.tolist())
c1 = r @ k.A % pp.cipher_mod
b = (c1 @ k.sec + gaussian(pp.bound, rng,
shape=(pp.m, pp.bs))) % pp.cipher_mod
c2 = (b + mround(pp.cipher_mod / mes_mod) * mes) % pp.cipher_mod
return BatchedRegevCiphertext(c1, c2, mes_mod)
def __repr__(self):
return f"bRegev Ciphertext:\nMessage Modulus: {self.mes_mod}\n{self.c1}\n{self.c2}"
def __add__(self, other):
c1 = (self.c1 + other.c1) % self.mes_mod
c2 = (self.c1 + other.c1) % self.mes_mod
return BatchedRegevCiphertext(c1, c2, self.mes_mod)
def __eq__(self, other):
return (self.c1 == other.c1).all() and (self.c1 == other.c1).all() and self.mes_mod == other.mes_mod
def to_bytes(self, pp: RegevKey) -> bytes:
""" Turns a batched Regev ciphertext 'self' into its byte representation """
res = bytearray()
num_batches = self.c1.shape[0]
res.extend(num_batches.to_bytes(8, "little"))
# The length of the ciphertext modulus in bytes
cipher_mod_len = math.ceil(pp.cipher_mod.bit_length()/8)
res.extend(self.mes_mod.to_bytes(cipher_mod_len, "little"))
res.extend(serialize_ndarray(self.c1, cipher_mod_len))
res.extend(serialize_ndarray(self.c2, cipher_mod_len))
return res
@ classmethod
def from_bytes(cls, pp: RegevPublicParameters, b: bytes):
""" Recovers a batched Regev ciphertext from its byte representation """
num_batches = int.from_bytes(b[:8], "little")
b = b[8:]
c1 = np.zeros((num_batches, pp.n), dtype=int)
c2 = np.zeros((num_batches, pp.bs), dtype=int)
cipher_mod_len = math.ceil(pp.cipher_mod.bit_length()/8)
mes_mod = int.from_bytes(b[:cipher_mod_len], "little")
b = b[cipher_mod_len:]
c1 = deserialize_ndarray(b, (num_batches, pp.n), cipher_mod_len)
c2 = deserialize_ndarray(
b[num_batches*pp.n*cipher_mod_len:], (num_batches, pp.bs), cipher_mod_len)
return BatchedRegevCiphertext(c1, c2, mes_mod)
def decrypt(self, pp: RegevPublicParameters, k: RegevKey, mes_mod: int = None) -> np.ndarray:
""" Decrypts a batched Regev ciphertext 'self' """
mes_mod = mes_mod or self.mes_mod
noisy_message = (
((self.c2 - self.c1 @ k.sec) % pp.cipher_mod) * mes_mod) / pp.cipher_mod
return mround(noisy_message) % mes_mod
def pack(self, pp: RegevPublicParameters, seed: bytes = None):
""" More densely encodes a batched Regev ciphertext 'self' """
seed = seed or secrets.token_bytes(32)
rng = SeededRNG(seed)
while True:
r = uniform(pp.cipher_mod, rng)
if PackedRegevCiphertext._near_mes((r+self.c2) % pp.cipher_mod, pp.bound, pp.cipher_mod, self.mes_mod):
break
w = mround((((self.c2 + r) % pp.cipher_mod) * self.mes_mod) /
pp.cipher_mod) % self.mes_mod
return PackedRegevCiphertext(self.c1, w, r, self.mes_mod)
class PackedRegevCiphertext:
def __init__(self, c1: np.ndarray, w: np.ndarray, r: int, mes_mod: int):
self.c1 = c1
self.w = w
self.r = r
self.mes_mod = mes_mod
def __repr__(self):
return f"pRegev Ciphertext:\n{self.c1}\n{self.w}\n{self.r}"
def __eq__(self, other):
return (self.c1 == other.c1).all() and (self.w == other.w).all() and self.r == other.r and self.mes_mod == other.mes_mod
def to_bytes(self, pp: RegevPublicParameters) -> bytes:
""" Turns a ciphertext into a byte representation """
res = bytearray()
num_batches = self.c1.shape[0]
res.extend(num_batches.to_bytes(8, "little"))
# The length of the ciphertext modulus in bytes
cipher_mod_len = math.ceil(pp.cipher_mod.bit_length()/8)
res.extend(self.mes_mod.to_bytes(cipher_mod_len, "little"))
res.extend(int(self.r[0]).to_bytes(cipher_mod_len, "little"))
res.extend(serialize_ndarray(self.c1, cipher_mod_len))
mes_mod_len = math.ceil(self.mes_mod.bit_length()/8)
res.extend(serialize_ndarray(self.w, mes_mod_len))
return res
@ classmethod
def from_bytes(cls, pp: RegevPublicParameters, b: bytes):
num_batches = int.from_bytes(b[:8], "little")
b = b[8:]
c1 = np.zeros((num_batches, pp.n), dtype=int)
w = np.zeros((num_batches, pp.bs), dtype=int)
r = np.zeros((1,), dtype=int)
cipher_mod_len = math.ceil(pp.cipher_mod.bit_length()/8)
mes_mod = int.from_bytes(b[:cipher_mod_len], "little")
b = b[cipher_mod_len:]
r[0] += int.from_bytes(b[:cipher_mod_len], "little")
b = b[cipher_mod_len:]
c1 = deserialize_ndarray(b, (num_batches, pp.n), cipher_mod_len)
b = b[num_batches*pp.n*cipher_mod_len:]
mes_mod_len = math.ceil(mes_mod.bit_length()/8)
w = deserialize_ndarray(b, (num_batches, pp.bs), mes_mod_len)
return PackedRegevCiphertext(c1, w, r, mes_mod)
@ classmethod
def _near_mes_scalar(cls, x: int, bound: int, cipher_mod: int, mes_mod: int):
frac = cipher_mod / mes_mod
check_around = frac / 2
for _ in range(mes_mod):
if mod_between(x, (check_around - bound - 1) % cipher_mod, (check_around + bound + 1) % cipher_mod):
return False
check_around += frac
return True
@ classmethod
def _near_mes(cls, arr: np.ndarray, bound: int, cipher_mod: int, mes_mod: int):
return np.vectorize(lambda x: cls._near_mes_scalar(x, bound, cipher_mod, mes_mod))(arr).all()
def decrypt(self, pp: RegevPublicParameters, k: RegevKey, mes_mod=None) -> np.array:
mes_mod = mes_mod or self.mes_mod
return (self.w - mround((((self.c1 @ k.sec + self.r) % pp.cipher_mod) * mes_mod)/pp.cipher_mod)) % mes_mod
class MessageWrongDimensions(Exception):
pass
class MessageWrongSize(Exception):
pass