From feb0719b77f2d754c610e03bef098a59faee4592 Mon Sep 17 00:00:00 2001 From: Berry Schoenmakers Date: Sat, 29 Sep 2018 14:31:28 +0200 Subject: [PATCH] Transition from n-party to m-party. (#3) --- demos/.config/generate-certificates.py | 6 +- demos/.config/generate-config-files.py | 10 +- demos/ot.py | 14 +- demos/secretsanta.py | 14 +- demos/sort.py | 2 +- mpyc/__init__.py | 2 +- mpyc/asyncoro.py | 2 +- mpyc/pfield.py | 26 ++-- mpyc/runtime.py | 195 +++++++++++++------------ mpyc/sectypes.py | 6 +- mpyc/thresha.py | 81 +++++----- tests/test_thresha.py | 26 ++-- 12 files changed, 199 insertions(+), 185 deletions(-) diff --git a/demos/.config/generate-certificates.py b/demos/.config/generate-certificates.py index 2e6939e5..fc7e9ae8 100644 --- a/demos/.config/generate-certificates.py +++ b/demos/.config/generate-certificates.py @@ -59,9 +59,9 @@ def save_certificate(cert, filename): help='output filename prefix') parser.add_argument('-k', '--key-size', type=int, help='key size') - parser.add_argument('-n', '--parties', dest='n', type=int, + parser.add_argument('-m', '--parties', dest='m', type=int, help='number of parties') - parser.set_defaults(n=5, key_size=2048, prefix='party_') + parser.set_defaults(m=5, key_size=2048, prefix='party_') options = parser.parse_args() # self-signed certificate @@ -71,7 +71,7 @@ def save_certificate(cert, filename): save_key(ca_key, 'mpyc_ca.key') save_certificate(ca_cert, 'mpyc_ca.crt') - for i in range(options.n): + for i in range(options.m): key = create_key(options.key_size) req = create_request(key, f'MPyC party {i}') cert = create_certificate(req, ca_cert, ca_key, 256 + i) diff --git a/demos/.config/generate-config-files.py b/demos/.config/generate-config-files.py index c57f4892..a6a5b09d 100644 --- a/demos/.config/generate-config-files.py +++ b/demos/.config/generate-config-files.py @@ -22,21 +22,21 @@ parser = argparse.ArgumentParser() parser.add_argument('-p', '--prefix', help='output filename prefix') -parser.add_argument('-n', '--parties', dest='n', type=int, +parser.add_argument('-m', '--parties', dest='m', type=int, help='number of parties') parser.add_argument('args', nargs='*') -parser.set_defaults(n=3, prefix='party') +parser.set_defaults(m=3, prefix='party') options = parser.parse_args() args = options.args -if len(args) != options.n: +if len(args) != options.m: parser.error('A hostname:port argument required for each party.') addresses = [arg.split(':', 1) for arg in args] -configs = generate_configs(options.n, addresses) +configs = generate_configs(options.m, addresses) for party, config in enumerate(configs): - filename = f'{options.prefix}{options.n}_{party}.ini' + filename = f'{options.prefix}{options.m}_{party}.ini' config.write(open(filename, 'w')) diff --git a/demos/ot.py b/demos/ot.py index 0f6d63a1..41ab6134 100644 --- a/demos/ot.py +++ b/demos/ot.py @@ -3,13 +3,13 @@ import sys from mpyc.runtime import mpc -n = len(mpc.parties) +m = len(mpc.parties) -if n % 2 == 0: +if m % 2 == 0: print('OT runs with odd number of parties only.') sys.exit() -t = n // 2 +t = m // 2 message = [(None, None)] * t choice = [None] * t if mpc.id == 0: @@ -25,10 +25,10 @@ secnum = mpc.SecInt() for i in range(1, t + 1): - m = mpc.input([secnum(message[i - 1][0]), secnum(message[i - 1][1])], i) + x = mpc.input([secnum(message[i - 1][0]), secnum(message[i - 1][1])], i) b = mpc.input(secnum(choice[i - t - 1]), t + i) - m = mpc.run(mpc.output(m[0] + b * (m[1] - m[0]), t + i)) - if m: - print(f'You have received message {m}.') + a = mpc.run(mpc.output(x[0] + b * (x[1] - x[0]), t + i)) + if a: + print(f'You have received message {a}.') mpc.shutdown() diff --git a/demos/secretsanta.py b/demos/secretsanta.py index 5e98a171..21557725 100644 --- a/demos/secretsanta.py +++ b/demos/secretsanta.py @@ -50,26 +50,26 @@ async def random_derangement(n, sectype): def main(): if sys.argv[1:]: - m = int(sys.argv[1]) + N = int(sys.argv[1]) else: - m = 8 - print('Setting input to default =', m) + N = 8 + print('Setting input to default =', N) mpc.start() - secfld = mpc.SecFld(l=max(len(mpc.parties), (m - 1)).bit_length() + 1) + secfld = mpc.SecFld(l=max(len(mpc.parties), (N - 1)).bit_length() + 1) print('Using secure fields:', secfld) - for n in range(2, m + 1): + for n in range(2, N + 1): print(n, mpc.run(mpc.output(random_derangement(n, secfld)))) secint = mpc.SecInt() print('Using secure integers:', secint) - for n in range(2, m + 1): + for n in range(2, N + 1): print(n, mpc.run(mpc.output(random_derangement(n, secint)))) secfxp = mpc.SecFxp() print('Using secure fixed-point numbers:', secfxp) - for n in range(2, m + 1): + for n in range(2, N + 1): print(n, mpc.run(mpc.output(random_derangement(n, secfxp)))) mpc.shutdown() diff --git a/demos/sort.py b/demos/sort.py index 22d6bc2a..e9258436 100644 --- a/demos/sort.py +++ b/demos/sort.py @@ -29,7 +29,7 @@ def bitonic_compare(i, j, up): def main(): if sys.argv[1:]: - m = int(sys.argv[1]) + n = int(sys.argv[1]) else: n = 5 print('Setting input to default =', n) diff --git a/mpyc/__init__.py b/mpyc/__init__.py index 95115556..310fa2c7 100644 --- a/mpyc/__init__.py +++ b/mpyc/__init__.py @@ -13,5 +13,5 @@ are provided as well (e.g., some matrix-vector operations). """ -__version__ = '0.3.5' +__version__ = '0.3.6' __license__ = 'Apache License 2.0' diff --git a/mpyc/asyncoro.py b/mpyc/asyncoro.py index c828e094..9a528e00 100644 --- a/mpyc/asyncoro.py +++ b/mpyc/asyncoro.py @@ -180,7 +180,7 @@ def returnType(rettype=None, *args): Used in first await expression in an MPC coroutine. """ - if rettype == None: + if rettype is None: return _afuture(None) if isinstance(rettype, Future): return _afuture(Future) diff --git a/mpyc/pfield.py b/mpyc/pfield.py index 7d8cdc0b..317b71b8 100644 --- a/mpyc/pfield.py +++ b/mpyc/pfield.py @@ -105,34 +105,32 @@ def __abs__(self): return self.unsigned() @classmethod - def to_bytes(cls, values): - """Return an array of bytes representing the given list of values. + def to_bytes(cls, x): + """Return an array of bytes representing the given list of values x. Values are either integers or field elements. """ - m = len(values) r = (cls.modulus.bit_length() + 7) // 8 - data = bytearray(2 + m * r) + data = bytearray(2 + len(x) * r) data[:2] = r.to_bytes(2, byteorder='little') j = 2 - for i in range(m): - e = values[i] - if not isinstance(e, int): e = e.value - data[j:j + r] = e.to_bytes(r, byteorder='little') + for v in x: + if not isinstance(v, int): v = v.value + data[j:j + r] = v.to_bytes(r, byteorder='little') j += r return data @staticmethod def from_bytes(data): - """Return the list of elements represented by the given array of bytes.""" + """Return the list of integers represented by the given array of bytes.""" r = int.from_bytes(data[:2], byteorder='little') - m = (len(data) - 2) // r - elements = [None] * m + n = (len(data) - 2) // r + x = [None] * n j = 2 - for i in range(m): - elements[i] = int.from_bytes(data[j:j + r], byteorder='little') + for i in range(n): + x[i] = int.from_bytes(data[j:j + r], byteorder='little') j += r - return elements + return x def __add__(self, other): """Addition.""" diff --git a/mpyc/runtime.py b/mpyc/runtime.py index 27f5b4a3..e056364d 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -34,7 +34,7 @@ class Runtime: the list of parties, etc., and handles secret-shared values of type Share. 1-party case is supported (with option to disable asynchronous evaluation). - Threshold 0 (no corrupted parties) is supported for n-party case as well + Threshold 0 (no corrupted parties) is supported for m-party case as well to enable distributed computation (without secret sharing). """ @@ -45,10 +45,10 @@ def __init__(self, id, parties, options): self.options = options self.threshold = options.threshold self._program_counter = [0] - n = len(self.parties) + m = len(self.parties) t = self.threshold - #caching (n choose t): - self._bincoef = math.factorial(n) // math.factorial(t) // math.factorial(n - t) + #caching (m choose t): + self._bincoef = math.factorial(m) // math.factorial(t) // math.factorial(m - t) def _increment_pc(self): """Increment the program counter.""" @@ -116,10 +116,11 @@ def start(self): """ logging.info(f'Start MPyC runtime v{self.version}') self.start_time = time.time() - if len(self.parties) == 1: + m = len(self.parties) + if m == 1: return - # len(self.parties) > 1 + # m > 1 for peer in self.parties: peer.protocol = Future() if peer.id == self.id else None factory = lambda: asyncoro.SharesExchanger(self) @@ -176,7 +177,8 @@ def shutdown(self): Close all connections, if any. """ - if len(self.parties) > 1: + m = len(self.parties) + if m > 1: # Wait for all parties. self.run(self.output(self.input(sectypes.SecFld(101)(self.id)))) # Close connections to all parties. @@ -208,7 +210,8 @@ def input(self, x, senders=None): else: x = [x] if senders is None: - senders = list(range(len(self.parties))) + m = len(self.parties) + senders = list(range(m)) senders_is_list = isinstance(senders, list) if not senders_is_list: senders = [senders] @@ -231,11 +234,13 @@ async def _distribute(self, x, senders): value = x[0].df if not isinstance(x[0].df, Future) else None assert value is None or self.id in senders + m = len(self.parties) + t = self.threshold x = [a.df for a in x] # Extract values from all elements of x. shares = [None] * len(senders) for i, peer_id in enumerate(senders): if peer_id == self.id: - in_shares = thresha.random_split(x, self.threshold, len(self.parties)) + in_shares = thresha.random_split(x, t, m) for other_id, data in enumerate(in_shares): data = field.to_bytes(data) if other_id == self.id: @@ -260,7 +265,8 @@ def output(self, x, receivers=None, threshold=None): else: x = [x] if receivers is None: - receivers = list(range(len(self.parties))) + m = len(self.parties) + receivers = list(range(m)) elif isinstance(receivers, int): receivers = [receivers] if threshold is None: @@ -271,7 +277,7 @@ def output(self, x, receivers=None, threshold=None): return gather_shares(y) @mpc_coro - async def _recombine(self, x, receivers, threshold): + async def _recombine(self, x, receivers, t): """Recombine shares of elements of x.""" sftype = type(x[0]) # all elts assumed of same type if issubclass(sftype, Share): @@ -285,17 +291,18 @@ async def _recombine(self, x, receivers, threshold): await returnType(Share, len(x)) field = sftype + m = len(self.parties) # Send share to all successors in receivers. for peer_id in receivers: - if 0 < (peer_id - self.id) % len(self.parties) <= threshold: + if 0 < (peer_id - self.id) % m <= t: self._send_share(peer_id, field.to_bytes(x)) # Receive and recombine shares if this party is a receiver. if self.id in receivers: - shares = [None] * threshold - for i in range(threshold): - shares[i] = self._expect_share((self.id - threshold + i) % len(self.parties)) + shares = [None] * t + for i in range(t): + shares[i] = self._expect_share((self.id - t + i) % m) shares = await gather_shares(shares) - shares = [((self.id - threshold + j) % len(self.parties) + 1, field.from_bytes(shares[j])) for j in range(threshold)] + shares = [((self.id - t + j) % m + 1, field.from_bytes(shares[j])) for j in range(t)] shares.append((self.id + 1, x)) return thresha.recombine(field, shares) else: @@ -318,10 +325,12 @@ async def _reshare(self, x): await returnType(Share) field = sftype - in_shares = thresha.random_split(x, self.threshold, len(self.parties)) + m = len(self.parties) + t = self.threshold + in_shares = thresha.random_split(x, t, m) in_shares = [field.to_bytes(elts) for elts in in_shares] # Recombine the first 2t+1 output_shares. - out_shares = await gather_shares(self._exchange_shares(in_shares)[:2*self.threshold+1]) + out_shares = await gather_shares(self._exchange_shares(in_shares)[:2 * t + 1]) y = thresha.recombine(field, [(j + 1, field.from_bytes(out_shares[j])) for j in range(len(out_shares))]) if issubclass(sftype, Share): @@ -350,7 +359,7 @@ async def trunc(self, a, f=None): r_modf += r_bits[i].value r_divf = self.random(Zp, 1<<(k + l - f)) a = await gather_shares(a) - c = await self.output(a + (1< Future: """Secure public zero test of a.""" stype = type(a) + m = len(self.parties) + t = self.threshold if stype.__name__.startswith('SecFld'): prfs = self.parties[self.id].prfs(stype.field.modulus) while True: r, s = self.randoms(stype.field, 2) - z = thresha.pseudorandom_share_zero(stype.field, len(self.parties), self.id, prfs, self._prss_uci(), 1) - if await self.output(r * s + z[0], threshold=2*self.threshold): + z = thresha.pseudorandom_share_zero(stype.field, m, self.id, prfs, self._prss_uci(), 1) + if await self.output(r * s + z[0], threshold=2 * t): break else: r = self.random(stype.field) #failure shared r is 0 with prob. 1/p a = await gather_shares(a) if stype.__name__.startswith('SecFld'): - z = thresha.pseudorandom_share_zero(stype.field, len(self.parties), self.id, prfs, self._prss_uci(), 1) + z = thresha.pseudorandom_share_zero(stype.field, m, self.id, prfs, self._prss_uci(), 1) b = a * r + z[0] else: b = a * r - c = await self.output(b, threshold=2*self.threshold) + c = await self.output(b, threshold=2 * t) return c == 0 @mpc_coro @@ -431,7 +442,7 @@ async def mul(self, a, b): else: a, b = await gather_shares(a, b) if stype.field.frac_length > 0 and a_integral: - a = a / (1< 0 and not a_integral: c = self.trunc(stype(c.df)) # c.df @@ -460,23 +471,23 @@ async def reciprocal(self, a): ar = await self.output(a * r, threshold=2*self.threshold) if ar == 0: return self.reciprocal(a) - return r * (1<> i) & 1 - e[i] = Zp(s_sign + r_bits[i].value - c_i + 3*sumXors) + e[i] = Zp(s_sign + r_bits[i].value - c_i + 3 * sumXors) sumXors += 1 - r_bits[i].value if c_i else r_bits[i].value - e[l] = Zp(s_sign - 1 + 3*sumXors) + e[l] = Zp(s_sign - 1 + 3 * sumXors) f = await self.is_zero_public(self.prod(e)) UF = s_bit if f == 1 else 1 - s_bit z = (a_rmodl - (c + UF * (1< 0 and x_integral: - f1 = 1 / stype.field(1< 0 and not x_integral: @@ -759,7 +770,7 @@ async def scalar_mul(self, a, x): await returnType((stype, a_integral and x[0].integral), len(x)) a, x = await gather_shares(a, x) if stype.field.frac_length > 0 and a_integral: - a = a / (1<> 1 # q = 1/2 mod p + m = len(self.parties) + t = self.threshold prfs = self.parties[self.id].prfs(p) - h = m + h = n while h > 0: - rs = thresha.pseudorandom_share(field, len(self.parties), self.id, prfs, self._prss_uci(), h) + rs = thresha.pseudorandom_share(field, m, self.id, prfs, self._prss_uci(), h) # Compute and open the squares and compute square roots. r2s = [r * r for r in rs] if prss0: - z = thresha.pseudorandom_share_zero(field, len(self.parties), self.id, prfs, self._prss_uci(), h) + z = thresha.pseudorandom_share_zero(field, m, self.id, prfs, self._prss_uci(), h) for i in range(h): r2s[i] += z[i] - r2s = await self.output(r2s, threshold=2*self.threshold) + r2s = await self.output(r2s, threshold=2 * t) for r, r2 in zip(rs, r2s): if r2.value != 0: h -= 1 @@ -953,7 +967,7 @@ async def to_bits(self, a): r_modl += r_bits[i].value r_divl = self.random(Zp, 1<> i) & 1) for i in range(l)] @@ -1017,18 +1031,18 @@ def __repr__(self): else: return f'<_Party {self.id}: {self.host}:{self.port}>' -def generate_configs(n, addresses): +def generate_configs(m, addresses): """Generate party configurations. - Generates n party configurations with thresholds 0 up to (n-1)//2. + Generates m party configurations with thresholds 0 up to (m-1)//2. addresses is a list of '(host, port)' pairs, specifying the hostnames and port numbers for each party. Moreover, the keys used in pseudorandom secret sharing (PRSS) are generated. - The n party configurations are returned as a list of ConfigParser - instances, which be saved in n separate INI-files. + The m party configurations are returned as a list of ConfigParser + instances, which be saved in m separate INI-files. """ - parties = range(n) + parties = range(m) configs = [configparser.ConfigParser() for _ in parties] for p in parties: host, port = addresses[p] @@ -1039,8 +1053,8 @@ def generate_configs(n, addresses): config.set(f'Party {p}', 'host', host) config.set(f'Party {p}', 'port', port) - for t in range((n + 1) // 2): - for subset in itertools.combinations(parties, n - t): + for t in range((m + 1) // 2): + for subset in itertools.combinations(parties, m - t): key = hex(secrets.randbits(128)) # 128-bit key subset_str = ' '.join(map(str, subset)) for p in subset: @@ -1048,7 +1062,7 @@ def generate_configs(n, addresses): return configs def _load_config(filename, t=None): - """Load n-party configuration file using threshold t (default (n-1) // 2). + """Load m-party configuration file using threshold t (default (m-1) // 2). Configuration files are simple INI-files containing information (hostname and port number) about the other parties in the protocol. @@ -1060,10 +1074,10 @@ def _load_config(filename, t=None): """ config = configparser.ConfigParser() config.read_file(open(filename, 'r')) - n = len(config.sections()) + m = len(config.sections()) if t is None: - t = (n - 1) // 2 - parties = [None] * n + t = (m - 1) // 2 + parties = [None] * m for party in config.sections(): id = int(party[6:]) # strip 'Party ' prefix host = config.get(party, 'host') @@ -1075,7 +1089,7 @@ def _load_config(filename, t=None): for option in config.options(party): if not option in ['host', 'port']: subset = frozenset(map(int, option.split())) - if len(subset) == n - t: + if len(subset) == m - t: keys[subset] = config.get(party, option) parties[my_id] = _Party(my_id, host, port, keys) else: @@ -1089,21 +1103,21 @@ def setup(): help=f'show -h help message for {sys.argv[0]}, if any') group = parser.add_argument_group('MPyC') group.add_argument('-c', '--config', metavar='C', - help='party configuration file C, which defines N') + help='party configuration file C, which defines M') group.add_argument('-t', '--threshold', type=int, metavar='T', - help='threshold T, 2T+1<=N') + help='threshold T, 2T+1<=M') group.add_argument('-l', '--bit-length', type=int, metavar='L', - help='maximum bit length L (for comparisons etc.)') + help='maximum bit length L (for comparisons etc.)') group.add_argument('-k', '--security-parameter', type=int, metavar='K', - help='security parameter K for leakage probability 1/2**K') + help='security parameter K for leakage probability 1/2**K') group.add_argument('--no-log', action='store_true', - default=False, help='disable logging') + default=False, help='disable logging') group.add_argument('--ssl', action='store_true', - default=False, help='enable SSL connections') + default=False, help='enable SSL connections') group.add_argument('--no-async', action='store_true', - default=False, help='disable asynchronous evaluation') + default=False, help='disable asynchronous evaluation') group.add_argument('-f', type=str, - default='', help='consume IPython string') + default='', help='consume IPython string') parser.set_defaults(bit_length=32, security_parameter=30) options, args = parser.parse_known_args() if options.HELP: @@ -1126,9 +1140,10 @@ def setup(): else: options.config = os.path.join('.config', options.config) id, parties = _load_config(options.config, options.threshold) + m = len(parties) if options.threshold is None: - options.threshold = (len(parties) - 1) // 2 - assert 2 * options.threshold < len(parties) + options.threshold = (m - 1) // 2 + assert 2 * options.threshold < m runtime = Runtime(id, parties, options) runtime.parser = parser diff --git a/mpyc/sectypes.py b/mpyc/sectypes.py index 58b8ae65..5339dc83 100644 --- a/mpyc/sectypes.py +++ b/mpyc/sectypes.py @@ -210,7 +210,7 @@ def __init__(self, value=None): super().__init__(field, value) SecureFld.field = field SecureFld.bit_length = l - name = f'SecFld{SecureFld.bit_length}({SecureFld.field.modulus})' + name = f'SecFld{SecureFld.bit_length}({SecureFld.field.modulus})' _sectypes[(l, p)] = type(name, (SecureFld,), {'__slots__':()}) return _sectypes[(l, p)] @@ -246,7 +246,7 @@ def SecInt(l=None, p=None, n=2): def SecFxp(l=None, f=None, p=None, n=2): """Secure l-bit fixed-point numbers with f-bit fractional part. - + NB: if dividing secure fixed-point numbers, make sure that l =~ 2f. """ if l is None: @@ -257,7 +257,7 @@ def SecFxp(l=None, f=None, p=None, n=2): if (l, f, p, n) not in _sectypes: SecureFxp = _SecNum(l, f, p, n) if p is None: - name = f'SecFxp{l}:{f}' + name = f'SecFxp{l}:{f}' else: name = f'SecFxp{l}:{f}({p})' def init(self, value=None, integral=False): diff --git a/mpyc/thresha.py b/mpyc/thresha.py index a60a1566..fbfd5457 100644 --- a/mpyc/thresha.py +++ b/mpyc/thresha.py @@ -15,22 +15,22 @@ import hashlib import secrets -def random_split(s, d, n): - """Split each secret given in s into n random Shamir shares. +def random_split(s, t, m): + """Split each secret given in s into m random Shamir shares. - The (maximum) degree for the Shamir polynomials is d, 0 <= d < n. + The (maximum) degree for the Shamir polynomials is t, 0 <= t < n. Return matrix of shares, one row per party. """ p = s[0].modulus - m = len(s) - shares = [[None] * m for _ in range(n)] - for h in range(m): - c = [secrets.randbelow(p) for _ in range(d)] - # polynomial f(x) = s[h] + c[d-1] x + c[d-2] x^2 + ... + c[0] x^d - for i in range(n): + n = len(s) + shares = [[None] * n for _ in range(m)] + for h in range(n): + c = [secrets.randbelow(p) for _ in range(t)] + # polynomial f(x) = s[h] + c[t-1] x + c[t-2] x^2 + ... + c[0] x^t + for i in range(m): y = 0 - for c_k in c: - y += c_k + for c_j in c: + y += c_j y *= i + 1 shares[i][h] = (y + s[h].value) % p return shares @@ -63,16 +63,17 @@ def recombine(field, points, x_rs=0): coefficient *= (x_r - x_j) / (x_i - x_j) vector[r].append(coefficient.value) _recombination_vectors[(field, xs, x_r.value)] = vector[r] - m = len(shares[0]) - sums = [[0] * m for _ in range(len(x_rs))] - for i in range(len(shares)): - for h in range(m): + m = len(shares) + n = len(shares[0]) + sums = [[0] * n for _ in range(len(x_rs))] + for i in range(m): + for h in range(n): s = shares[i][h] if not isinstance(s, int): s = s.value for r in range(len(sums)): sums[r][h] += s * vector[r][i] - for h in range(m): + for h in range(n): for r in range(len(sums)): sums[r][h] = field(sums[r][h]) if isinstance(x_rs, tuple): @@ -84,58 +85,58 @@ def recombine(field, points, x_rs=0): #the party concerned, and the subset. _f_in_i_cache = {} -def pseudorandom_share(field, n, i, prfs, uci, m): - """Return pseudorandom Shamir shares for party i for m random numbers. +def pseudorandom_share(field, m, i, prfs, uci, n): + """Return pseudorandom Shamir shares for party i for n random numbers. The shares are based on the pseudorandom functions for party i, given in prfs, which maps subsets of parties to PRF instances. Input uci is used to evaluate the PRFs on a unique common input. """ s = str(uci) - sums = [0] * m - # iterate over (n-1 choose d) subsets for degree d. + sums = [0] * n + # iterate over (m-1 choose t) subsets for degree t. for subset, prf in prfs.items(): try: f_in_i = _f_in_i_cache[(field, i, subset)] except KeyError: - complement = frozenset(range(n)) - subset + complement = frozenset(range(m)) - subset points = [(0, [1])] + [(x + 1, [0]) for x in complement] f_in_i = recombine(field, points, i + 1)[0].value _f_in_i_cache[(field, i, subset)] = f_in_i - prl = prf(s, m) - for h in range(m): + prl = prf(s, n) + for h in range(n): sums[h] += prl[h] * f_in_i - for h in range(m): + for h in range(n): sums[h] = field(sums[h]) return sums -def pseudorandom_share_zero(field, n, i, prfs, uci, m): - """Return pseudorandom Shamir shares for party i for m sharings of 0. +def pseudorandom_share_zero(field, m, i, prfs, uci, n): + """Return pseudorandom Shamir shares for party i for n sharings of 0. The shares are based on the pseudorandom functions for party i, given in prfs, which maps subsets of parties to PRF instances. Input uci is used to evaluate the PRFs on a unique common input. """ s = str(uci) - sums = [0] * m - # iterate over (n-1 choose d) subsets for degree d. + sums = [0] * n + # iterate over (m-1 choose t) subsets for degree t. for subset, prf in prfs.items(): try: f_in_i = _f_in_i_cache[(field, i, subset)] except KeyError: - complement = frozenset(range(n)) - subset + complement = frozenset(range(m)) - subset points = [(0, [1])] + [(x + 1, [0]) for x in complement] f_in_i = recombine(field, points, i + 1)[0].value _f_in_i_cache[(field, i, subset)] = f_in_i - d = n - len(subset) - prl = prf(s, m * d) - for h in range(m): + d = m - len(subset) + prl = prf(s, n * d) + for h in range(n): y = 0 - for k in range(d): - y += prl[h * d + k] + for j in range(d): + y += prl[h * d + j] y *= i + 1 sums[h] += y * f_in_i - for h in range(m): + for h in range(n): sums[h] = field(sums[h]) return sums @@ -155,10 +156,10 @@ def __init__(self, key, bound): self.max = bound self.byte_length = len(self.key) + ((bound-1).bit_length() + 7) // 8 - def __call__(self, s, m=None): + def __call__(self, s, n=None): """Return a number or list of numbers in range(self.max) for input string s.""" - n = m if m else 1 + n_ = n if n else 1 l = self.byte_length - dk = hashlib.pbkdf2_hmac('sha1', self.key, s.encode(), 1, n * l) - x = [int.from_bytes(dk[i * l: (i+1) * l], byteorder='little') % self.max for i in range(n)] - return x if m else x[0] + dk = hashlib.pbkdf2_hmac('sha1', self.key, s.encode(), 1, n_ * l) + x = [int.from_bytes(dk[i * l: (i+1) * l], byteorder='little') % self.max for i in range(n_)] + return x if n else x[0] diff --git a/tests/test_thresha.py b/tests/test_thresha.py index e1368505..1080fedb 100644 --- a/tests/test_thresha.py +++ b/tests/test_thresha.py @@ -11,25 +11,25 @@ def setUp(self): def test_secretsharing(self): field = self.f2 t = 0 - n = 1 + m = 1 a = [field(0), field(1)] - shares = thresha.random_split(a, t, n) + shares = thresha.random_split(a, t, m) b = thresha.recombine(field, [(j + 1, shares[j]) for j in range(len(shares))]) self.assertEqual(a, b) field = self.f19 for t in range(8): - n = 2 * t + 1 + m = 2 * t + 1 for i in range(t): a = [field(i), field(-i), field(i**2), field(-i**2)] - shares = thresha.random_split(a, t, n) + shares = thresha.random_split(a, t, m) b = thresha.recombine(field, [(j + 1, shares[j]) for j in range(len(shares))]) self.assertEqual(a, b) - n = 17 - for t in range((n + 1) // 2): + m = 17 + for t in range((m + 1) // 2): for i in range(t): a = [field(i), field(-i), field(i**2), field(-i**2)] - shares = thresha.random_split(a, t, n) + shares = thresha.random_split(a, t, m) b = thresha.recombine(field, [(j + 1, shares[j]) for j in range(len(shares))]) self.assertEqual(a, b) @@ -48,16 +48,16 @@ def test_prss(self): key = b'00112233445566778899aabbccddeeff' max = field.modulus F = thresha.PRF(key, max) - n = 1 + m = 1 id = 0 prfs = {frozenset([0]): F} uci = 'test uci' - m = 8 - a = F(uci, m) - shares = thresha.pseudorandom_share(field, n, id, prfs, uci, m) + n = 8 + a = F(uci, n) + shares = thresha.pseudorandom_share(field, m, id, prfs, uci, n) b = thresha.recombine(field, [(1, shares)]) self.assertEqual(a, b) - a = [0] * m - shares = thresha.pseudorandom_share_zero(field, n, id, prfs, uci, m) + a = [0] * n + shares = thresha.pseudorandom_share_zero(field, m, id, prfs, uci, n) b = thresha.recombine(field, [(1, shares)]) self.assertEqual(a, b)