Skip to content

Commit

Permalink
Implement weighted sampling (#175)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/CrypTen#175

Pull Request resolved: #41

-Implements weighted sampling and weighted indexing
- Utilizes weighted indexing for argmax

Reviewed By: vshobha

Differential Revision: D19551148

fbshipit-source-id: 9daaf244cb522f4c1fab713b1ac4e22ce040e5f7
  • Loading branch information
knottb authored and facebook-github-bot committed Feb 3, 2020
1 parent 633d22f commit 4ad37cb
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 84 deletions.
2 changes: 1 addition & 1 deletion crypten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def stack(tensors, dim=0):


# Top level tensor functions
__PASSTHROUGH_FUNCTIONS = ["bernoulli", "rand", "randperm"]
__PASSTHROUGH_FUNCTIONS = ["bernoulli", "rand"]


def __add_top_level_function(func_name):
Expand Down
11 changes: 0 additions & 11 deletions crypten/mpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,6 @@ def bernoulli(tensor):
return rand(tensor.size()) < tensor


def randperm(size):
"""
Generate an MPCTensor with rows that contain values [1, 2, ... n]
where `n` is the length of each row (size[-1])
"""
result = MPCTensor(None)
result._tensor = __default_provider.randperm(size)
result.ptype = ptype.arithmetic
return result


# Set provider
__SUPPORTED_PROVIDERS = {
"TFP": provider.TrustedFirstParty,
Expand Down
106 changes: 87 additions & 19 deletions crypten/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,24 @@ def __setitem__(self, index, value):

@property
def share(self):
"""Returns underlying _tensor"""
"""Returns underlying share"""
return self._tensor.share

@share.setter
def share(self, value):
"""Sets _tensor to value"""
"""Sets share to value"""
self._tensor.share = value

@property
def encoder(self):
"""Returns underlying encoder"""
return self._tensor.encoder

@encoder.setter
def encoder(self, value):
"""Sets encoder to value"""
self._tensor.encoder = value

def bernoulli(self):
"""Returns a tensor with elements in {0, 1}. The i-th element of the
output will be 1 with probability according to the i-th value of the
Expand Down Expand Up @@ -251,9 +261,9 @@ def _ltz(self, _scale=True):
shift = torch.iinfo(torch.long).bits - 1
result = (self >> shift).to(Ptype.arithmetic, bits=1)
if _scale:
return result * result._tensor.encoder._scale
return result * result.encoder._scale
else:
result._tensor.encoder._scale = 1
result.encoder._scale = 1
return result

@mode(Ptype.arithmetic)
Expand Down Expand Up @@ -303,6 +313,58 @@ def relu(self):
"""Compute a Rectified Linear function on the input tensor."""
return self * self.ge(0, _scale=False)

@mode(Ptype.arithmetic)
def weighted_index(self, dim=None):
"""
Returns a tensor with entries that are one-hot along dimension `dim`.
These one-hot entries are set at random with weights given by the input
`self`.
Examples::
>>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.]))
>>> index = encrypted_tensor.weighted_index().get_plain_text()
# With 1 / 7 probability
torch.tensor([1., 0.])
# With 6 / 7 probability
torch.tensor([0., 1.])
"""
if dim is None:
return self.flatten().weighted_index(dim=0).view(self.size())

x = self.cumsum(dim)
max_weight = x.index_select(dim, torch.tensor(x.size(dim) - 1))
r = crypten.mpc.rand(max_weight.size()) * max_weight

gt = x.gt(r, _scale=False)
shifted = gt.roll(1, dims=dim)
shifted.share.index_fill_(dim, torch.tensor(0), 0)

return gt - shifted

@mode(Ptype.arithmetic)
def weighted_sample(self, dim=None):
"""
Samples a single value across dimension `dim` with weights corresponding
to the values in `self`
Returns the sample and the one-hot index of the sample.
Examples::
>>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.]))
>>> index = encrypted_tensor.weighted_sample().get_plain_text()
# With 1 / 7 probability
(torch.tensor([1., 0.]), torch.tensor([1., 0.]))
# With 6 / 7 probability
(torch.tensor([0., 6.]), torch.tensor([0., 1.]))
"""
indices = self.weighted_index(dim)
sample = self.mul(indices).sum(dim)
return sample, indices

# max / min-related functions
def _argmax_helper(self, dim=None):
"""Returns 1 for all elements that have the highest value in the appropriate
Expand All @@ -320,10 +382,24 @@ def _argmax_helper(self, dim=None):
[self.roll(i + 1, dims=dim) for i in range(row_length - 1)]
)

