From 6655c2aca8b16cbd6fc10a0cfaa15645e51171d4 Mon Sep 17 00:00:00 2001 From: Berry Schoenmakers Date: Mon, 1 Oct 2018 00:21:17 +0200 Subject: [PATCH] Integer division by 2. (#4) Integer division with quotient and remainder. Only division by 2 (least significant bit) for now. --- demos/indextounitvector.py | 4 +- mpyc/__init__.py | 2 +- mpyc/pfield.py | 4 +- mpyc/runtime.py | 2 +- mpyc/sectypes.py | 165 ++++++++++++++++++++++++++----------- tests/test_runtime.py | 54 +++++++----- 6 files changed, 157 insertions(+), 74 deletions(-) diff --git a/demos/indextounitvector.py b/demos/indextounitvector.py index 912c3022..66b01207 100644 --- a/demos/indextounitvector.py +++ b/demos/indextounitvector.py @@ -11,8 +11,8 @@ def si1(x, n): elif n==2: return [x] else: - b = mpc.lsb(x) - v = si1((x - b) / 2, (n + 1) // 2) + x2, b = divmod(x, 2) + v = si1(x2, (n + 1) // 2) w = mpc.scalar_mul(b, v) return [b-sum(w)] + [v[i//2]-w[i//2] if i%2==0 else w[i//2] for i in range(n-2)] v = si1(x, n) diff --git a/mpyc/__init__.py b/mpyc/__init__.py index 310fa2c7..335dacac 100644 --- a/mpyc/__init__.py +++ b/mpyc/__init__.py @@ -13,5 +13,5 @@ are provided as well (e.g., some matrix-vector operations). """ -__version__ = '0.3.6' +__version__ = '0.3.7' __license__ = 'Apache License 2.0' diff --git a/mpyc/pfield.py b/mpyc/pfield.py index 317b71b8..5de4242c 100644 --- a/mpyc/pfield.py +++ b/mpyc/pfield.py @@ -163,7 +163,7 @@ def __sub__(self, other): return NotImplemented def __rsub__(self, other): - """Subtraction (reflected argument version).""" + """Subtraction (with reflected arguments).""" if isinstance(other, int): return type(self)(other - self.value) else: @@ -218,7 +218,7 @@ def __truediv__(self, other): return NotImplemented def __rtruediv__(self, other): - """Division (reflected argument version).""" + """Division (with reflected arguments).""" if isinstance(other, int): return type(self)(other) * self._reciprocal() else: diff --git a/mpyc/runtime.py b/mpyc/runtime.py index e056364d..dc2d5c1f 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -474,7 +474,7 @@ async def reciprocal(self, a): return r * (1 << stype.field.frac_length) / ar def pow(self, a, b): - """Secure exponentation of a to public power b.""" + """Secure exponentation a raised to the power of b, for public integer b.""" if b == 0: return type(a)(1) if b < 0: diff --git a/mpyc/sectypes.py b/mpyc/sectypes.py index 5339dc83..66dc0aee 100644 --- a/mpyc/sectypes.py +++ b/mpyc/sectypes.py @@ -32,74 +32,68 @@ def __bool__(self): """Use of secret-shared values in Boolean expressions makes no sense.""" raise TypeError('cannot use secure type in Boolean expressions') - def __neg__(self): - """Negation.""" - return self.runtime.neg(self) - - def __add__(self, other): - """Addition.""" + def _coerce(self, other): if isinstance(other, Share): if type(self) != type(other): return NotImplemented + elif isinstance(other, int): + other = type(self)(other) elif isinstance(other, float): if type(self).field.frac_length == 0: return NotImplemented else: other = type(self)(other) - elif isinstance(other, int): - other = type(self)(other) elif type(self).field != type(other): return NotImplemented - return self.runtime.add(self, other) - - __radd__ = __add__ + return other - def __sub__(self, other): - """Subtraction.""" + def _coerce2(self, other): if isinstance(other, Share): if type(self) != type(other): return NotImplemented + elif isinstance(other, int): +# other <<= type(self).field.frac_length + pass elif isinstance(other, float): if type(self).field.frac_length == 0: return NotImplemented else: other = type(self)(other) - elif isinstance(other, int): - other = type(self)(other) elif type(self).field != type(other): return NotImplemented + return other + + def __neg__(self): + """Negation.""" + return self.runtime.neg(self) + + def __add__(self, other): + """Addition.""" + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented + return self.runtime.add(self, other) + + __radd__ = __add__ + + def __sub__(self, other): + """Subtraction.""" + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented return self.runtime.sub(self, other) def __rsub__(self, other): - """Subtraction (reflected argument version).""" - if isinstance(other, Share): - if type(self) != type(other): - return NotImplemented - elif isinstance(other, float): - if type(self).field.frac_length == 0: - return NotImplemented - else: - other = type(self)(other) - elif isinstance(other, int): - other = type(self)(other) - elif type(self).field != type(other): + """Subtraction (with reflected arguments).""" + other = self._coerce(other) + if other is NotImplemented: return NotImplemented return self.runtime.sub(other, self) def __mul__(self, other): """Multiplication.""" - if isinstance(other, Share): - if type(self) != type(other): - return NotImplemented - elif isinstance(other, float): - if type(self).field.frac_length == 0: - return NotImplemented - else: - other = type(self)(other) - elif isinstance(other, int): -# other <<= type(self).field.frac_length - pass - elif type(self).field != type(other): + other = self._coerce2(other) + if other is NotImplemented: return NotImplemented return self.runtime.mul(self, other) @@ -107,19 +101,99 @@ def __mul__(self, other): def __truediv__(self, other): """Division.""" + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented return self.runtime.div(self, other) - __floordiv__ = __truediv__ - def __rtruediv__(self, other): - """Division (reflected argument version).""" + """Division (with reflected arguments).""" + other = self._coerce2(other) + if other is NotImplemented: + return NotImplemented return self.runtime.div(other, self) - __rfloordiv__ = __rtruediv__ + def __mod__(self, other): + """Integer remainder.""" + if type(self).__name__.startswith('SecFld'): + return NotImplemented + elif type(other).__name__.startswith('SecFld'): + return NotImplemented + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented + # stub: only mod 2 + assert other.df.value == 2, 'Least significant bit only, for now!' + r = self.runtime.lsb(self) + return r + + def __rmod__(self, other): + """Integer remainder (with reflected arguments).""" + if type(self).__name__.startswith('SecFld'): + return NotImplemented + elif type(other).__name__.startswith('SecFld'): + return NotImplemented + # stub: only mod 2 + assert self.df.value == 2, 'Least significant bit only, for now!' + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented + r = self.runtime.lsb(other) + return r + + def __floordiv__(self, other): + """Integer quotient.""" + # stub: only div 2 + r = self.__mod__(other) + if r is NotImplemented: + return NotImplemented + other = self._coerce(other) # avoid coercing twice + if other is NotImplemented: + return NotImplemented + q = (self - r) / other.df + return q + + def __rfloordiv__(self, other): + """Integer quotient (with reflected arguments).""" + # stub: only div 2 + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented + r = other.__mod__(self) # avoid coercing twice + if r is NotImplemented: + return NotImplemented + q = (other - r) / self.df + return q + + def __divmod__(self, other): + """Integer division.""" + # stub: only divmod 2 + r = self.__mod__(other) + if r is NotImplemented: + return NotImplemented + other = self._coerce(other) # avoid coercing twice + if other is NotImplemented: + return NotImplemented + q = (self - r) / other.df + return q, r + + def __rdivmod__(self, other): + """Integer division (with reflected arguments).""" + # stub: only divmod 2 + other = self._coerce(other) + if other is NotImplemented: + return NotImplemented + r = other.__mod__(self) # avoid coercing twice + if r is NotImplemented: + return NotImplemented + q = (other - r) / self.df + return q, r - def __pow__(self, exponent): + def __pow__(self, other): """Exponentation with publicly known integer exponent.""" - return self.runtime.pow(self, exponent) + if not isinstance(other, int): + return NotImplemented + return self.runtime.pow(self, other) def __and__(self, other): """And 1-bit.""" @@ -187,7 +261,6 @@ def __ne__(self, other): c = self - other return 1 - self.runtime.is_zero(c) - _sectypes = {} def SecFld(p=None, l=None): diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 601ecbe0..22b8e2f9 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -29,14 +29,21 @@ def test_secfld(self): self.assertEqual(mpc.run(mpc.output(a == -b)), 1) self.assertEqual(mpc.run(mpc.output(a**2 == b**2)), 1) self.assertEqual(mpc.run(mpc.output(a != b)), 1) - with self.assertRaises(TypeError): - a < b - with self.assertRaises(TypeError): - a <= b - with self.assertRaises(TypeError): - a > b - with self.assertRaises(TypeError): - a >= b + for a, b in [(secfld(1), secfld(1)), (1, secfld(1)), (secfld(1), 1)]: + with self.assertRaises(TypeError): + a < b + with self.assertRaises(TypeError): + a <= b + with self.assertRaises(TypeError): + a > b + with self.assertRaises(TypeError): + a >= b + with self.assertRaises(TypeError): + a // b + with self.assertRaises(TypeError): + a % b + with self.assertRaises(TypeError): + divmod(a, b) b = mpc.random_bit(secfld) self.assertIn(mpc.run(mpc.output(b)), [0,1]) b = mpc.random_bit(secfld, signed=True) @@ -76,17 +83,19 @@ def test_secint(self): c = mpc.run(mpc.output(mpc.to_bits(secint(-2**31)))) self.assertEqual(c, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-2**31)))), 0) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-2**31 + 1)))), 1) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-1)))), 1) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(0)))), 0) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(1)))), 1) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(2**31 - 1)))), 1) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(5)))), 1) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-5)))), 1) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(50)))), 0) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secint(-50)))), 0) - + self.assertEqual(mpc.run(mpc.output(secint(-2**31) % 2)), 0) + self.assertEqual(mpc.run(mpc.output(secint(-2**31 + 1) % 2)), 1) + self.assertEqual(mpc.run(mpc.output(secint(-1) % 2)), 1) + self.assertEqual(mpc.run(mpc.output(secint(0) % 2)), 0) + self.assertEqual(mpc.run(mpc.output(secint(1) % 2)), 1) + self.assertEqual(mpc.run(mpc.output(secint(2**31 - 1) % 2)), 1) + self.assertEqual(mpc.run(mpc.output(secint(5) % 2)), 1) + self.assertEqual(mpc.run(mpc.output(secint(-5) % 2)), 1) + self.assertEqual(mpc.run(mpc.output(secint(50) % 2)), 0) + self.assertEqual(mpc.run(mpc.output(secint(-50) % 2)), 0) + self.assertEqual(mpc.run(mpc.output(secint(5) // 2)), 2) + self.assertEqual(mpc.run(mpc.output(secint(50) // 2)), 25) + self.assertEqual(mpc.run(mpc.output(secint(3)**73)), 3**73) b = mpc.random_bit(secint) self.assertIn(mpc.run(mpc.output(b)), [0,1]) @@ -171,6 +180,7 @@ def test_secfxp(self): ss2 = round(max(0, s[1]) * (1 << f)) self.assertEqual(mpc.run(mpc.output(mpc.max(0, y))), ss2) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secfxp(1)))), 0*(2**f)) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secfxp(1/2**f)))), 1*(2**f)) - self.assertEqual(mpc.run(mpc.output(mpc.lsb(secfxp(2/2**f)))), 0*(2**f)) + self.assertEqual(mpc.run(mpc.output(secfxp(1) % 2**(1-f))), 0*(2**f)) + self.assertEqual(mpc.run(mpc.output(secfxp(1/2**f) % 2**(1-f))), 1*(2**f)) + self.assertEqual(mpc.run(mpc.output(secfxp(2/2**f) % 2**(1-f))), 0*(2**f)) + self.assertEqual(mpc.run(mpc.output(secfxp(1) // 2**(1-f))), 2**(f-1))