diff --git a/crypten/mpc/primitives/arithmetic.py b/crypten/mpc/primitives/arithmetic.py index 820fdad1..2a675dc1 100644 --- a/crypten/mpc/primitives/arithmetic.py +++ b/crypten/mpc/primitives/arithmetic.py @@ -34,6 +34,8 @@ class ArithmeticSharedTensor(object): the number of parties present in the protocol (world_size). """ + PUBLIC_COMPUTE_PARTY = 0 # Party responsible for public addition. + # constructors: def __init__( self, @@ -239,7 +241,7 @@ def pad(self, pad, mode="constant", value=0): result = self.shallow_copy() if isinstance(value, (int, float)): value = self.encoder.encode(value).item() - if result.rank == 0: + if result.rank == self.PUBLIC_COMPUTE_PARTY: result.share = torch.nn.functional.pad( result.share, pad, mode=mode, value=value ) @@ -362,7 +364,7 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C y = result.encoder.encode(y, device=self.device) if additive_func: # ['add', 'sub'] - if result.rank == 0: + if result.rank == self.PUBLIC_COMPUTE_PARTY: result.share = getattr(result.share, op)(y) else: result.share = torch.broadcast_tensors(result.share, y)[0] @@ -514,7 +516,7 @@ def index_add_(self, dim, index, tensor): private = isinstance(tensor, ArithmeticSharedTensor) if public: enc_tensor = self.encoder.encode(tensor) - if self.rank == 0: + if self.rank == self.PUBLIC_COMPUTE_PARTY: self._tensor.index_add_(dim, index, enc_tensor) elif private: self._tensor.index_add_(dim, index, tensor._tensor) @@ -541,7 +543,7 @@ def scatter_add_(self, dim, index, other): public = isinstance(other, (int, float)) or is_tensor(other) private = isinstance(other, ArithmeticSharedTensor) if public: - if self.rank == 0: + if self.rank == self.PUBLIC_COMPUTE_PARTY: self.share.scatter_add_(dim, index, self.encoder.encode(other)) elif private: self.share.scatter_add_(dim, index, other.share) diff --git a/crypten/mpc/primitives/binary.py b/crypten/mpc/primitives/binary.py index 9a2ac779..ac299c4e 100644 --- a/crypten/mpc/primitives/binary.py +++ b/crypten/mpc/primitives/binary.py @@ -32,6 +32,8 @@ class BinarySharedTensor(object): where n is the number of parties present in the protocol (world_size). """ + PUBLIC_COMPUTE_PARTY = 0 # Party responsible for public XOR and NOT. + def __init__( self, tensor=None, size=None, broadcast_size=False, src=0, device=None ): @@ -214,7 +216,7 @@ def __nonzero__(self): def __ixor__(self, y): """Bitwise XOR operator (element-wise) in place""" if is_tensor(y) or isinstance(y, int): - if self.rank == 0: + if self.rank == self.PUBLIC_COMPUTE_PARTY: self.share ^= y elif isinstance(y, BinarySharedTensor): self.share ^= y.share @@ -267,7 +269,7 @@ def __or__(self, y): def __invert__(self): """Bitwise NOT operator (element-wise)""" result = self.clone() - if result.rank == 0: + if result.rank == self.PUBLIC_COMPUTE_PARTY: result.share ^= -1 return result diff --git a/crypten/mpc/provider/tfp_provider.py b/crypten/mpc/provider/tfp_provider.py index b7163d57..b1151f74 100644 --- a/crypten/mpc/provider/tfp_provider.py +++ b/crypten/mpc/provider/tfp_provider.py @@ -16,6 +16,7 @@ class TrustedFirstParty(TupleProvider): NAME = "TFP" + COMMUNICATING_PARTY = 0 def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwargs): """Generate multiplicative triples of given sizes""" @@ -24,9 +25,9 @@ def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwarg c = getattr(torch, op)(a, b, *args, **kwargs) - a = ArithmeticSharedTensor(a, precision=0, src=0) - b = ArithmeticSharedTensor(b, precision=0, src=0) - c = ArithmeticSharedTensor(c, precision=0, src=0) + a = ArithmeticSharedTensor(a, precision=0, src=self.COMMUNICATING_PARTY) + b = ArithmeticSharedTensor(b, precision=0, src=self.COMMUNICATING_PARTY) + c = ArithmeticSharedTensor(c, precision=0, src=self.COMMUNICATING_PARTY) return a, b, c @@ -37,7 +38,9 @@ def square(self, size, device=None): # Stack to vectorize scatter function stacked = torch_stack([r, r2]) - stacked = ArithmeticSharedTensor(stacked, precision=0, src=0) + stacked = ArithmeticSharedTensor( + stacked, precision=0, src=self.COMMUNICATING_PARTY + ) return stacked[0], stacked[1] def generate_binary_triple(self, size0, size1, device=None): @@ -46,9 +49,9 @@ def generate_binary_triple(self, size0, size1, device=None): b = generate_kbit_random_tensor(size1, device=device) c = a & b - a = BinarySharedTensor(a, src=0) - b = BinarySharedTensor(b, src=0) - c = BinarySharedTensor(c, src=0) + a = BinarySharedTensor(a, src=self.COMMUNICATING_PARTY) + b = BinarySharedTensor(b, src=self.COMMUNICATING_PARTY) + c = BinarySharedTensor(c, src=self.COMMUNICATING_PARTY) return a, b, c @@ -61,9 +64,11 @@ def wrap_rng(self, size, device=None): ] theta_r = count_wraps(r) - shares = comm.get().scatter(r, 0) + shares = comm.get().scatter(r, self.COMMUNICATING_PARTY) r = ArithmeticSharedTensor.from_shares(shares, precision=0) - theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0) + theta_r = ArithmeticSharedTensor( + theta_r, precision=0, src=self.COMMUNICATING_PARTY + ) return r, theta_r @@ -72,7 +77,7 @@ def B2A_rng(self, size, device=None): # generate random bit r = generate_kbit_random_tensor(size, bitlength=1, device=device) - rA = ArithmeticSharedTensor(r, precision=0, src=0) - rB = BinarySharedTensor(r, src=0) + rA = ArithmeticSharedTensor(r, precision=0, src=self.COMMUNICATING_PARTY) + rB = BinarySharedTensor(r, src=self.COMMUNICATING_PARTY) return rA, rB diff --git a/crypten/mpc/provider/ttp_provider.py b/crypten/mpc/provider/ttp_provider.py index 5a28d5da..f04afb23 100644 --- a/crypten/mpc/provider/ttp_provider.py +++ b/crypten/mpc/provider/ttp_provider.py @@ -23,6 +23,7 @@ class TrustedThirdParty(TupleProvider): NAME = "TTP" + COMMUNICATING_PARTY = 0 def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwargs): """Generate multiplicative triples of given sizes""" @@ -30,7 +31,7 @@ def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwarg a = generate_random_ring_element(size0, generator=generator, device=device) b = generate_random_ring_element(size1, generator=generator, device=device) - if comm.get().get_rank() == 0: + if comm.get().get_rank() == self.COMMUNICATING_PARTY: # Request c from TTP c = TTPClient.get().ttp_request( "additive", device, size0, size1, op, *args, **kwargs @@ -51,7 +52,7 @@ def square(self, size, device=None): generator = TTPClient.get().get_generator(device=device) r = generate_random_ring_element(size, generator=generator, device=device) - if comm.get().get_rank() == 0: + if comm.get().get_rank() == self.COMMUNICATING_PARTY: # Request r2 from TTP r2 = TTPClient.get().ttp_request("square", device, size) else: @@ -68,7 +69,7 @@ def generate_binary_triple(self, size0, size1, device=None): a = generate_kbit_random_tensor(size0, generator=generator, device=device) b = generate_kbit_random_tensor(size1, generator=generator, device=device) - if comm.get().get_rank() == 0: + if comm.get().get_rank() == self.COMMUNICATING_PARTY: # Request c from TTP c = TTPClient.get().ttp_request("binary", device, size0, size1) else: @@ -86,7 +87,7 @@ def wrap_rng(self, size, device=None): generator = TTPClient.get().get_generator(device=device) r = generate_random_ring_element(size, generator=generator, device=device) - if comm.get().get_rank() == 0: + if comm.get().get_rank() == self.COMMUNICATING_PARTY: # Request theta_r from TTP theta_r = TTPClient.get().ttp_request("wraps", device, size) else: @@ -107,7 +108,7 @@ def B2A_rng(self, size, device=None): size, bitlength=1, generator=generator, device=device ) - if comm.get().get_rank() == 0: + if comm.get().get_rank() == self.COMMUNICATING_PARTY: # Request rA from TTP rA = TTPClient.get().ttp_request("B2A", device, size) else: diff --git a/test/test_arithmetic.py b/test/test_arithmetic.py index b229fb6b..568f759d 100644 --- a/test/test_arithmetic.py +++ b/test/test_arithmetic.py @@ -120,40 +120,48 @@ def test_arithmetic(self): arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"] for func in arithmetic_functions: for tensor_type in [lambda x: x, ArithmeticSharedTensor]: - tensor1 = get_random_test_tensor(is_float=True) - tensor2 = get_random_test_tensor(is_float=True) - encrypted = ArithmeticSharedTensor(tensor1) - encrypted2 = tensor_type(tensor2) + for public_party in range(2): - reference = getattr(tensor1, func)(tensor2) - encrypted_out = getattr(encrypted, func)(encrypted2) - private_type = tensor_type == ArithmeticSharedTensor - self._check( - encrypted_out, - reference, - "%s %s failed" % ("private" if private_type else "public", func), - ) - if "_" in func: - # Check in-place op worked - self._check( - encrypted, - reference, - "%s %s failed" - % ("private" if private_type else "public", func), + # Vary which party performs public addition. + ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY = public_party + self.assertTrue( + ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY == public_party ) - else: - # Check original is not modified + + tensor1 = get_random_test_tensor(is_float=True) + tensor2 = get_random_test_tensor(is_float=True) + encrypted = ArithmeticSharedTensor(tensor1) + encrypted2 = tensor_type(tensor2) + + reference = getattr(tensor1, func)(tensor2) + encrypted_out = getattr(encrypted, func)(encrypted2) + private_type = tensor_type == ArithmeticSharedTensor self._check( - encrypted, - tensor1, - "%s %s failed" - % ( - "private" - if tensor_type == ArithmeticSharedTensor - else "public", - func, - ), + encrypted_out, + reference, + "%s %s failed" % ("private" if private_type else "public", func), ) + if "_" in func: + # Check in-place op worked + self._check( + encrypted, + reference, + "%s %s failed" + % ("private" if private_type else "public", func), + ) + else: + # Check original is not modified + self._check( + encrypted, + tensor1, + "%s %s failed" + % ( + "private" + if tensor_type == ArithmeticSharedTensor + else "public", + func, + ), + ) # Check encrypted vector with encrypted scalar works. tensor1 = get_random_test_tensor(is_float=True) @@ -270,35 +278,43 @@ def test_index_add(self): tensor_size2[dimension] = index.size(0) for func in index_add_functions: for tensor_type in [lambda x: x, ArithmeticSharedTensor]: - tensor1 = get_random_test_tensor(size=tensor_size1, is_float=True) - tensor2 = get_random_test_tensor(size=tensor_size2, is_float=True) - encrypted = ArithmeticSharedTensor(tensor1) - encrypted2 = tensor_type(tensor2) + for public_party in range(2): - reference = getattr(tensor1, func)(dimension, index, tensor2) - encrypted_out = getattr(encrypted, func)( - dimension, index, encrypted2 - ) - private = tensor_type == ArithmeticSharedTensor - self._check( - encrypted_out, - reference, - "%s %s failed" % ("private" if private else "public", func), - ) - if func.endswith("_"): - # Check in-place index_add worked - self._check( - encrypted, - reference, - "%s %s failed" % ("private" if private else "public", func), + # Vary which party performs public addition. + ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY = public_party + self.assertTrue( + ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY == public_party ) - else: - # Check original is not modified + + tensor1 = get_random_test_tensor(size=tensor_size1, is_float=True) + tensor2 = get_random_test_tensor(size=tensor_size2, is_float=True) + encrypted = ArithmeticSharedTensor(tensor1) + encrypted2 = tensor_type(tensor2) + + reference = getattr(tensor1, func)(dimension, index, tensor2) + encrypted_out = getattr(encrypted, func)( + dimension, index, encrypted2 + ) + private = tensor_type == ArithmeticSharedTensor self._check( - encrypted, - tensor1, + encrypted_out, + reference, "%s %s failed" % ("private" if private else "public", func), ) + if func.endswith("_"): + # Check in-place index_add worked + self._check( + encrypted, + reference, + "%s %s failed" % ("private" if private else "public", func), + ) + else: + # Check original is not modified + self._check( + encrypted, + tensor1, + "%s %s failed" % ("private" if private else "public", func), + ) def test_scatter(self): """Test scatter/scatter_add function of encrypted tensor""" diff --git a/test/test_binary.py b/test/test_binary.py index 1413284a..27cff54d 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -153,13 +153,21 @@ def test_permute(self): def test_XOR(self): """Test bitwise-XOR function on BinarySharedTensor""" for tensor_type in [lambda x: x, BinarySharedTensor]: - tensor = get_random_test_tensor(is_float=False) - tensor2 = get_random_test_tensor(is_float=False) - reference = tensor ^ tensor2 - encrypted_tensor = BinarySharedTensor(tensor) - encrypted_tensor2 = tensor_type(tensor2) - encrypted_out = encrypted_tensor ^ encrypted_tensor2 - self._check(encrypted_out, reference, "%s XOR failed" % tensor_type) + for public_party in range(2): + + # Vary which party performs public XOR. + BinarySharedTensor.PUBLIC_COMPUTE_PARTY = public_party + self.assertTrue( + BinarySharedTensor.PUBLIC_COMPUTE_PARTY == public_party + ) + + tensor = get_random_test_tensor(is_float=False) + tensor2 = get_random_test_tensor(is_float=False) + reference = tensor ^ tensor2 + encrypted_tensor = BinarySharedTensor(tensor) + encrypted_tensor2 = tensor_type(tensor2) + encrypted_out = encrypted_tensor ^ encrypted_tensor2 + self._check(encrypted_out, reference, "%s XOR failed" % tensor_type) def test_AND(self): """Test bitwise-AND function on BinarySharedTensor""" @@ -229,11 +237,20 @@ def test_bitwise_broadcasting(self): def test_invert(self): """Test bitwise-invert function on BinarySharedTensor""" - tensor = get_random_test_tensor(is_float=False) - encrypted_tensor = BinarySharedTensor(tensor) - reference = ~tensor - encrypted_out = ~encrypted_tensor - self._check(encrypted_out, reference, "invert failed") + for public_party in range(2): + + # Vary which party performs NOT operation. + BinarySharedTensor.PUBLIC_COMPUTE_PARTY = public_party + self.assertTrue( + BinarySharedTensor.PUBLIC_COMPUTE_PARTY == public_party + ) + + # Perform inversion test. + tensor = get_random_test_tensor(is_float=False) + encrypted_tensor = BinarySharedTensor(tensor) + reference = ~tensor + encrypted_out = ~encrypted_tensor + self._check(encrypted_out, reference, "invert failed") def test_add(self): """Tests add using binary shares""" @@ -332,52 +349,60 @@ def test_inplace(self): """Test inplace vs. out-of-place functions""" for op in ["__xor__", "__and__", "__or__"]: for tensor_type in [lambda x: x, BinarySharedTensor]: - tensor1 = get_random_test_tensor(is_float=False) - tensor2 = get_random_test_tensor(is_float=False) + for public_party in range(2): - reference = getattr(tensor1, op)(tensor2) + # Vary which party performs public XOR and NOT. + BinarySharedTensor.PUBLIC_COMPUTE_PARTY = public_party + self.assertTrue( + BinarySharedTensor.PUBLIC_COMPUTE_PARTY == public_party + ) - encrypted1 = BinarySharedTensor(tensor1) - encrypted2 = tensor_type(tensor2) + tensor1 = get_random_test_tensor(is_float=False) + tensor2 = get_random_test_tensor(is_float=False) - input_plain_id = id(encrypted1.share) - input_encrypted_id = id(encrypted1) + reference = getattr(tensor1, op)(tensor2) - # Test that out-of-place functions do not modify the input - private = isinstance(encrypted2, BinarySharedTensor) - encrypted_out = getattr(encrypted1, op)(encrypted2) - self._check( - encrypted1, - tensor1, - "%s out-of-place %s modifies input" - % ("private" if private else "public", op), - ) - self._check( - encrypted_out, - reference, - "%s out-of-place %s produces incorrect output" - % ("private" if private else "public", op), - ) - self.assertFalse(id(encrypted_out.share) == input_plain_id) - self.assertFalse(id(encrypted_out) == input_encrypted_id) + encrypted1 = BinarySharedTensor(tensor1) + encrypted2 = tensor_type(tensor2) - # Test that in-place functions modify the input - inplace_op = op[:2] + "i" + op[2:] - encrypted_out = getattr(encrypted1, inplace_op)(encrypted2) - self._check( - encrypted1, - reference, - "%s in-place %s does not modify input" - % ("private" if private else "public", inplace_op), - ) - self._check( - encrypted_out, - reference, - "%s in-place %s produces incorrect output" - % ("private" if private else "public", inplace_op), - ) - self.assertTrue(id(encrypted_out.share) == input_plain_id) - self.assertTrue(id(encrypted_out) == input_encrypted_id) + input_plain_id = id(encrypted1.share) + input_encrypted_id = id(encrypted1) + + # Test that out-of-place functions do not modify the input + private = isinstance(encrypted2, BinarySharedTensor) + encrypted_out = getattr(encrypted1, op)(encrypted2) + self._check( + encrypted1, + tensor1, + "%s out-of-place %s modifies input" + % ("private" if private else "public", op), + ) + self._check( + encrypted_out, + reference, + "%s out-of-place %s produces incorrect output" + % ("private" if private else "public", op), + ) + self.assertFalse(id(encrypted_out.share) == input_plain_id) + self.assertFalse(id(encrypted_out) == input_encrypted_id) + + # Test that in-place functions modify the input + inplace_op = op[:2] + "i" + op[2:] + encrypted_out = getattr(encrypted1, inplace_op)(encrypted2) + self._check( + encrypted1, + reference, + "%s in-place %s does not modify input" + % ("private" if private else "public", inplace_op), + ) + self._check( + encrypted_out, + reference, + "%s in-place %s produces incorrect output" + % ("private" if private else "public", inplace_op), + ) + self.assertTrue(id(encrypted_out.share) == input_plain_id) + self.assertTrue(id(encrypted_out) == input_encrypted_id) def test_control_flow_failure(self): """Tests that control flow fails as expected"""