Skip to content

Commit

Permalink
Add direct computation as a method for computing sigmoid
Browse files Browse the repository at this point in the history
Summary:
Add an option to compute sigmoid using (1 + exp(-x)) ^ -1

This is faster but not quite as accurate: P128130956.

Reviewed By: lvdmaaten

Differential Revision: D20746654

fbshipit-source-id: 2404090be2b1fdf8b5f38cfdff77eb05c3c00b64
  • Loading branch information
knottb authored and facebook-github-bot committed Apr 14, 2020
1 parent bc139a8 commit 2389544
Showing 1 changed file with 77 additions and 37 deletions.
114 changes: 77 additions & 37 deletions crypten/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,26 +626,53 @@ def where(self, condition, y):

# Logistic Functions
@mode(Ptype.arithmetic)
def sigmoid(self, maxval_tanh=6, terms_tanh=32, reciprocal_method=None):
"""Computes the sigmoid function as
sigmoid(x) = (tanh(x /2) + 1) / 2
Args:
maxval_tanh (int): interval width used for tanh chebyshev polynomials
terms_tanh (int): highest degree of Chebyshev polynomials for tanh.
Must be even and at least 6.
"""
if reciprocal_method:
warnings.warn(
"reciprocal_method is deprecated in favor of Chebyshev approximations",
DeprecationWarning,
)
def sigmoid(self, method="reciprocal", terms=32):
"""Computes the sigmoid function using the following definition
.. math::
\sigma(x) = (1 + e^{-x})^{-1}
tanh_approx = self.div(2).tanh(maxval=maxval_tanh, terms=terms_tanh)
return tanh_approx.div(2) + 0.5
If a valid method is given, this function will compute sigmoid
using that method:
"chebyshev" - computes tanh via Chebyshev approximation with
truncation and uses the identity:
.. math::
\sigma(x) = \frac{1}{2}tanh(\frac{x}{2}) + \frac{1}{2}
Args:
terms (int): highest degree of Chebyshev polynomials for tanh
using Chebyshev approximation. Must be even and at least 6.
""" # noqa: W605
if method == "chebyshev":
tanh_approx = self.div(2).tanh(method=method, terms=terms)
return tanh_approx.div(2) + 0.5
elif method == "reciprocal":
ltz = self._ltz(_scale=True)
sign = 1 - 2 * ltz

input = self.mul(sign)
denominator = input.neg().exp(iterations=9).add(1)

pos_output = denominator.reciprocal(nr_iters=3, all_pos=True, initial=0.75)
result = pos_output * (1 - ltz) + ltz * (1 - pos_output)
# TODO: Support addition with different encoder scales
# result = pos_output + ltz - 2 * pos_output * ltz
return result
else:
raise ValueError(f"Unrecognized method {method} for sigmoid")

@mode(Ptype.arithmetic)
def tanh(self, maxval=6, terms=32, reciprocal_method=None):
r"""Computes tanh via Chebyshev approximation with truncation.
def tanh(self, method="reciprocal", terms=32):
r"""Computes the hyperbolic tangent function using the identity
.. math::
tanh(x) = 2\sigma(2x) - 1
If a valid method is given, this function will compute tanh using that method:
"chebyshev" - computes tanh via Chebyshev approximation with truncation.
.. math::
tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval)
Expand All @@ -654,25 +681,27 @@ def tanh(self, maxval=6, terms=32, reciprocal_method=None):
The approximation is truncated to +/-1 outside [-maxval, maxval].
Args:
maxval (int): interval width used for computing chebyshev polynomials
terms (int): highest degree of Chebyshev polynomials.
Must be even and at least 6.
"""
if reciprocal_method:
warnings.warn(
"reciprocal_method is deprecated in favor of Chebyshev approximations",
DeprecationWarning,
if method == "reciprocal":
return self.mul(2).sigmoid(method=method).mul(2).sub(1)
elif method == "chebyshev":
maxval = 6
coeffs = crypten.common.util.chebyshev_series(torch.tanh, maxval, terms)[
1::2
]
tanh_polys = self.div(maxval)._chebyshev_polynomials(terms)
tanh_polys_flipped = (
tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0)
)
out = tanh_polys_flipped.matmul(coeffs)

coeffs = crypten.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2]
tanh_polys = self.div(maxval)._chebyshev_polynomials(terms)
tanh_polys_flipped = (
tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0)
)
out = tanh_polys_flipped.matmul(coeffs)
# truncate outside [-maxval, maxval]
out = self._truncate_tanh(maxval, out)
return out
# truncate outside [-maxval, maxval]
out = self._truncate_tanh(maxval, out)
return out
else:
raise ValueError(f"Unrecognized method {method} for tanh")

def _truncate_tanh(self, maxval, out):
"""Truncates `out` to +/-1 when self is outside [-maxval, maxval].
Expand Down Expand Up @@ -818,12 +847,14 @@ def log(self, iterations=2, exp_iterations=8, order=8):
y -= h.polynomial([1 / (i + 1) for i in range(order)])
return y

def reciprocal(self, method="NR", nr_iters=10, log_iters=1, all_pos=False):
def reciprocal(
self, method="NR", nr_iters=10, log_iters=1, all_pos=False, initial=None
):
"""
Methods:
'NR' : `Newton-Raphson`_ method computes the reciprocal using iterations
of :math:`x_{i+1} = (2x_i - self * x_i^2)` and uses
:math:`3*exp(-(x-.5)) + 0.003` as an initial guess
:math:`3*exp(-(x-.5)) + 0.003` as an initial guess by default
'log' : Computes the reciprocal of the input from the observation that:
:math:`x^{-1} = exp(-log(x))`
Expand All @@ -836,6 +867,9 @@ def reciprocal(self, method="NR", nr_iters=10, log_iters=1, all_pos=False):
all_pos (bool): determines whether all elements
of the input are known to be positive, which optimizes
the step of computing the sign of the input.
initial (tensor): sets the initial value for the Newton-Raphson method. By
default, this will be set to :math: `3*exp(-(x-.5)) + 0.003` as
this allows the method to converge over a fairly large domain
.. _Newton-Raphson:
https://en.wikipedia.org/wiki/Newton%27s_method
Expand All @@ -849,11 +883,17 @@ def reciprocal(self, method="NR", nr_iters=10, log_iters=1, all_pos=False):
return sgn * rec

if method == "NR":
# Initialization to a decent estimate (found by qualitative inspection):
# 1/x = 3exp(.5 - x) + 0.003
result = 3 * (0.5 - self).exp() + 0.003
if initial is None:
# Initialization to a decent estimate (found by qualitative inspection):
# 1/x = 3exp(.5 - x) + 0.003
result = 3 * (0.5 - self).exp() + 0.003
else:
result = initial
for _ in range(nr_iters):
result += result - result.square().mul_(self)
if isinstance(result, MPCTensor):
result += result - result.square().mul_(self)
else:
result = 2 * result - result * result * self
return result
elif method == "log":
return (-self.log(iterations=log_iters)).exp()
Expand Down

0 comments on commit 2389544

Please sign in to comment.