Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable customization of party performing public computations #393

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions crypten/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ArithmeticSharedTensor(object):
the number of parties present in the protocol (world_size).
"""

PUBLIC_COMPUTE_PARTY = 0 # Party responsible for public addition.

# constructors:
def __init__(
self,
Expand Down Expand Up @@ -239,7 +241,7 @@ def pad(self, pad, mode="constant", value=0):
result = self.shallow_copy()
if isinstance(value, (int, float)):
value = self.encoder.encode(value).item()
if result.rank == 0:
if result.rank == self.PUBLIC_COMPUTE_PARTY:
result.share = torch.nn.functional.pad(
result.share, pad, mode=mode, value=value
)
Expand Down Expand Up @@ -362,7 +364,7 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C
y = result.encoder.encode(y, device=self.device)

if additive_func: # ['add', 'sub']
if result.rank == 0:
if result.rank == self.PUBLIC_COMPUTE_PARTY:
result.share = getattr(result.share, op)(y)
else:
result.share = torch.broadcast_tensors(result.share, y)[0]
Expand Down Expand Up @@ -514,7 +516,7 @@ def index_add_(self, dim, index, tensor):
private = isinstance(tensor, ArithmeticSharedTensor)
if public:
enc_tensor = self.encoder.encode(tensor)
if self.rank == 0:
if self.rank == self.PUBLIC_COMPUTE_PARTY:
self._tensor.index_add_(dim, index, enc_tensor)
elif private:
self._tensor.index_add_(dim, index, tensor._tensor)
Expand All @@ -541,7 +543,7 @@ def scatter_add_(self, dim, index, other):
public = isinstance(other, (int, float)) or is_tensor(other)
private = isinstance(other, ArithmeticSharedTensor)
if public:
if self.rank == 0:
if self.rank == self.PUBLIC_COMPUTE_PARTY:
self.share.scatter_add_(dim, index, self.encoder.encode(other))
elif private:
self.share.scatter_add_(dim, index, other.share)
Expand Down
6 changes: 4 additions & 2 deletions crypten/mpc/primitives/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class BinarySharedTensor(object):
where n is the number of parties present in the protocol (world_size).
"""

PUBLIC_COMPUTE_PARTY = 0 # Party responsible for public XOR and NOT.

def __init__(
self, tensor=None, size=None, broadcast_size=False, src=0, device=None
):
Expand Down Expand Up @@ -214,7 +216,7 @@ def __nonzero__(self):
def __ixor__(self, y):
"""Bitwise XOR operator (element-wise) in place"""
if is_tensor(y) or isinstance(y, int):
if self.rank == 0:
if self.rank == self.PUBLIC_COMPUTE_PARTY:
self.share ^= y
elif isinstance(y, BinarySharedTensor):
self.share ^= y.share
Expand Down Expand Up @@ -267,7 +269,7 @@ def __or__(self, y):
def __invert__(self):
"""Bitwise NOT operator (element-wise)"""
result = self.clone()
if result.rank == 0:
if result.rank == self.PUBLIC_COMPUTE_PARTY:
result.share ^= -1
return result

Expand Down
27 changes: 16 additions & 11 deletions crypten/mpc/provider/tfp_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class TrustedFirstParty(TupleProvider):
NAME = "TFP"
COMMUNICATING_PARTY = 0

def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwargs):
"""Generate multiplicative triples of given sizes"""
Expand All @@ -24,9 +25,9 @@ def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwarg

c = getattr(torch, op)(a, b, *args, **kwargs)

a = ArithmeticSharedTensor(a, precision=0, src=0)
b = ArithmeticSharedTensor(b, precision=0, src=0)
c = ArithmeticSharedTensor(c, precision=0, src=0)
a = ArithmeticSharedTensor(a, precision=0, src=self.COMMUNICATING_PARTY)
b = ArithmeticSharedTensor(b, precision=0, src=self.COMMUNICATING_PARTY)
c = ArithmeticSharedTensor(c, precision=0, src=self.COMMUNICATING_PARTY)

return a, b, c

Expand All @@ -37,7 +38,9 @@ def square(self, size, device=None):

