diff --git a/crypten/mpc/primitives/arithmetic.py b/crypten/mpc/primitives/arithmetic.py index fccd6b82..399ae2dc 100644 --- a/crypten/mpc/primitives/arithmetic.py +++ b/crypten/mpc/primitives/arithmetic.py @@ -317,7 +317,7 @@ def encode_(self, new_encoder): self.share *= scale_factor else: scale_factor = self.encoder.scale // new_encoder.scale - self = self.div_(scale_factor) + self = self._truncate_(scale=scale_factor) self.encoder = new_encoder return self @@ -390,12 +390,12 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C if not additive_func: if public: # scale by self.encoder.scale if self.encoder.scale > 1: - return result.div_(result.encoder.scale) + return result._truncate_() else: result.encoder = self.encoder else: # scale by larger of self.encoder.scale and y.encoder.scale if self.encoder.scale > 1 and y.encoder.scale > 1: - return result.div_(result.encoder.scale) + return result._truncate_() elif self.encoder.scale > 1: result.encoder = self.encoder else: @@ -443,6 +443,19 @@ def div(self, y): result.share = torch.broadcast_tensors(result.share, y)[0].clone() return result.div_(y) + def _truncate_(self, scale=None): + """Rescales the result of a multiplication by dividing the input by the input scale""" + if scale is None: + scale = self.encoder._scale + + # Truncate protocol for dividing by public integers: + if comm.get().get_world_size() > 2: + protocol = globals()[cfg.mpc.protocol] + self.share = protocol.truncate(self, scale).share + else: + self.share = self.share.div_(scale, rounding_mode="trunc") + return self + def div_(self, y): """Divide two tensors element-wise""" # TODO: Add test coverage for this code path (next 4 lines) @@ -458,12 +471,13 @@ def div_(self, y): tolerance = 1.0 tensor = self.get_plain_text() - # Truncate protocol for dividing by public integers: - if comm.get().get_world_size() > 2: - protocol = globals()[cfg.mpc.protocol] - protocol.truncate(self, y) - else: - self.share = self.share.div_(y, rounding_mode="trunc") + # Re-encode if input has low precision + encoder = FixedPointEncoder() + if self.encoder._scale < encoder._scale: + self.encode_(encoder) + + # Use truncate protocol for dividing by public integers: + self._truncate_(scale=y) # Validate if validate: @@ -596,7 +610,8 @@ def neg(self): def square_(self): protocol = globals()[cfg.mpc.protocol] - self.share = protocol.square(self).div_(self.encoder.scale).share + square = protocol.square(self) + self.share = square._truncate_(scale=self.encoder._scale).share return self def square(self): diff --git a/test/test_arithmetic.py b/test/test_arithmetic.py index 6ad6d614..e92deb59 100644 --- a/test/test_arithmetic.py +++ b/test/test_arithmetic.py @@ -11,12 +11,9 @@ import unittest import crypten -import crypten.communicator as comm import torch import torch.nn.functional as F -from crypten.common.rng import generate_random_ring_element from crypten.common.tensor_types import is_float_tensor -from crypten.common.util import count_wraps from crypten.mpc.primitives import ArithmeticSharedTensor from test.multiprocess_test_case import MultiProcessTestCase, get_random_test_tensor diff --git a/test/test_mpc.py b/test/test_mpc.py index 39b79eb3..3aab8076 100644 --- a/test/test_mpc.py +++ b/test/test_mpc.py @@ -307,15 +307,22 @@ def test_div(self): reference = tensor.float().div(scalar) encrypted_tensor = MPCTensor(tensor) - encrypted_tensor = getattr(encrypted_tensor, function)(scalar) + encrypted_out = getattr(encrypted_tensor, function)(scalar) self._check(encrypted_tensor, reference, "scalar division failed") # multiply denominator by 10 to avoid dividing by small num divisor = self._get_random_test_tensor(is_float=True, ex_zero=True) * 10 reference = tensor.div(divisor) encrypted_tensor = MPCTensor(tensor) - encrypted_tensor = getattr(encrypted_tensor, function)(divisor) - self._check(encrypted_tensor, reference, "tensor division failed") + encrypted_out = getattr(encrypted_tensor, function)(divisor) + self._check(encrypted_out, reference, "tensor division failed") + + # Test int to float division + tensor = torch.ones((10,)) + reference = tensor.div(scalar) + encrypted_tensor = MPCTensor(tensor, precision=0) + encrypted_out = getattr(encrypted_tensor, function)(scalar) + self._check(encrypted_out, reference, "int tensor division failed") def test_mean(self): """Tests computing means of encrypted tensors."""