Skip to content

Commit

Permalink
Vectorized input().
Browse files Browse the repository at this point in the history
Method input() now also handles lists, like output().
  • Loading branch information
lschoe committed Sep 24, 2018
1 parent 28ca4d1 commit 9897f6c
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 72 deletions.
13 changes: 6 additions & 7 deletions demos/ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
n = len(mpc.parties)

if n % 2 == 0:
print('OT runs with odd number of parties only')
print('OT runs with odd number of parties only.')
sys.exit()

t = n // 2
Expand All @@ -16,20 +16,19 @@
print('You are the trusted third party.')
elif 1 <= mpc.id <= t:
message[mpc.id - 1] = (random.randint(0, 99), random.randint(0, 99))
print('You are sender %d holding messages %d and %d.' % (mpc.id, message[mpc.id - 1][0], message[mpc.id - 1][1]))
print(f'You are sender {mpc.id} holding messages {message[mpc.id - 1][0]} and {message[mpc.id - 1][1]}.')
else:
choice[mpc.id - t - 1] = random.randint(0, 1)
print('You are receiver %d with random choice bit %d.' % (mpc.id - t, choice[mpc.id - t - 1]))
print(f'You are receiver {mpc.id - t} with random choice bit {choice[mpc.id - t - 1]}.')

mpc.start()

secnum = mpc.SecInt()
for i in range(1, t + 1):
m0 = mpc.input(secnum(message[i - 1][0]), i)
m1 = mpc.input(secnum(message[i - 1][1]), i)
m = mpc.input([secnum(message[i - 1][0]), secnum(message[i - 1][1])], i)
b = mpc.input(secnum(choice[i - t - 1]), t + i)
m = mpc.run(mpc.output(m0 + b * (m1 - m0), t + i))
m = mpc.run(mpc.output(m[0] + b * (m[1] - m[0]), t + i))
if m:
print('You have received message %s.' % m)
print(f'You have received message {m}.')