# Sum of columns with all 1s will have value equal to (length - 1).
# Using >= since it requires 1-fewer comparrison than !=
result = (a >= b).sum(dim=0)
return result >= (row_length - 1)
# Use either prod or sum & comparison depending on size
if row_length - 1 < torch.iinfo(torch.long).bits * 2:
pairwise_comparisons = a.ge(b, _scale=False)
result = pairwise_comparisons.prod(dim=0)
result.share *= self.encoder._scale
result.encoder = self.encoder
else:
# Sum of columns with all 1s will have value equal to (length - 1).
# Using ge() since it is slightly faster than eq()
pairwise_comparisons = a.ge(b)
result = pairwise_comparisons.sum(dim=0).ge(row_length - 1)
return result

"""
pairwise_comparisons = a.ge(b, _scale=False)
return result
"""

@mode(Ptype.arithmetic)
def argmax(self, dim=None, keepdim=False, one_hot=False):
Expand All @@ -334,19 +410,10 @@ def argmax(self, dim=None, keepdim=False, one_hot=False):
return MPCTensor(torch.ones(())) if one_hot else MPCTensor(torch.zeros(()))

input = self.flatten() if dim is None else self

result = input._argmax_helper(dim)

# Multiply by a random permutation to give each maximum a random priority
randperm_size = [x for x in input.size()]
if dim is not None:
randperm_size[-1] = input.size(dim)
randperm_size[dim] = input.size(-1)
randperm = crypten.mpc.randperm(randperm_size)
if dim is not None:
randperm = randperm.transpose(dim, -1)
result *= randperm
result = result._argmax_helper(dim)
# Break ties by using a uniform weighted sample among tied indices
result = result.weighted_index(dim)

result = result.view(self.size()) if dim is None else result
return result if one_hot else _one_hot_to_index(result, dim, keepdim)
Expand Down Expand Up @@ -1138,6 +1205,7 @@ def ib_wrapper_function(self, value, *args, **kwargs):
"unfold",
"flip",
"trace",
"prod",
"sum",
"cumsum",
"reshape",
Expand Down
25 changes: 25 additions & 0 deletions crypten/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,31 @@ def matmul(self, y):
"""Perform matrix multiplication using some tensor"""
return self._arithmetic_function(y, "matmul")

def prod(self, dim=None, keepdim=False):
"""
Returns the product of each row of the `input` tensor in the given
dimension `dim`.
If `keepdim` is `True`, the output tensor is of the same size as `input`
except in the dimension `dim` where it is of size 1. Otherwise, `dim` is
squeezed, resulting in the output tensor having 1 fewer dimension than
`input`.
"""
if dim is None:
return self.flatten().prod(dim=0)

