diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index f6cd8bc0da51..e9d01eae0e83 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -18,6 +18,7 @@ """ from enum import Enum +from typing import Any import numpy as np from functools import partial @@ -29,14 +30,30 @@ standard_naryop, standard_unop, sub, _const, _dtype, _float, _nary_lower_hlo, _ones, _isnan, _reduce) -from jax._src.lax.control_flow import while_loop +from jax._src.lax.control_flow import cond, scan, while_loop +from jax._src import api from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import chlo from jax._src.typing import Array, ArrayLike +def _while_loop_scan(cond_fun, body_fun, init_val, max_iter): + """Scan-based implementation (jit ok, reverse-mode autodiff ok).""" + def _iter(val): + next_val = body_fun(val) + next_cond = cond_fun(next_val) + return next_val, next_cond + + def _fun(tup, it): + val, _cond = tup + # When _cond is met, we start doing no-ops. + return cond(_cond, _iter, lambda x: (x, False), val), it + + init = (init_val, cond_fun(init_val)) + return scan(_fun, init, None, length=max_iter)[0][0] + def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" return regularized_incomplete_beta_p.bind(a, b, x) @@ -250,7 +267,7 @@ def _any(predicates: Array) -> Array: all_dimensions = tuple(range(len(predicates_shape))) return reduce(predicates, f, bitwise_or, all_dimensions) -def _igamma_series(ax, x, a, enabled, dtype, mode): +def _igamma_series(ax, x, a, enabled, dtype, mode, *, hessian: bool = False): def cond_fn(vals): return _any(vals[0]) @@ -285,7 +302,9 @@ def body_fn(vals): full_like(a, 0), ) - vals = while_loop(cond_fn, body_fn, init_vals) + vals = (_while_loop_scan(cond_fn, body_fn, init_vals, 256) + if hessian + else while_loop(cond_fn, body_fn, init_vals)) ans = vals[3] dans_da = vals[6] @@ -327,7 +346,9 @@ def igamma_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): +def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode, + *, + hessian: bool = False): eps = dtypes.finfo(dtype).eps def cond_fn(vals): @@ -418,7 +439,9 @@ def body_fn(vals): c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) - vals = while_loop(cond_fn, body_fn, init_vals) + vals = (_while_loop_scan(cond_fn, body_fn, init_vals, 256) + if hessian + else while_loop(cond_fn, body_fn, init_vals)) ans = vals[1] if mode == IgammaMode.VALUE: return ans * ax @@ -470,7 +493,12 @@ def igamma_grad_a_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def random_gamma_grad_impl(a, x, *, dtype): +def random_gamma_grad_impl(a: Array, + x: Array, + *, + dtype: Any, + hessian: bool = False + ) -> Array: is_nan = bitwise_or(_isnan(a), _isnan(x)) x_is_zero = eq(x, full_like(x,0)) domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0))) @@ -480,11 +508,13 @@ def random_gamma_grad_impl(a, x, *, dtype): ax = exp(ax) enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or (x_is_zero, domain_error), underflow), is_nan)) - output = select(use_igammac, - -_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), - dtype, IgammaMode.SAMPLE_DERIVATIVE), - _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), - dtype, IgammaMode.SAMPLE_DERIVATIVE)) + output = select( + use_igammac, + -_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), + dtype, IgammaMode.SAMPLE_DERIVATIVE, + hessian=hessian), + _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), + dtype, IgammaMode.SAMPLE_DERIVATIVE, hessian=hessian)) output = select(x_is_zero, full_like(output,0), output) output = select(bitwise_or(domain_error, is_nan), full_like(a, float('nan')), output) @@ -653,10 +683,21 @@ def bessel_i0e_impl(x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) +def random_gamma_hessian_a(g, a, x, *, dtype): + return api.grad(random_gamma_grad_impl, argnums=0)(a, x, dtype=dtype, + hessian=True) + +def random_gamma_hessian_x(g, a, x, *, dtype): + return api.grad(random_gamma_grad_impl, argnums=1)(a, x, dtype=dtype, + hessian=True) + random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') mlir.register_lowering(random_gamma_grad_p, mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), multiple_results=False)) +ad.defjvp(random_gamma_grad_p, + _up_and_broadcast(random_gamma_hessian_a), + _up_and_broadcast(random_gamma_hessian_x)) zeta_p = standard_naryop([_float, _float], 'zeta') mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) diff --git a/tests/random_test.py b/tests/random_test.py index 941172f75278..dd2cb517d655 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1473,5 +1473,18 @@ def f(): jax.random.normal(jax.random.key(0), 1000) f() # don't crash +class SamplingDerivativeTest(jtu.JaxTestCase): + def test_gamma_hessian(self): + # Regression test for https://github.com/google/jax/issues/16076 + def hessian_sample(key: jax.Array) -> jax.Array: + ((retval,),) = jax.hessian(random.gamma, argnums=(1,))(key, 0.8) + return retval + + keys = random.split(random.key(0), 300) + x = jax.vmap(hessian_sample)(keys) + mean_x = jnp.mean(x, axis=-1) + self.assertArraysAllClose(mean_x, jnp.asarray(0.61), atol=0.1, rtol=0.4) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())