# Stack to vectorize scatter function
stacked = torch_stack([r, r2])
stacked = ArithmeticSharedTensor(stacked, precision=0, src=0)
stacked = ArithmeticSharedTensor(
stacked, precision=0, src=self.COMMUNICATING_PARTY
)
return stacked[0], stacked[1]

def generate_binary_triple(self, size0, size1, device=None):
Expand All @@ -46,9 +49,9 @@ def generate_binary_triple(self, size0, size1, device=None):
b = generate_kbit_random_tensor(size1, device=device)
c = a & b

a = BinarySharedTensor(a, src=0)
b = BinarySharedTensor(b, src=0)
c = BinarySharedTensor(c, src=0)
a = BinarySharedTensor(a, src=self.COMMUNICATING_PARTY)
b = BinarySharedTensor(b, src=self.COMMUNICATING_PARTY)
c = BinarySharedTensor(c, src=self.COMMUNICATING_PARTY)

return a, b, c

Expand All @@ -61,9 +64,11 @@ def wrap_rng(self, size, device=None):
]
theta_r = count_wraps(r)

shares = comm.get().scatter(r, 0)
shares = comm.get().scatter(r, self.COMMUNICATING_PARTY)
r = ArithmeticSharedTensor.from_shares(shares, precision=0)
theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0)
theta_r = ArithmeticSharedTensor(
theta_r, precision=0, src=self.COMMUNICATING_PARTY
)

return r, theta_r

Expand All @@ -72,7 +77,7 @@ def B2A_rng(self, size, device=None):
# generate random bit
r = generate_kbit_random_tensor(size, bitlength=1, device=device)

rA = ArithmeticSharedTensor(r, precision=0, src=0)
rB = BinarySharedTensor(r, src=0)
rA = ArithmeticSharedTensor(r, precision=0, src=self.COMMUNICATING_PARTY)
rB = BinarySharedTensor(r, src=self.COMMUNICATING_PARTY)

return rA, rB
11 changes: 6 additions & 5 deletions crypten/mpc/provider/ttp_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@

class TrustedThirdParty(TupleProvider):
NAME = "TTP"
COMMUNICATING_PARTY = 0

def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwargs):
"""Generate multiplicative triples of given sizes"""
generator = TTPClient.get().get_generator(device=device)

a = generate_random_ring_element(size0, generator=generator, device=device)
b = generate_random_ring_element(size1, generator=generator, device=device)
if comm.get().get_rank() == 0:
if comm.get().get_rank() == self.COMMUNICATING_PARTY:
# Request c from TTP
c = TTPClient.get().ttp_request(
"additive", device, size0, size1, op, *args, **kwargs
Expand All @@ -51,7 +52,7 @@ def square(self, size, device=None):
generator = TTPClient.get().get_generator(device=device)

r = generate_random_ring_element(size, generator=generator, device=device)
if comm.get().get_rank() == 0:
if comm.get().get_rank() == self.COMMUNICATING_PARTY:
# Request r2 from TTP
r2 = TTPClient.get().ttp_request("square", device, size)
else:
Expand All @@ -68,7 +69,7 @@ def generate_binary_triple(self, size0, size1, device=None):
a = generate_kbit_random_tensor(size0, generator=generator, device=device)
b = generate_kbit_random_tensor(size1, generator=generator, device=device)

if comm.get().get_rank() == 0:
if comm.get().get_rank() == self.COMMUNICATING_PARTY:
# Request c from TTP
c = TTPClient.get().ttp_request("binary", device, size0, size1)
else:
Expand All @@ -86,7 +87,7 @@ def wrap_rng(self, size, device=None):
generator = TTPClient.get().get_generator(device=device)

r = generate_random_ring_element(size, generator=generator, device=device)
if comm.get().get_rank() == 0:
if comm.get().get_rank() == self.COMMUNICATING_PARTY:
# Request theta_r from TTP
theta_r = TTPClient.get().ttp_request("wraps", device, size)
else:
Expand All @@ -107,7 +108,7 @@ def B2A_rng(self, size, device=None):
size, bitlength=1, generator=generator, device=device
)

