From 303a990fc5b24fdbb184eac7de9c55f01bf8e0ec Mon Sep 17 00:00:00 2001 From: Brian Knott Date: Thu, 20 Jan 2022 16:31:33 -0800 Subject: [PATCH] Fix bug in integer division Summary: In certain cases, when inputs are integers, division will statistically round rather than converting to floating point division. This happens when the encoder scale is set to 1, which will happen after a comparison operator (e.g. sign() etc.) I found this in the linear_svm example during evaluation as we compute `accuracy = correct.add(1).div(2).mean()`. The `div(2)` will round because the input encoder has scale 1 due to the sign in the SVM code. This diff corrects the issue and adds this edge case to testing. Differential Revision: D33686267 fbshipit-source-id: 06da5d3986120d8bf64b2db438586fcf3e4d9a6f --- crypten/mpc/primitives/arithmetic.py | 35 ++++++++++++++++++++-------- test/test_arithmetic.py | 3 --- test/test_mpc.py | 13 ++++++++--- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/crypten/mpc/primitives/arithmetic.py b/crypten/mpc/primitives/arithmetic.py index fccd6b82..399ae2dc 100644 --- a/crypten/mpc/primitives/arithmetic.py +++ b/crypten/mpc/primitives/arithmetic.py @@ -317,7 +317,7 @@ def encode_(self, new_encoder): self.share *= scale_factor else: scale_factor = self.encoder.scale // new_encoder.scale - self = self.div_(scale_factor) + self = self._truncate_(scale=scale_factor) self.encoder = new_encoder return self @@ -390,12 +390,12 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C if not additive_func: if public: # scale by self.encoder.scale if self.encoder.scale > 1: - return result.div_(result.encoder.scale) + return result._truncate_() else: result.encoder = self.encoder else: # scale by larger of self.encoder.scale and y.encoder.scale if self.encoder.scale > 1 and y.encoder.scale > 1: - return result.div_(result.encoder.scale) + return result._truncate_() elif self.encoder.scale > 1: result.encoder = self.encoder else: @@ -443,6 +443,19 @@ def div(self, y): result.share = torch.broadcast_tensors(result.share, y)[0].clone() return result.div_(y) + def _truncate_(self, scale=None): + """Rescales the result of a multiplication by dividing the input by the input scale""" + if scale is None: + scale = self.encoder._scale + + # Truncate protocol for dividing by public integers: + if comm.get().get_world_size() > 2: + protocol = globals()[cfg.mpc.protocol] + self.share = protocol.truncate(self, scale).share + else: + self.share = self.share.div_(scale, rounding_mode="trunc") + return self + def div_(self, y): """Divide two tensors element-wise""" # TODO: Add test coverage for this code path (next 4 lines) @@ -458,12 +471,13 @@ def div_(self, y): tolerance = 1.0 tensor = self.get_plain_text() - # Truncate protocol for dividing by public integers: - if comm.get().get_world_size() > 2: - protocol = globals()[cfg.mpc.protocol] - protocol.truncate(self, y) - else: - self.share = self.share.div_(y, rounding_mode="trunc") + # Re-encode if input has low precision + encoder = FixedPointEncoder() + if self.encoder._scale < encoder._scale: + self.encode_(encoder) + + # Use truncate protocol for dividing by public integers: + self._truncate_(scale=y) # Validate if validate: @@ -596,7 +610,8 @@ def neg(self): def square_(self): protocol = globals()[cfg.mpc.protocol] - self.share = protocol.square(self).div_(self.encoder.scale).share + square = protocol.square(self) + self.share = square._truncate_(scale=self.encoder._scale).share return self def square(self): diff --git a/test/test_arithmetic.py b/test/test_arithmetic.py index 6ad6d614..e92deb59 100644 --- a/test/test_arithmetic.py +++ b/test/test_arithmetic.py @@ -11,12 +11,9 @@ import unittest import crypten -import crypten.communicator as comm import torch import torch.nn.functional as F -from crypten.common.rng import generate_random_ring_element from crypten.common.tensor_types import is_float_tensor -from crypten.common.util import count_wraps from crypten.mpc.primitives import ArithmeticSharedTensor from test.multiprocess_test_case import MultiProcessTestCase, get_random_test_tensor diff --git a/test/test_mpc.py b/test/test_mpc.py index 39b79eb3..3aab8076 100644 --- a/test/test_mpc.py +++ b/test/test_mpc.py @@ -307,15 +307,22 @@ def test_div(self): reference = tensor.float().div(scalar) encrypted_tensor = MPCTensor(tensor) - encrypted_tensor = getattr(encrypted_tensor, function)(scalar) + encrypted_out = getattr(encrypted_tensor, function)(scalar) self._check(encrypted_tensor, reference, "scalar division failed") # multiply denominator by 10 to avoid dividing by small num divisor = self._get_random_test_tensor(is_float=True, ex_zero=True) * 10 reference = tensor.div(divisor) encrypted_tensor = MPCTensor(tensor) - encrypted_tensor = getattr(encrypted_tensor, function)(divisor) - self._check(encrypted_tensor, reference, "tensor division failed") + encrypted_out = getattr(encrypted_tensor, function)(divisor) + self._check(encrypted_out, reference, "tensor division failed") + + # Test int to float division + tensor = torch.ones((10,)) + reference = tensor.div(scalar) + encrypted_tensor = MPCTensor(tensor, precision=0) + encrypted_out = getattr(encrypted_tensor, function)(scalar) + self._check(encrypted_out, reference, "int tensor division failed") def test_mean(self): """Tests computing means of encrypted tensors."""