mpc.shutdown()
62 changes: 52 additions & 10 deletions demos/parallelsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,76 @@
Then, the i-th party broadcasts the i-th smallest number to all parties.
This effectively sorts the random numbers.
The protocol provides no secrecy but shows that a parallel computation can
be seen as a special case of a secure multiparty computation. Accordingly,
Next, each party broadcasts a list of m random numbers to all parties,
where m is the number of parties in the protocol.
The parties thus generate m**2 numbers in total.
Then, the i-th party broadcasts the sorted segment containing the
(m*i)-th smallest number up to (but not including) the (m*(i+1))-st
smallest number. This effectively sorts the m**2 random numbers.
The protocols provide no secrecy but show that parallel computation can
be seen as a special case of secure multiparty computation. Accordingly,
we set the threshold for the number of corrupt parties simply to 0.
"""

import random
import itertools
from mpyc.runtime import mpc

mpc.threshold = 0 # no secret sharing
def quickselect(x, k, l=None):
"""Stub for efficient selection of kth smallest element of x
up to (k+l-1)st smallest element of x, in arbitrary order.
Efficient selection can be done using O(len(x)) comparisons.
"""
if l:
return sorted(x)[k:k+l]
else:
return sorted(x)[k]

mpc.threshold = 0 # No secret sharing.
m = len(mpc.parties)

mpc.start()

secint = mpc.SecInt()
print('Using secure integers:', secint)
x = mpc.run(mpc.output(mpc.input(secint(random.randint(0, 99)))))
print('Random inputs, one per party: ', x)
print('Random inputs, one per party:', x)
x = [a.signed() for a in x]
x.sort()
x = mpc.run(mpc.output(mpc.input(secint(x[mpc.id]))))
print('Sorted outputs, one per party: ', x)
a = quickselect(x, mpc.id)
x = mpc.run(mpc.output(mpc.input(secint(a))))
print('Sorted outputs, one per party:', x)

x = mpc.input([secint(random.randint(0, 999)) for _ in range(m)])
x = list(itertools.chain(*x))
x = mpc.run(mpc.output(x))
print('Random inputs,', m, 'per party:', x)
x = [a.signed() for a in x]
x = sorted(quickselect(x, m * mpc.id, m))
x = mpc.input([secint(a) for a in x])
x = list(itertools.chain(*x))
x = mpc.run(mpc.output(x))
print('Sorted outputs,', m, 'per party:', x)

secfxp = mpc.SecFxp()
print('Using secure fixed-point numbers:', secfxp)
print(f'Using secure fixed-point numbers: {secfxp}')
x = mpc.run(mpc.output(mpc.input(secfxp(0.5 - random.randint(0, 99)))))
print('Random inputs, one per party: ', x)
x = [a.signed() for a in x]
x.sort()
x = mpc.run(mpc.output(mpc.input(secfxp(x[mpc.id]))))
a = quickselect(x, mpc.id)
x = mpc.run(mpc.output(mpc.input(secfxp(a))))
print('Sorted outputs, one per party: ', x)

x = mpc.input([secfxp(random.randint(0,8)/8 + random.randint(0, 99)) for _ in range(m)])
x = list(itertools.chain(*x))
x = mpc.run(mpc.output(x))
print('Random inputs,', m, 'per party:', x)
x = [a.signed() for a in x]
x = sorted(quickselect(x, m * mpc.id, m))
x = mpc.input([secfxp(a) for a in x])
x = list(itertools.chain(*x))
x = mpc.run(mpc.output(x))
print('Sorted outputs,', m, 'per party:', x)

mpc.shutdown()
128 changes: 73 additions & 55 deletions mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def start(self):
else:
context = None
server_hostname = None
connect = self._loop.create_connection(factory, peer.host, peer.port,
connect = self._loop.create_connection(factory, peer.host, peer.port,
ssl=context, server_hostname=server_hostname)
self.run(connect)
break
Expand Down Expand Up @@ -195,28 +195,47 @@ def shutdown(self):
coroutine = staticmethod(mpc_coro)
returnType = staticmethod(returnType)

def input(self, a, senders=None):
"""Input a to the computation."""
def input(self, x, senders=None):
"""Input x to the computation.
Value x is a secure number, or x list of secure numbers.
The senders are the parties that provide an input.
The default is to let every party be a sender.
"""
x_is_list = isinstance(x, list)
if x_is_list:
x = x[:]
else:
x = [x]
if senders is None:
senders = list(range(len(self.parties)))
if isinstance(senders, int):
return self._distribute(a, [senders])[0]
senders_is_list = isinstance(senders, list)
if not senders_is_list:
senders = [senders]
y = self._distribute(x, senders)
if not senders_is_list:
y = y[0]
if not x_is_list:
y = y[0]
else:
return self._distribute(a, senders)
if not x_is_list:
y = [a[0] for a in y]
return y

@mpc_coro
async def _distribute(self, a, senders):
"""Distribute shares for each secret a provided by a sender."""
value = a.df if not isinstance(a.df, Future) else None
async def _distribute(self, x, senders):
"""Distribute shares for each x provided by a sender."""
stype = type(x[0]) # all elts assumed of same type
field = stype.field
await returnType(stype, len(senders), len(x))
value = x[0].df if not isinstance(x[0].df, Future) else None
assert value is None or self.id in senders
stype = type(a)
await returnType(stype, len(senders))

field = stype.field
x = [a.df for a in x] # Extract values from all elements of x.
shares = [None] * len(senders)
for i, peer_id in enumerate(senders):
if peer_id == self.id:
in_shares = thresha.random_split([value], self.threshold, len(self.parties))
in_shares = thresha.random_split(x, self.threshold, len(self.parties))
for other_id, data in enumerate(in_shares):
data = field.to_bytes(data)
if other_id == self.id:
Expand All @@ -226,91 +245,90 @@ async def _distribute(self, a, senders):
else:
shares[i] = self._expect_share(peer_id)
shares = await gather_shares(shares)
shares = [field(field.from_bytes(r)[0]) for r in shares]
return shares
return [[field(a) for a in field.from_bytes(r)] for r in shares]

def output(self, a, receivers=None, threshold=None):
"""Output the value of a to the receivers specified.
def output(self, x, receivers=None, threshold=None):
"""Output the value of x to the receivers specified.
Value a is a secure number, or a list of secure numbers.
The receivers are the parties that will obtain the result.
The default is to let everybody know the result.
Value x is a secure number, or a list of secure numbers.
The receivers are the parties that will obtain the output.
The default is to let every party be a receiver.
"""
if isinstance(a, list):
a = a[:]
x_is_list = isinstance(x, list)
if x_is_list:
x = x[:]
else:
x = [x]
if receivers is None:
receivers = list(range(len(self.parties)))
elif isinstance(receivers, int):
receivers = [receivers]
if threshold is None:
threshold = self.threshold
return gather_shares(self._recombine(a, receivers, threshold))
y = self._recombine(x, receivers, threshold)
if not x_is_list:
y = y[0]
return gather_shares(y)

@mpc_coro
async def _recombine(self, a, receivers, threshold):
"""Recombine shares of a."""
if not isinstance(a, list):
a = tuple([a])
sftype = type(a[0]) # all elts assumed of same type
async def _recombine(self, x, receivers, threshold):
"""Recombine shares of elements of x."""
sftype = type(x[0]) # all elts assumed of same type
if issubclass(sftype, Share):
if sftype.field.frac_length == 0:
await returnType(sftype)
await returnType(sftype, len(x))
else:
await returnType((sftype, a[0].integral))
a = await gather_shares(a)
field = type(a[0])
await returnType((sftype, x[0].integral), len(x))
x = await gather_shares(x)
field = type(x[0])
else:
await returnType(Share)
await returnType(Share, len(x))
field = sftype

# Send share to all successors in receivers.
for peer_id in receivers:
if 0 < (peer_id - self.id) % len(self.parties) <= threshold:
self._send_share(peer_id, field.to_bytes(list(a)))
self._send_share(peer_id, field.to_bytes(x))
# Receive and recombine shares if this party is a receiver.
if self.id in receivers:
shares = [None] * threshold
for i in range(threshold):
shares[i] = self._expect_share((self.id - threshold + i) % len(self.parties))
shares = await gather_shares(shares)
shares = [((self.id - threshold + j) % len(self.parties) + 1, field.from_bytes(shares[j])) for j in range(threshold)]
shares.append((self.id + 1, list(a)))
b = thresha.recombine(field, shares)
if isinstance(a, tuple):
return b[0]
else:
return b
shares.append((self.id + 1, x))
return thresha.recombine(field, shares)
else:
return
return [None] * len(x)

@mpc_coro
async def _reshare(self, a):
if not isinstance(a, list):
a = tuple([a])
sftype = type(a[0]) # all elts assumed of same type
async def _reshare(self, x):
x_is_list = isinstance(x, list)
if not x_is_list:
x = [x]
sftype = type(x[0]) # all elts assumed of same type
if issubclass(sftype, Share):
if sftype.field.frac_length == 0:
await returnType(sftype, len(a))
await returnType(sftype, len(x))
else:
await returnType((sftype, a[0].integral), len(a))
a = await mpc.gather(a)
await returnType((sftype, x[0].integral), len(x))
x = await mpc.gather(x)
field = sftype.field
else:
await returnType(Share)
field = sftype

in_shares = thresha.random_split(a, self.threshold, len(self.parties))
in_shares = thresha.random_split(x, self.threshold, len(self.parties))
in_shares = [field.to_bytes(elts) for elts in in_shares]
# Recombine the first 2t+1 output_shares.
out_shares = await gather_shares(self._exchange_shares(in_shares)[:2*self.threshold+1])
b = thresha.recombine(field, [(j + 1, field.from_bytes(out_shares[j])) for j in range(len(out_shares))])
y = thresha.recombine(field, [(j + 1, field.from_bytes(out_shares[j])) for j in range(len(out_shares))])

if issubclass(sftype, Share):
b = [sftype(s) for s in b]
if isinstance(a, tuple):
return b[0]
else:
return b
y = [sftype(s) for s in y]
if not x_is_list:
y = y[0]
return y

@mpc_coro
async def trunc(self, a, f=None):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def test_secint(self):
secint = mpc.SecInt()
a = secint(12)
b = secint(13)
c = mpc.run(mpc.output(mpc.input(a, 0)))
self.assertEqual(c, 12)
c = mpc.run(mpc.output(mpc.input([a, b], 0)))
self.assertEqual(c, [12, 13])
c = mpc.run(mpc.output(a * b + b))
self.assertEqual(c, 12 * 13 + 13)
c = mpc.run(mpc.output(a**11 * a**(-6) * a**(-5)))
Expand Down

0 comments on commit 9897f6c

Please sign in to comment.