if comm.get().get_rank() == 0:
if comm.get().get_rank() == self.COMMUNICATING_PARTY:
# Request rA from TTP
rA = TTPClient.get().ttp_request("B2A", device, size)
else:
Expand Down
124 changes: 70 additions & 54 deletions test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,48 @@ def test_arithmetic(self):
arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"]
for func in arithmetic_functions:
for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
tensor1 = get_random_test_tensor(is_float=True)
tensor2 = get_random_test_tensor(is_float=True)
encrypted = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)
for public_party in range(2):

reference = getattr(tensor1, func)(tensor2)
encrypted_out = getattr(encrypted, func)(encrypted2)
private_type = tensor_type == ArithmeticSharedTensor
self._check(
encrypted_out,
reference,
"%s %s failed" % ("private" if private_type else "public", func),
)
if "_" in func:
# Check in-place op worked
self._check(
encrypted,
reference,
"%s %s failed"
% ("private" if private_type else "public", func),
# Vary which party performs public addition.
ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY = public_party
self.assertTrue(
ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY == public_party
)
else:
# Check original is not modified

tensor1 = get_random_test_tensor(is_float=True)
tensor2 = get_random_test_tensor(is_float=True)
encrypted = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)

reference = getattr(tensor1, func)(tensor2)
encrypted_out = getattr(encrypted, func)(encrypted2)
private_type = tensor_type == ArithmeticSharedTensor
self._check(
encrypted,
tensor1,
"%s %s failed"
% (
"private"
if tensor_type == ArithmeticSharedTensor
else "public",
func,
),
encrypted_out,
reference,
"%s %s failed" % ("private" if private_type else "public", func),
)
if "_" in func:
# Check in-place op worked
self._check(
encrypted,
reference,
"%s %s failed"
% ("private" if private_type else "public", func),
)
else:
# Check original is not modified
self._check(
encrypted,
tensor1,
"%s %s failed"
% (
"private"
if tensor_type == ArithmeticSharedTensor
else "public",
func,
),
)

# Check encrypted vector with encrypted scalar works.
tensor1 = get_random_test_tensor(is_float=True)
Expand Down Expand Up @@ -270,35 +278,43 @@ def test_index_add(self):
tensor_size2[dimension] = index.size(0)
for func in index_add_functions:
for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
tensor1 = get_random_test_tensor(size=tensor_size1, is_float=True)
tensor2 = get_random_test_tensor(size=tensor_size2, is_float=True)
encrypted = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)
for public_party in range(2):

reference = getattr(tensor1, func)(dimension, index, tensor2)
encrypted_out = getattr(encrypted, func)(
dimension, index, encrypted2
)
private = tensor_type == ArithmeticSharedTensor
self._check(
encrypted_out,
reference,
"%s %s failed" % ("private" if private else "public", func),
)
if func.endswith("_"):
# Check in-place index_add worked
self._check(
encrypted,
reference,
"%s %s failed" % ("private" if private else "public", func),
# Vary which party performs public addition.
ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY = public_party
self.assertTrue(
ArithmeticSharedTensor.PUBLIC_COMPUTE_PARTY == public_party
)
else:
# Check original is not modified

tensor1 = get_random_test_tensor(size=tensor_size1, is_float=True)
tensor2 = get_random_test_tensor(size=tensor_size2, is_float=True)
encrypted = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)

reference = getattr(tensor1, func)(dimension, index, tensor2)
encrypted_out = getattr(encrypted, func)(
dimension, index, encrypted2
)
private = tensor_type == ArithmeticSharedTensor
self._check(
encrypted,
tensor1,
encrypted_out,
reference,
"%s %s failed" % ("private" if private else "public", func),
)
if func.endswith("_"):
# Check in-place index_add worked
self._check(
encrypted,
reference,
"%s %s failed" % ("private" if private else "public", func),
)
else:
# Check original is not modified
self._check(
encrypted,
tensor1,
"%s %s failed" % ("private" if private else "public", func),
)

def test_scatter(self):
"""Test scatter/scatter_add function of encrypted tensor"""
Expand Down
Loading