Skip to content

Commit

Permalink
Integer division by 2. (#4)
Browse files Browse the repository at this point in the history
Integer division with quotient and remainder. Only division by 2 (least significant bit) for now.
  • Loading branch information
lschoe authored Sep 30, 2018
1 parent feb0719 commit 6655c2a
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 74 deletions.
4 changes: 2 additions & 2 deletions demos/indextounitvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def si1(x, n):
elif n==2:
return [x]
else:
b = mpc.lsb(x)
v = si1((x - b) / 2, (n + 1) // 2)
x2, b = divmod(x, 2)
v = si1(x2, (n + 1) // 2)
w = mpc.scalar_mul(b, v)
return [b-sum(w)] + [v[i//2]-w[i//2] if i%2==0 else w[i//2] for i in range(n-2)]
v = si1(x, n)
Expand Down
2 changes: 1 addition & 1 deletion mpyc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
are provided as well (e.g., some matrix-vector operations).
"""

__version__ = '0.3.6'
__version__ = '0.3.7'
__license__ = 'Apache License 2.0'
4 changes: 2 additions & 2 deletions mpyc/pfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __sub__(self, other):
return NotImplemented

def __rsub__(self, other):
"""Subtraction (reflected argument version)."""
"""Subtraction (with reflected arguments)."""
if isinstance(other, int):
return type(self)(other - self.value)
else:
Expand Down Expand Up @@ -218,7 +218,7 @@ def __truediv__(self, other):
return NotImplemented

def __rtruediv__(self, other):
"""Division (reflected argument version)."""
"""Division (with reflected arguments)."""
if isinstance(other, int):
return type(self)(other) * self._reciprocal()
else:
Expand Down
2 changes: 1 addition & 1 deletion mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ async def reciprocal(self, a):
return r * (1 << stype.field.frac_length) / ar

def pow(self, a, b):
"""Secure exponentation of a to public power b."""
"""Secure exponentation a raised to the power of b, for public integer b."""
if b == 0:
return type(a)(1)
if b < 0:
Expand Down
165 changes: 119 additions & 46 deletions mpyc/sectypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,94 +32,168 @@ def __bool__(self):
"""Use of secret-shared values in Boolean expressions makes no sense."""
raise TypeError('cannot use secure type in Boolean expressions')

def __neg__(self):
"""Negation."""
return self.runtime.neg(self)

def __add__(self, other):
"""Addition."""
def _coerce(self, other):
if isinstance(other, Share):
if type(self) != type(other):
return NotImplemented
elif isinstance(other, int):
other = type(self)(other)
elif isinstance(other, float):
if type(self).field.frac_length == 0:
return NotImplemented
else:
other = type(self)(other)
elif isinstance(other, int):
other = type(self)(other)
elif type(self).field != type(other):
return NotImplemented
return self.runtime.add(self, other)

__radd__ = __add__
return other

def __sub__(self, other):
"""Subtraction."""
def _coerce2(self, other):
if isinstance(other, Share):
if type(self) != type(other):
return NotImplemented
elif isinstance(other, int):
# other <<= type(self).field.frac_length
pass
elif isinstance(other, float):
if type(self).field.frac_length == 0:
return NotImplemented
else:
other = type(self)(other)
elif isinstance(other, int):
other = type(self)(other)
elif type(self).field != type(other):
return NotImplemented
return other

def __neg__(self):
"""Negation."""
return self.runtime.neg(self)

def __add__(self, other):
"""Addition."""
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
return self.runtime.add(self, other)

__radd__ = __add__

def __sub__(self, other):
"""Subtraction."""
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
return self.runtime.sub(self, other)

def __rsub__(self, other):
"""Subtraction (reflected argument version)."""
if isinstance(other, Share):
if type(self) != type(other):
return NotImplemented
elif isinstance(other, float):
if type(self).field.frac_length == 0:
return NotImplemented
else:
other = type(self)(other)
elif isinstance(other, int):
other = type(self)(other)
elif type(self).field != type(other):
"""Subtraction (with reflected arguments)."""
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
return self.runtime.sub(other, self)

def __mul__(self, other):
"""Multiplication."""
if isinstance(other, Share):
if type(self) != type(other):
return NotImplemented
elif isinstance(other, float):
if type(self).field.frac_length == 0:
return NotImplemented
else:
other = type(self)(other)
elif isinstance(other, int):
# other <<= type(self).field.frac_length
pass
elif type(self).field != type(other):
other = self._coerce2(other)
if other is NotImplemented:
return NotImplemented
return self.runtime.mul(self, other)

__rmul__ = __mul__

def __truediv__(self, other):
"""Division."""
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
return self.runtime.div(self, other)

__floordiv__ = __truediv__

def __rtruediv__(self, other):
"""Division (reflected argument version)."""
"""Division (with reflected arguments)."""
other = self._coerce2(other)
if other is NotImplemented:
return NotImplemented
return self.runtime.div(other, self)

__rfloordiv__ = __rtruediv__
def __mod__(self, other):
"""Integer remainder."""
if type(self).__name__.startswith('SecFld'):
return NotImplemented
elif type(other).__name__.startswith('SecFld'):
return NotImplemented
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
# stub: only mod 2
assert other.df.value == 2, 'Least significant bit only, for now!'
r = self.runtime.lsb(self)
return r

def __rmod__(self, other):
"""Integer remainder (with reflected arguments)."""
if type(self).__name__.startswith('SecFld'):
return NotImplemented
elif type(other).__name__.startswith('SecFld'):
return NotImplemented
# stub: only mod 2
assert self.df.value == 2, 'Least significant bit only, for now!'
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
r = self.runtime.lsb(other)
return r

def __floordiv__(self, other):
"""Integer quotient."""
# stub: only div 2
r = self.__mod__(other)
if r is NotImplemented:
return NotImplemented
other = self._coerce(other) # avoid coercing twice
if other is NotImplemented:
return NotImplemented
q = (self - r) / other.df
return q

def __rfloordiv__(self, other):
"""Integer quotient (with reflected arguments)."""
# stub: only div 2
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
r = other.__mod__(self) # avoid coercing twice
if r is NotImplemented:
return NotImplemented
q = (other - r) / self.df
return q

def __divmod__(self, other):
"""Integer division."""
# stub: only divmod 2
r = self.__mod__(other)
if r is NotImplemented:
return NotImplemented
other = self._coerce(other) # avoid coercing twice
if other is NotImplemented:
return NotImplemented
q = (self - r) / other.df
return q, r

def __rdivmod__(self, other):
"""Integer division (with reflected arguments)."""
# stub: only divmod 2
other = self._coerce(other)
if other is NotImplemented:
return NotImplemented
r = other.__mod__(self) # avoid coercing twice
if r is NotImplemented:
return NotImplemented
q = (other - r) / self.df
return q, r

def __pow__(self, exponent):
def __pow__(self, other):
"""Exponentation with publicly known integer exponent."""
return self.runtime.pow(self, exponent)
if not isinstance(other, int):
return NotImplemented
return self.runtime.pow(self, other)

def __and__(self, other):
"""And 1-bit."""
Expand Down Expand Up @@ -187,7 +261,6 @@ def __ne__(self, other):
c = self - other
return 1 - self.runtime.is_zero(c)


_sectypes = {}

def SecFld(p=None, l=None):
Expand Down
54 changes: 32 additions & 22 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ def test_secfld(self):
self.assertEqual(mpc.run(mpc.output(a == -b)), 1)
self.assertEqual(mpc.run(mpc.output(a**2 == b**2)), 1)
self.assertEqual(mpc.run(mpc.output(a != b)), 1)
with self.assertRaises(TypeError):
a < b
with self.assertRaises(TypeError):
a <= b
with self.assertRaises(TypeError):
a > b
with self.assertRaises(TypeError):
a >= b
for a, b in [(secfld(1), secfld(1)), (1, secfld(1)), (secfld(1), 1)]:
with self.assertRaises(TypeError):
a < b
with self.assertRaises(TypeError):
a <= b
with self.assertRaises(TypeError):
a > b
with self.assertRaises(TypeError):
a >= b
with self.assertRaises(TypeError):
a // b
with self.assertRaises(TypeError):
a % b
with self.assertRaises(TypeError):
divmod(a, b)
b = mpc.random_bit(secfld)
self.assertIn(mpc.run(mpc.output(b)), [0,1])
b = mpc.random_bit(secfld, signed=True)
Expand Down Expand Up @@ -76,17 +83,19 @@ def test_secint(self):
c = mpc.run(mpc.output(mpc.to_bits(secint(-2**31))))
self.assertEqual(c, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])

self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-2**31)))), 0)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-2**31 + 1)))), 1)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-1)))), 1)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(0)))), 0)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(1)))), 1)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(2**31 - 1)))), 1)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(5)))), 1)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-5)))), 1)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(50)))), 0)
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-50)))), 0)

self.assertEqual(mpc.run(mpc.output(secint(-2**31) % 2)), 0)
self.assertEqual(mpc.run(mpc.output(secint(-2**31 + 1) % 2)), 1)
self.assertEqual(mpc.run(mpc.output(secint(-1) % 2)), 1)
self.assertEqual(mpc.run(mpc.output(secint(0) % 2)), 0)
self.assertEqual(mpc.run(mpc.output(secint(1) % 2)), 1)
self.assertEqual(mpc.run(mpc.output(secint(2**31 - 1) % 2)), 1)
self.assertEqual(mpc.run(mpc.output(secint(5) % 2)), 1)
self.assertEqual(mpc.run(mpc.output(secint(-5) % 2)), 1)
self.assertEqual(mpc.run(mpc.output(secint(50) % 2)), 0)
self.assertEqual(mpc.run(mpc.output(secint(-50) % 2)), 0)
self.assertEqual(mpc.run(mpc.output(secint(5) // 2)), 2)
self.assertEqual(mpc.run(mpc.output(secint(50) // 2)), 25)

self.assertEqual(mpc.run(mpc.output(secint(3)**73)), 3**73)
b = mpc.random_bit(secint)
self.assertIn(mpc.run(mpc.output(b)), [0,1])
Expand Down Expand Up @@ -171,6 +180,7 @@ def test_secfxp(self):
ss2 = round(max(0, s[1]) * (1 << f))
self.assertEqual(mpc.run(mpc.output(mpc.max(0, y))), ss2)

self.assertEqual(mpc.run(mpc.output(mpc.lsb(secfxp(1)))), 0*(2**f))
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secfxp(1/2**f)))), 1*(2**f))
self.assertEqual(mpc.run(mpc.output(mpc.lsb(secfxp(2/2**f)))), 0*(2**f))
self.assertEqual(mpc.run(mpc.output(secfxp(1) % 2**(1-f))), 0*(2**f))
self.assertEqual(mpc.run(mpc.output(secfxp(1/2**f) % 2**(1-f))), 1*(2**f))
self.assertEqual(mpc.run(mpc.output(secfxp(2/2**f) % 2**(1-f))), 0*(2**f))
self.assertEqual(mpc.run(mpc.output(secfxp(1) // 2**(1-f))), 2**(f-1))

0 comments on commit 6655c2a

Please sign in to comment.