result = self.clone()
while result.size(dim) > 1:
size = result.size(dim)
x, y, remainder = result.split([size // 2, size // 2, size % 2], dim=dim)
result = x.mul_(y)
result.share = torch.cat([result.share, remainder.share], dim=dim)

# Squeeze result if necessary
if not keepdim:
result.share = result.share.squeeze(dim)
return result

def mean(self, *args, **kwargs):
"""Computes mean of given tensor"""
result = self.sum(*args, **kwargs)
Expand Down
13 changes: 0 additions & 13 deletions crypten/mpc/provider/homomorphic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,3 @@ def B2A_rng(size):
def rand(*sizes):
"""Generate random ArithmeticSharedTensor uniform on [0, 1]"""
raise NotImplementedError("HomomorphicProvider not implemented")

@staticmethod
def bernoulli(tensor):
"""Generate random ArithmeticSharedTensor bernoulli on {0, 1}"""
raise NotImplementedError("HomomorphicProvider not implemented")

@staticmethod
def randperm(tensor_size):
"""
Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of
the first `tensor_size[-1]` whole numbers
"""
raise NotImplementedError("HomomorphicProvider not implemented")
15 changes: 0 additions & 15 deletions crypten/mpc/provider/tfp_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,3 @@ def rand(*sizes):
"""Generate random ArithmeticSharedTensor uniform on [0, 1]"""
samples = torch.rand(*sizes)
return ArithmeticSharedTensor(samples, src=0)

@staticmethod
def randperm(tensor_size):
"""
Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of
the first `tensor_size[-1]` whole numbers
"""
tensor_len = tensor_size[-1]
nperms = int(torch.tensor(tensor_size[:-1]).prod().item())
random_permutation = (
torch.stack([torch.randperm(tensor_len) + 1 for _ in range(nperms)])
.view(tensor_size)
.float()
)
return ArithmeticSharedTensor(random_permutation, src=0)
26 changes: 1 addition & 25 deletions crypten/mpc/provider/ttp_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from crypten.mpc.primitives import ArithmeticSharedTensor, BinarySharedTensor


TTP_FUNCTIONS = ["additive", "square", "binary", "wraps", "B2A", "rand", "randperm"]
TTP_FUNCTIONS = ["additive", "square", "binary", "wraps", "B2A", "rand"]


class TrustedThirdParty:
Expand Down Expand Up @@ -133,20 +133,6 @@ def rand(*sizes, encoder=None):
samples = generate_random_ring_element(sizes, generator=generator)
return ArithmeticSharedTensor.from_shares(samples)

@staticmethod
def randperm(tensor_size, encoder=None):
"""
Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of
the first `tensor_size[-1]` whole numbers
"""
generator = TTPClient.get().generator
if comm.get().get_rank() == 0:
# Request samples from TTP
samples = TTPClient.get().ttp_request("randperm", tensor_size)
else:
samples = generate_random_ring_element(tensor_size, generator=generator)
return ArithmeticSharedTensor.from_shares(samples)

@staticmethod
def _init():
TTPClient._init()
Expand Down Expand Up @@ -320,13 +306,3 @@ def rand(self, *sizes, encoder=None):
r = encoder.encode(torch.rand(*sizes))
r = r - self._get_additive_PRSS(sizes, remove_rank=True)
return r

def randperm(self, tensor_size, encoder=None):
tensor_len = tensor_size[-1]
nperms = int(torch.tensor(tensor_size[:-1]).prod().item())
random_permutation = torch.stack(
[torch.randperm(tensor_len) + 1 for _ in range(nperms)]
).view(tensor_size)
return random_permutation - self._get_additive_PRSS(
random_permutation.size(), remove_rank=True
)
15 changes: 15 additions & 0 deletions test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def test_arithmetic(self):
self._check(encrypted_out, reference, "right mul failed")

def test_sum(self):
"""Tests sum reduction on encrypted tensor."""
tensor = get_random_test_tensor(size=(5, 100, 100), is_float=True)
encrypted = ArithmeticSharedTensor(tensor)
self._check(encrypted.sum(), tensor.sum(), "sum failed")
Expand All @@ -184,6 +185,20 @@ def test_sum(self):
encrypted_out = encrypted.sum(dim)
self._check(encrypted_out, reference, "sum failed")

def test_prod(self):
"""Tests prod reduction on encrypted tensor."""
# Increaing size to reduce relative error due to quantization
tensor = get_random_test_tensor(size=(5, 5, 5), is_float=False)
encrypted = ArithmeticSharedTensor(tensor)
self._check(encrypted.prod(), tensor.prod().float(), "prod failed")

for dim in [0, 1, 2]:
reference = tensor.prod(dim).float()
with self.benchmark(type="prod", dim=dim) as bench:
for _ in bench.iters:
encrypted_out = encrypted.prod(dim)
self._check(encrypted_out, reference, "prod failed")

def test_div(self):
"""Tests division of encrypted tensor by scalar."""
for function in ["div", "div_"]:
Expand Down
13 changes: 13 additions & 0 deletions test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ def test_sum(self):
encrypted_out = encrypted.sum(dim)
self._check(encrypted_out, reference, "sum failed")

def test_prod(self):
"""Tests prod reduction on encrypted tensor."""
tensor = get_random_test_tensor(size=(5, 5), is_float=False)
encrypted = MPCTensor(tensor)
self._check(encrypted.prod(), tensor.prod().float(), "prod failed")

for dim in [0, 1]:
reference = tensor.prod(dim).float()
with self.benchmark(type="prod", dim=dim) as bench:
for _ in bench.iters:
encrypted_out = encrypted.prod(dim)
self._check(encrypted_out, reference, "prod failed")

def test_div(self):
"""Tests division of encrypted tensor by scalar and tensor."""
for function in ["div", "div_"]:
Expand Down

0 comments on commit 4ad37cb

Please sign in to comment.