Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in integer division #345

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions crypten/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def encode_(self, new_encoder):
self.share *= scale_factor
else:
scale_factor = self.encoder.scale // new_encoder.scale
self = self.div_(scale_factor)
self = self._truncate_(scale=scale_factor)
self.encoder = new_encoder
return self

Expand Down Expand Up @@ -390,12 +390,12 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C
if not additive_func:
if public: # scale by self.encoder.scale
if self.encoder.scale > 1:
return result.div_(result.encoder.scale)
return result._truncate_()
else:
result.encoder = self.encoder
else: # scale by larger of self.encoder.scale and y.encoder.scale
if self.encoder.scale > 1 and y.encoder.scale > 1:
return result.div_(result.encoder.scale)
return result._truncate_()
elif self.encoder.scale > 1:
result.encoder = self.encoder
else:
Expand Down Expand Up @@ -443,6 +443,19 @@ def div(self, y):
result.share = torch.broadcast_tensors(result.share, y)[0].clone()
return result.div_(y)

def _truncate_(self, scale=None):
"""Rescales the result of a multiplication by dividing the input by the input scale"""
if scale is None:
scale = self.encoder._scale

# Truncate protocol for dividing by public integers:
if comm.get().get_world_size() > 2:
protocol = globals()[cfg.mpc.protocol]
self.share = protocol.truncate(self, scale).share
else:
self.share = self.share.div_(scale, rounding_mode="trunc")
return self

def div_(self, y):
"""Divide two tensors element-wise"""
# TODO: Add test coverage for this code path (next 4 lines)
Expand All @@ -458,12 +471,13 @@ def div_(self, y):
tolerance = 1.0
tensor = self.get_plain_text()

# Truncate protocol for dividing by public integers:
if comm.get().get_world_size() > 2:
protocol = globals()[cfg.mpc.protocol]
protocol.truncate(self, y)
else:
self.share = self.share.div_(y, rounding_mode="trunc")
# Re-encode if input has low precision
encoder = FixedPointEncoder()
if self.encoder._scale < encoder._scale:
self.encode_(encoder)

# Use truncate protocol for dividing by public integers:
self._truncate_(scale=y)

# Validate
if validate:
Expand Down Expand Up @@ -596,7 +610,8 @@ def neg(self):

def square_(self):
protocol = globals()[cfg.mpc.protocol]
self.share = protocol.square(self).div_(self.encoder.scale).share
square = protocol.square(self)
self.share = square._truncate_(scale=self.encoder._scale).share
return self

def square(self):
Expand Down
3 changes: 0 additions & 3 deletions test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
import unittest

import crypten
import crypten.communicator as comm
import torch
import torch.nn.functional as F
from crypten.common.rng import generate_random_ring_element
from crypten.common.tensor_types import is_float_tensor
from crypten.common.util import count_wraps
from crypten.mpc.primitives import ArithmeticSharedTensor
from test.multiprocess_test_case import MultiProcessTestCase, get_random_test_tensor

Expand Down
13 changes: 10 additions & 3 deletions test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,22 @@ def test_div(self):

reference = tensor.float().div(scalar)
encrypted_tensor = MPCTensor(tensor)
encrypted_tensor = getattr(encrypted_tensor, function)(scalar)
encrypted_out = getattr(encrypted_tensor, function)(scalar)
self._check(encrypted_tensor, reference, "scalar division failed")

# multiply denominator by 10 to avoid dividing by small num
divisor = self._get_random_test_tensor(is_float=True, ex_zero=True) * 10
reference = tensor.div(divisor)
encrypted_tensor = MPCTensor(tensor)
encrypted_tensor = getattr(encrypted_tensor, function)(divisor)
self._check(encrypted_tensor, reference, "tensor division failed")
encrypted_out = getattr(encrypted_tensor, function)(divisor)
self._check(encrypted_out, reference, "tensor division failed")

# Test int to float division
tensor = torch.ones((10,))
reference = tensor.div(scalar)
encrypted_tensor = MPCTensor(tensor, precision=0)
encrypted_out = getattr(encrypted_tensor, function)(scalar)
self._check(encrypted_out, reference, "int tensor division failed")

def test_mean(self):
"""Tests computing means of encrypted tensors."""
Expand Down