Skip to content

Commit

Permalink
Add softplus inverse to rules.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557920930
  • Loading branch information
ColCarroll authored and The oryx Authors committed Aug 17, 2023
1 parent 0a92aae commit 4ab2815
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
9 changes: 9 additions & 0 deletions oryx/core/interpreters/inverse/inverse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,15 @@ def f(x):
np.testing.assert_allclose(x, jnp.ones((2, 2)))
np.testing.assert_allclose(ildj_, 0.)

def test_softplus_inverse_ildj(self):
softplus_inv = core.inverse_and_ildj(jax.nn.softplus)
softplus_bij = tfb.Softplus()
x, ildj = softplus_inv(0.1)
np.testing.assert_allclose(x,
softplus_bij.inverse(0.1))
np.testing.assert_allclose(ildj,
softplus_bij.inverse_log_det_jacobian(0.1))

def test_sigmoid_ildj(self):

def naive_sigmoid(x):
Expand Down
11 changes: 11 additions & 0 deletions oryx/core/interpreters/inverse/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def logit_ildj(y):
return -jax.nn.softplus(-y) - jax.nn.softplus(y)


def softplus_inv(y):
return np.log(-np.expm1(-y)) + y


def softplus_ildj(y):
return -np.log1p(-np.exp(-y))


def convert_element_type_ildj(incells, outcells, *, new_dtype, **params):
"""InverseAndILDJ rule for convert_element_type primitive."""
incell, = incells
Expand All @@ -249,8 +257,11 @@ def convert_element_type_ildj(incells, outcells, *, new_dtype, **params):
jax.scipy.special.expit = custom_inverse(jax.scipy.special.expit)
jax.scipy.special.logit = custom_inverse(jax.scipy.special.logit)
jax.nn.sigmoid = jax.scipy.special.expit
jax.nn.softplus = custom_inverse(jax.nn.softplus)
jax.scipy.special.expit.def_inverse_unary(f_inv=jax.scipy.special.logit,
f_ildj=expit_ildj)
jax.scipy.special.logit.def_inverse_unary(f_inv=jax.scipy.special.expit,
f_ildj=logit_ildj)
jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv,
f_ildj=softplus_ildj)

0 comments on commit 4ab2815

Please sign in to comment.