diff --git a/crypten/__init__.py b/crypten/__init__.py index 91beb018..8c224cf9 100644 --- a/crypten/__init__.py +++ b/crypten/__init__.py @@ -345,7 +345,7 @@ def stack(tensors, dim=0): # Top level tensor functions -__PASSTHROUGH_FUNCTIONS = ["bernoulli", "rand", "randperm"] +__PASSTHROUGH_FUNCTIONS = ["bernoulli", "rand"] def __add_top_level_function(func_name): diff --git a/crypten/mpc/__init__.py b/crypten/mpc/__init__.py index f6539e58..ecdd3d87 100644 --- a/crypten/mpc/__init__.py +++ b/crypten/mpc/__init__.py @@ -83,17 +83,6 @@ def bernoulli(tensor): return rand(tensor.size()) < tensor -def randperm(size): - """ - Generate an MPCTensor with rows that contain values [1, 2, ... n] - where `n` is the length of each row (size[-1]) - """ - result = MPCTensor(None) - result._tensor = __default_provider.randperm(size) - result.ptype = ptype.arithmetic - return result - - # Set provider __SUPPORTED_PROVIDERS = { "TFP": provider.TrustedFirstParty, diff --git a/crypten/mpc/mpc.py b/crypten/mpc/mpc.py index 3c0277e2..090f8c63 100644 --- a/crypten/mpc/mpc.py +++ b/crypten/mpc/mpc.py @@ -140,14 +140,24 @@ def __setitem__(self, index, value): @property def share(self): - """Returns underlying _tensor""" + """Returns underlying share""" return self._tensor.share @share.setter def share(self, value): - """Sets _tensor to value""" + """Sets share to value""" self._tensor.share = value + @property + def encoder(self): + """Returns underlying encoder""" + return self._tensor.encoder + + @encoder.setter + def encoder(self, value): + """Sets encoder to value""" + self._tensor.encoder = value + def bernoulli(self): """Returns a tensor with elements in {0, 1}. The i-th element of the output will be 1 with probability according to the i-th value of the @@ -251,9 +261,9 @@ def _ltz(self, _scale=True): shift = torch.iinfo(torch.long).bits - 1 result = (self >> shift).to(Ptype.arithmetic, bits=1) if _scale: - return result * result._tensor.encoder._scale + return result * result.encoder._scale else: - result._tensor.encoder._scale = 1 + result.encoder._scale = 1 return result @mode(Ptype.arithmetic) @@ -303,6 +313,58 @@ def relu(self): """Compute a Rectified Linear function on the input tensor.""" return self * self.ge(0, _scale=False) + @mode(Ptype.arithmetic) + def weighted_index(self, dim=None): + """ + Returns a tensor with entries that are one-hot along dimension `dim`. + These one-hot entries are set at random with weights given by the input + `self`. + + Examples:: + + >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) + >>> index = encrypted_tensor.weighted_index().get_plain_text() + # With 1 / 7 probability + torch.tensor([1., 0.]) + + # With 6 / 7 probability + torch.tensor([0., 1.]) + """ + if dim is None: + return self.flatten().weighted_index(dim=0).view(self.size()) + + x = self.cumsum(dim) + max_weight = x.index_select(dim, torch.tensor(x.size(dim) - 1)) + r = crypten.mpc.rand(max_weight.size()) * max_weight + + gt = x.gt(r, _scale=False) + shifted = gt.roll(1, dims=dim) + shifted.share.index_fill_(dim, torch.tensor(0), 0) + + return gt - shifted + + @mode(Ptype.arithmetic) + def weighted_sample(self, dim=None): + """ + Samples a single value across dimension `dim` with weights corresponding + to the values in `self` + + Returns the sample and the one-hot index of the sample. + + Examples:: + + >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) + >>> index = encrypted_tensor.weighted_sample().get_plain_text() + # With 1 / 7 probability + (torch.tensor([1., 0.]), torch.tensor([1., 0.])) + + # With 6 / 7 probability + (torch.tensor([0., 6.]), torch.tensor([0., 1.])) + """ + indices = self.weighted_index(dim) + sample = self.mul(indices).sum(dim) + return sample, indices + # max / min-related functions def _argmax_helper(self, dim=None): """Returns 1 for all elements that have the highest value in the appropriate @@ -320,10 +382,24 @@ def _argmax_helper(self, dim=None): [self.roll(i + 1, dims=dim) for i in range(row_length - 1)] ) - # Sum of columns with all 1s will have value equal to (length - 1). - # Using >= since it requires 1-fewer comparrison than != - result = (a >= b).sum(dim=0) - return result >= (row_length - 1) + # Use either prod or sum & comparison depending on size + if row_length - 1 < torch.iinfo(torch.long).bits * 2: + pairwise_comparisons = a.ge(b, _scale=False) + result = pairwise_comparisons.prod(dim=0) + result.share *= self.encoder._scale + result.encoder = self.encoder + else: + # Sum of columns with all 1s will have value equal to (length - 1). + # Using ge() since it is slightly faster than eq() + pairwise_comparisons = a.ge(b) + result = pairwise_comparisons.sum(dim=0).ge(row_length - 1) + return result + + """ + pairwise_comparisons = a.ge(b, _scale=False) + + return result + """ @mode(Ptype.arithmetic) def argmax(self, dim=None, keepdim=False, one_hot=False): @@ -334,19 +410,10 @@ def argmax(self, dim=None, keepdim=False, one_hot=False): return MPCTensor(torch.ones(())) if one_hot else MPCTensor(torch.zeros(())) input = self.flatten() if dim is None else self - result = input._argmax_helper(dim) - # Multiply by a random permutation to give each maximum a random priority - randperm_size = [x for x in input.size()] - if dim is not None: - randperm_size[-1] = input.size(dim) - randperm_size[dim] = input.size(-1) - randperm = crypten.mpc.randperm(randperm_size) - if dim is not None: - randperm = randperm.transpose(dim, -1) - result *= randperm - result = result._argmax_helper(dim) + # Break ties by using a uniform weighted sample among tied indices + result = result.weighted_index(dim) result = result.view(self.size()) if dim is None else result return result if one_hot else _one_hot_to_index(result, dim, keepdim) @@ -1138,6 +1205,7 @@ def ib_wrapper_function(self, value, *args, **kwargs): "unfold", "flip", "trace", + "prod", "sum", "cumsum", "reshape", diff --git a/crypten/mpc/primitives/arithmetic.py b/crypten/mpc/primitives/arithmetic.py index 2ce4742c..a9c1a3fc 100644 --- a/crypten/mpc/primitives/arithmetic.py +++ b/crypten/mpc/primitives/arithmetic.py @@ -327,6 +327,31 @@ def matmul(self, y): """Perform matrix multiplication using some tensor""" return self._arithmetic_function(y, "matmul") + def prod(self, dim=None, keepdim=False): + """ + Returns the product of each row of the `input` tensor in the given + dimension `dim`. + + If `keepdim` is `True`, the output tensor is of the same size as `input` + except in the dimension `dim` where it is of size 1. Otherwise, `dim` is + squeezed, resulting in the output tensor having 1 fewer dimension than + `input`. + """ + if dim is None: + return self.flatten().prod(dim=0) + + result = self.clone() + while result.size(dim) > 1: + size = result.size(dim) + x, y, remainder = result.split([size // 2, size // 2, size % 2], dim=dim) + result = x.mul_(y) + result.share = torch.cat([result.share, remainder.share], dim=dim) + + # Squeeze result if necessary + if not keepdim: + result.share = result.share.squeeze(dim) + return result + def mean(self, *args, **kwargs): """Computes mean of given tensor""" result = self.sum(*args, **kwargs) diff --git a/crypten/mpc/provider/homomorphic_provider.py b/crypten/mpc/provider/homomorphic_provider.py index 74ee0592..baf16962 100644 --- a/crypten/mpc/provider/homomorphic_provider.py +++ b/crypten/mpc/provider/homomorphic_provider.py @@ -38,16 +38,3 @@ def B2A_rng(size): def rand(*sizes): """Generate random ArithmeticSharedTensor uniform on [0, 1]""" raise NotImplementedError("HomomorphicProvider not implemented") - - @staticmethod - def bernoulli(tensor): - """Generate random ArithmeticSharedTensor bernoulli on {0, 1}""" - raise NotImplementedError("HomomorphicProvider not implemented") - - @staticmethod - def randperm(tensor_size): - """ - Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of - the first `tensor_size[-1]` whole numbers - """ - raise NotImplementedError("HomomorphicProvider not implemented") diff --git a/crypten/mpc/provider/tfp_provider.py b/crypten/mpc/provider/tfp_provider.py index 8dc53863..ba31fa88 100644 --- a/crypten/mpc/provider/tfp_provider.py +++ b/crypten/mpc/provider/tfp_provider.py @@ -80,18 +80,3 @@ def rand(*sizes): """Generate random ArithmeticSharedTensor uniform on [0, 1]""" samples = torch.rand(*sizes) return ArithmeticSharedTensor(samples, src=0) - - @staticmethod - def randperm(tensor_size): - """ - Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of - the first `tensor_size[-1]` whole numbers - """ - tensor_len = tensor_size[-1] - nperms = int(torch.tensor(tensor_size[:-1]).prod().item()) - random_permutation = ( - torch.stack([torch.randperm(tensor_len) + 1 for _ in range(nperms)]) - .view(tensor_size) - .float() - ) - return ArithmeticSharedTensor(random_permutation, src=0) diff --git a/crypten/mpc/provider/ttp_provider.py b/crypten/mpc/provider/ttp_provider.py index 62675297..c276f2ca 100644 --- a/crypten/mpc/provider/ttp_provider.py +++ b/crypten/mpc/provider/ttp_provider.py @@ -18,7 +18,7 @@ from crypten.mpc.primitives import ArithmeticSharedTensor, BinarySharedTensor -TTP_FUNCTIONS = ["additive", "square", "binary", "wraps", "B2A", "rand", "randperm"] +TTP_FUNCTIONS = ["additive", "square", "binary", "wraps", "B2A", "rand"] class TrustedThirdParty: @@ -133,20 +133,6 @@ def rand(*sizes, encoder=None): samples = generate_random_ring_element(sizes, generator=generator) return ArithmeticSharedTensor.from_shares(samples) - @staticmethod - def randperm(tensor_size, encoder=None): - """ - Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of - the first `tensor_size[-1]` whole numbers - """ - generator = TTPClient.get().generator - if comm.get().get_rank() == 0: - # Request samples from TTP - samples = TTPClient.get().ttp_request("randperm", tensor_size) - else: - samples = generate_random_ring_element(tensor_size, generator=generator) - return ArithmeticSharedTensor.from_shares(samples) - @staticmethod def _init(): TTPClient._init() @@ -320,13 +306,3 @@ def rand(self, *sizes, encoder=None): r = encoder.encode(torch.rand(*sizes)) r = r - self._get_additive_PRSS(sizes, remove_rank=True) return r - - def randperm(self, tensor_size, encoder=None): - tensor_len = tensor_size[-1] - nperms = int(torch.tensor(tensor_size[:-1]).prod().item()) - random_permutation = torch.stack( - [torch.randperm(tensor_len) + 1 for _ in range(nperms)] - ).view(tensor_size) - return random_permutation - self._get_additive_PRSS( - random_permutation.size(), remove_rank=True - ) diff --git a/test/test_arithmetic.py b/test/test_arithmetic.py index e9267236..a25a3dcb 100644 --- a/test/test_arithmetic.py +++ b/test/test_arithmetic.py @@ -173,6 +173,7 @@ def test_arithmetic(self): self._check(encrypted_out, reference, "right mul failed") def test_sum(self): + """Tests sum reduction on encrypted tensor.""" tensor = get_random_test_tensor(size=(5, 100, 100), is_float=True) encrypted = ArithmeticSharedTensor(tensor) self._check(encrypted.sum(), tensor.sum(), "sum failed") @@ -184,6 +185,20 @@ def test_sum(self): encrypted_out = encrypted.sum(dim) self._check(encrypted_out, reference, "sum failed") + def test_prod(self): + """Tests prod reduction on encrypted tensor.""" + # Increaing size to reduce relative error due to quantization + tensor = get_random_test_tensor(size=(5, 5, 5), is_float=False) + encrypted = ArithmeticSharedTensor(tensor) + self._check(encrypted.prod(), tensor.prod().float(), "prod failed") + + for dim in [0, 1, 2]: + reference = tensor.prod(dim).float() + with self.benchmark(type="prod", dim=dim) as bench: + for _ in bench.iters: + encrypted_out = encrypted.prod(dim) + self._check(encrypted_out, reference, "prod failed") + def test_div(self): """Tests division of encrypted tensor by scalar.""" for function in ["div", "div_"]: diff --git a/test/test_mpc.py b/test/test_mpc.py index 8fc1fe60..5bde604b 100644 --- a/test/test_mpc.py +++ b/test/test_mpc.py @@ -195,6 +195,19 @@ def test_sum(self): encrypted_out = encrypted.sum(dim) self._check(encrypted_out, reference, "sum failed") + def test_prod(self): + """Tests prod reduction on encrypted tensor.""" + tensor = get_random_test_tensor(size=(5, 5), is_float=False) + encrypted = MPCTensor(tensor) + self._check(encrypted.prod(), tensor.prod().float(), "prod failed") + + for dim in [0, 1]: + reference = tensor.prod(dim).float() + with self.benchmark(type="prod", dim=dim) as bench: + for _ in bench.iters: + encrypted_out = encrypted.prod(dim) + self._check(encrypted_out, reference, "prod failed") + def test_div(self): """Tests division of encrypted tensor by scalar and tensor.""" for function in ["div", "div_"]: