From 9897f6cd46867bd601e44aa9d11aa74ed4ce944c Mon Sep 17 00:00:00 2001 From: lschoe Date: Mon, 24 Sep 2018 17:34:42 +0200 Subject: [PATCH] Vectorized input(). Method input() now also handles lists, like output(). --- demos/ot.py | 13 ++--- demos/parallelsort.py | 62 ++++++++++++++++---- mpyc/runtime.py | 128 ++++++++++++++++++++++++------------------ tests/test_runtime.py | 4 ++ 4 files changed, 135 insertions(+), 72 deletions(-) diff --git a/demos/ot.py b/demos/ot.py index 1904dce1..0f6d63a1 100644 --- a/demos/ot.py +++ b/demos/ot.py @@ -6,7 +6,7 @@ n = len(mpc.parties) if n % 2 == 0: - print('OT runs with odd number of parties only') + print('OT runs with odd number of parties only.') sys.exit() t = n // 2 @@ -16,20 +16,19 @@ print('You are the trusted third party.') elif 1 <= mpc.id <= t: message[mpc.id - 1] = (random.randint(0, 99), random.randint(0, 99)) - print('You are sender %d holding messages %d and %d.' % (mpc.id, message[mpc.id - 1][0], message[mpc.id - 1][1])) + print(f'You are sender {mpc.id} holding messages {message[mpc.id - 1][0]} and {message[mpc.id - 1][1]}.') else: choice[mpc.id - t - 1] = random.randint(0, 1) - print('You are receiver %d with random choice bit %d.' % (mpc.id - t, choice[mpc.id - t - 1])) + print(f'You are receiver {mpc.id - t} with random choice bit {choice[mpc.id - t - 1]}.') mpc.start() secnum = mpc.SecInt() for i in range(1, t + 1): - m0 = mpc.input(secnum(message[i - 1][0]), i) - m1 = mpc.input(secnum(message[i - 1][1]), i) + m = mpc.input([secnum(message[i - 1][0]), secnum(message[i - 1][1])], i) b = mpc.input(secnum(choice[i - t - 1]), t + i) - m = mpc.run(mpc.output(m0 + b * (m1 - m0), t + i)) + m = mpc.run(mpc.output(m[0] + b * (m[1] - m[0]), t + i)) if m: - print('You have received message %s.' % m) + print(f'You have received message {m}.') mpc.shutdown() diff --git a/demos/parallelsort.py b/demos/parallelsort.py index 9abb8aa4..a8b08979 100644 --- a/demos/parallelsort.py +++ b/demos/parallelsort.py @@ -4,34 +4,76 @@ Then, the i-th party broadcasts the i-th smallest number to all parties. This effectively sorts the random numbers. -The protocol provides no secrecy but shows that a parallel computation can -be seen as a special case of a secure multiparty computation. Accordingly, +Next, each party broadcasts a list of m random numbers to all parties, +where m is the number of parties in the protocol. +The parties thus generate m**2 numbers in total. +Then, the i-th party broadcasts the sorted segment containing the +(m*i)-th smallest number up to (but not including) the (m*(i+1))-st +smallest number. This effectively sorts the m**2 random numbers. + +The protocols provide no secrecy but show that parallel computation can +be seen as a special case of secure multiparty computation. Accordingly, we set the threshold for the number of corrupt parties simply to 0. """ import random +import itertools from mpyc.runtime import mpc -mpc.threshold = 0 # no secret sharing +def quickselect(x, k, l=None): + """Stub for efficient selection of kth smallest element of x + up to (k+l-1)st smallest element of x, in arbitrary order. + + Efficient selection can be done using O(len(x)) comparisons. + """ + if l: + return sorted(x)[k:k+l] + else: + return sorted(x)[k] + +mpc.threshold = 0 # No secret sharing. +m = len(mpc.parties) mpc.start() secint = mpc.SecInt() print('Using secure integers:', secint) x = mpc.run(mpc.output(mpc.input(secint(random.randint(0, 99))))) -print('Random inputs, one per party: ', x) +print('Random inputs, one per party:', x) x = [a.signed() for a in x] -x.sort() -x = mpc.run(mpc.output(mpc.input(secint(x[mpc.id])))) -print('Sorted outputs, one per party: ', x) +a = quickselect(x, mpc.id) +x = mpc.run(mpc.output(mpc.input(secint(a)))) +print('Sorted outputs, one per party:', x) + +x = mpc.input([secint(random.randint(0, 999)) for _ in range(m)]) +x = list(itertools.chain(*x)) +x = mpc.run(mpc.output(x)) +print('Random inputs,', m, 'per party:', x) +x = [a.signed() for a in x] +x = sorted(quickselect(x, m * mpc.id, m)) +x = mpc.input([secint(a) for a in x]) +x = list(itertools.chain(*x)) +x = mpc.run(mpc.output(x)) +print('Sorted outputs,', m, 'per party:', x) secfxp = mpc.SecFxp() -print('Using secure fixed-point numbers:', secfxp) +print(f'Using secure fixed-point numbers: {secfxp}') x = mpc.run(mpc.output(mpc.input(secfxp(0.5 - random.randint(0, 99))))) print('Random inputs, one per party: ', x) x = [a.signed() for a in x] -x.sort() -x = mpc.run(mpc.output(mpc.input(secfxp(x[mpc.id])))) +a = quickselect(x, mpc.id) +x = mpc.run(mpc.output(mpc.input(secfxp(a)))) print('Sorted outputs, one per party: ', x) +x = mpc.input([secfxp(random.randint(0,8)/8 + random.randint(0, 99)) for _ in range(m)]) +x = list(itertools.chain(*x)) +x = mpc.run(mpc.output(x)) +print('Random inputs,', m, 'per party:', x) +x = [a.signed() for a in x] +x = sorted(quickselect(x, m * mpc.id, m)) +x = mpc.input([secfxp(a) for a in x]) +x = list(itertools.chain(*x)) +x = mpc.run(mpc.output(x)) +print('Sorted outputs,', m, 'per party:', x) + mpc.shutdown() diff --git a/mpyc/runtime.py b/mpyc/runtime.py index 5e80ecb2..326e50f5 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -155,7 +155,7 @@ def start(self): else: context = None server_hostname = None - connect = self._loop.create_connection(factory, peer.host, peer.port, + connect = self._loop.create_connection(factory, peer.host, peer.port, ssl=context, server_hostname=server_hostname) self.run(connect) break @@ -195,28 +195,47 @@ def shutdown(self): coroutine = staticmethod(mpc_coro) returnType = staticmethod(returnType) - def input(self, a, senders=None): - """Input a to the computation.""" + def input(self, x, senders=None): + """Input x to the computation. + + Value x is a secure number, or x list of secure numbers. + The senders are the parties that provide an input. + The default is to let every party be a sender. + """ + x_is_list = isinstance(x, list) + if x_is_list: + x = x[:] + else: + x = [x] if senders is None: senders = list(range(len(self.parties))) - if isinstance(senders, int): - return self._distribute(a, [senders])[0] + senders_is_list = isinstance(senders, list) + if not senders_is_list: + senders = [senders] + y = self._distribute(x, senders) + if not senders_is_list: + y = y[0] + if not x_is_list: + y = y[0] else: - return self._distribute(a, senders) + if not x_is_list: + y = [a[0] for a in y] + return y @mpc_coro - async def _distribute(self, a, senders): - """Distribute shares for each secret a provided by a sender.""" - value = a.df if not isinstance(a.df, Future) else None + async def _distribute(self, x, senders): + """Distribute shares for each x provided by a sender.""" + stype = type(x[0]) # all elts assumed of same type + field = stype.field + await returnType(stype, len(senders), len(x)) + value = x[0].df if not isinstance(x[0].df, Future) else None assert value is None or self.id in senders - stype = type(a) - await returnType(stype, len(senders)) - field = stype.field + x = [a.df for a in x] # Extract values from all elements of x. shares = [None] * len(senders) for i, peer_id in enumerate(senders): if peer_id == self.id: - in_shares = thresha.random_split([value], self.threshold, len(self.parties)) + in_shares = thresha.random_split(x, self.threshold, len(self.parties)) for other_id, data in enumerate(in_shares): data = field.to_bytes(data) if other_id == self.id: @@ -226,47 +245,50 @@ async def _distribute(self, a, senders): else: shares[i] = self._expect_share(peer_id) shares = await gather_shares(shares) - shares = [field(field.from_bytes(r)[0]) for r in shares] - return shares + return [[field(a) for a in field.from_bytes(r)] for r in shares] - def output(self, a, receivers=None, threshold=None): - """Output the value of a to the receivers specified. + def output(self, x, receivers=None, threshold=None): + """Output the value of x to the receivers specified. - Value a is a secure number, or a list of secure numbers. - The receivers are the parties that will obtain the result. - The default is to let everybody know the result. + Value x is a secure number, or a list of secure numbers. + The receivers are the parties that will obtain the output. + The default is to let every party be a receiver. """ - if isinstance(a, list): - a = a[:] + x_is_list = isinstance(x, list) + if x_is_list: + x = x[:] + else: + x = [x] if receivers is None: receivers = list(range(len(self.parties))) elif isinstance(receivers, int): receivers = [receivers] if threshold is None: threshold = self.threshold - return gather_shares(self._recombine(a, receivers, threshold)) + y = self._recombine(x, receivers, threshold) + if not x_is_list: + y = y[0] + return gather_shares(y) @mpc_coro - async def _recombine(self, a, receivers, threshold): - """Recombine shares of a.""" - if not isinstance(a, list): - a = tuple([a]) - sftype = type(a[0]) # all elts assumed of same type + async def _recombine(self, x, receivers, threshold): + """Recombine shares of elements of x.""" + sftype = type(x[0]) # all elts assumed of same type if issubclass(sftype, Share): if sftype.field.frac_length == 0: - await returnType(sftype) + await returnType(sftype, len(x)) else: - await returnType((sftype, a[0].integral)) - a = await gather_shares(a) - field = type(a[0]) + await returnType((sftype, x[0].integral), len(x)) + x = await gather_shares(x) + field = type(x[0]) else: - await returnType(Share) + await returnType(Share, len(x)) field = sftype # Send share to all successors in receivers. for peer_id in receivers: if 0 < (peer_id - self.id) % len(self.parties) <= threshold: - self._send_share(peer_id, field.to_bytes(list(a))) + self._send_share(peer_id, field.to_bytes(x)) # Receive and recombine shares if this party is a receiver. if self.id in receivers: shares = [None] * threshold @@ -274,43 +296,39 @@ async def _recombine(self, a, receivers, threshold): shares[i] = self._expect_share((self.id - threshold + i) % len(self.parties)) shares = await gather_shares(shares) shares = [((self.id - threshold + j) % len(self.parties) + 1, field.from_bytes(shares[j])) for j in range(threshold)] - shares.append((self.id + 1, list(a))) - b = thresha.recombine(field, shares) - if isinstance(a, tuple): - return b[0] - else: - return b + shares.append((self.id + 1, x)) + return thresha.recombine(field, shares) else: - return + return [None] * len(x) @mpc_coro - async def _reshare(self, a): - if not isinstance(a, list): - a = tuple([a]) - sftype = type(a[0]) # all elts assumed of same type + async def _reshare(self, x): + x_is_list = isinstance(x, list) + if not x_is_list: + x = [x] + sftype = type(x[0]) # all elts assumed of same type if issubclass(sftype, Share): if sftype.field.frac_length == 0: - await returnType(sftype, len(a)) + await returnType(sftype, len(x)) else: - await returnType((sftype, a[0].integral), len(a)) - a = await mpc.gather(a) + await returnType((sftype, x[0].integral), len(x)) + x = await mpc.gather(x) field = sftype.field else: await returnType(Share) field = sftype - in_shares = thresha.random_split(a, self.threshold, len(self.parties)) + in_shares = thresha.random_split(x, self.threshold, len(self.parties)) in_shares = [field.to_bytes(elts) for elts in in_shares] # Recombine the first 2t+1 output_shares. out_shares = await gather_shares(self._exchange_shares(in_shares)[:2*self.threshold+1]) - b = thresha.recombine(field, [(j + 1, field.from_bytes(out_shares[j])) for j in range(len(out_shares))]) + y = thresha.recombine(field, [(j + 1, field.from_bytes(out_shares[j])) for j in range(len(out_shares))]) if issubclass(sftype, Share): - b = [sftype(s) for s in b] - if isinstance(a, tuple): - return b[0] - else: - return b + y = [sftype(s) for s in y] + if not x_is_list: + y = y[0] + return y @mpc_coro async def trunc(self, a, f=None): diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 29f6e347..601ecbe0 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -46,6 +46,10 @@ def test_secint(self): secint = mpc.SecInt() a = secint(12) b = secint(13) + c = mpc.run(mpc.output(mpc.input(a, 0))) + self.assertEqual(c, 12) + c = mpc.run(mpc.output(mpc.input([a, b], 0))) + self.assertEqual(c, [12, 13]) c = mpc.run(mpc.output(a * b + b)) self.assertEqual(c, 12 * 13 + 13) c = mpc.run(mpc.output(a**11 * a**(-6) * a**(-5)))