diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index c2b696c894..0ce1ff8caf 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -51,7 +51,7 @@ def zeros_seed(): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top - return jax.random.PRNGKey(0) + return jax.random.key(0) return tf.constant([0, 0], dtype=SEED_DTYPE) @@ -143,8 +143,12 @@ def sanitize_seed(seed, salt=None, name=None): seed = fold_in(seed, salt) if JAX_MODE: - # Seed must be a jax.PRNGKey -- so just return it. - return seed + import jax # pylint: disable=g-import-not-at-top + # Typed keys are returned as is, otherwise wrap them. + if jax.dtypes.issubdtype(seed.dtype, jax.dtypes.prng_key): + return seed + else: + return jax.random.wrap_key_data(seed) return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed') diff --git a/tensorflow_probability/python/internal/samplers_test.py b/tensorflow_probability/python/internal/samplers_test.py index 3ae5fdfd0e..2aa7009f6d 100644 --- a/tensorflow_probability/python/internal/samplers_test.py +++ b/tensorflow_probability/python/internal/samplers_test.py @@ -40,6 +40,14 @@ def setUp(self): from jax import config # pylint: disable=g-import-not-at-top config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng) + def test_new_style_jax_keys(self): + if not JAX_MODE: + self.skipTest('JAX-only distinction') + import jax # pylint: disable=g-import-not-at-top + seed1 = samplers.sanitize_seed(jax.random.PRNGKey(0)) + seed2 = samplers.sanitize_seed(jax.random.key(0)) + self.assertAllEqual(seed1, seed2) + @test_util.substrate_disable_stateful_random_test def test_sanitize_int(self): seed1 = samplers.sanitize_seed(seed=123) diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index 0da39d05b5..f148f00ce6 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -1532,7 +1532,7 @@ def test_seed(hardcoded_seed=None, answer = answer % (2**32 - 1) if JAX_MODE: import jax # pylint: disable=g-import-not-at-top - answer = jax.random.PRNGKey(answer) + answer = jax.random.key(answer) else: answer = tf.constant([0, answer], dtype=tf.uint32) answer = tf.bitcast(answer, tf.int32) diff --git a/tensorflow_probability/python/vi/csiszar_divergence.py b/tensorflow_probability/python/vi/csiszar_divergence.py index e8aec7504c..925fb38db0 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence.py +++ b/tensorflow_probability/python/vi/csiszar_divergence.py @@ -16,6 +16,7 @@ import enum import functools +import traceback import warnings # Dependency imports @@ -58,13 +59,29 @@ def _call_fn_maybe_with_seed(fn, args, *, seed=None): + """Try calling `fn` with or without a `seed` arg.""" try: return nest_util.call_fn(functools.partial(fn, seed=seed), args) - except (TypeError, ValueError) as e: - if ("'seed'" in str(e) or ('one of *args or **kwargs' in str(e))): - return nest_util.call_fn(fn, args) - else: - raise e + except Exception as e1_: # pylint: disable=broad-exception-caught + e1 = e1_ + + # Don't call this inside the above except, so we don't get e1 to be in the + # context of e2, which is confusing. + try: + return nest_util.call_fn(fn, args) + except Exception as e2: # pylint: disable=broad-exception-caught + # For Python 3.9 compatibility, we call format_exception in this odd way. + tb1 = ''.join( + traceback.format_exception(None, value=e1, tb=e1.__traceback__) + ) + tb2 = ''.join( + traceback.format_exception(None, value=e2, tb=e2.__traceback__) + ) + raise RuntimeError( + f'Attempted to detect if {fn} requires a `seed`, but failed.\n' + f'Calling it with seed raised:\n\n{tb1}\n\n' + f'Calling it without the seed raised:\n\n{tb2}' + ) from None class GradientEstimators(enum.Enum): diff --git a/tensorflow_probability/python/vi/csiszar_divergence_test.py b/tensorflow_probability/python/vi/csiszar_divergence_test.py index 5862e6c4b5..eff6bf2215 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence_test.py +++ b/tensorflow_probability/python/vi/csiszar_divergence_test.py @@ -15,6 +15,7 @@ """Tests for Csiszar divergences.""" import functools +import re # Dependency imports from absl.testing import parameterized @@ -927,6 +928,63 @@ def target_log_prob_fn(x): seed=seed) self.assertAllClose(iwae_loss, loss, atol=0.03) + def test_seeded_target_log_prob_fn(self): + """Call a tlp_fn that requires a seed.""" + def target_log_prob_fn(x, seed): + del x, seed + return 0. + + seed = test_util.test_seed(sampler_type='stateless') + cd.monte_carlo_variational_loss( + target_log_prob_fn, + surrogate_posterior=normal.Normal(loc=0.0, scale=1.0), + gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION, + importance_sample_size=1, + sample_size=1, + seed=seed, + ) + + def test_seeded_target_log_prob_fn_with_seed_error(self): + """Call a tlp_fn that takes a seed, but errors if it is set.""" + def target_log_prob_fn(x, seed=None): + del x + if seed is not None: + raise ValueError('Inscrutable error.') + return 0. + + seed = test_util.test_seed(sampler_type='stateless') + cd.monte_carlo_variational_loss( + target_log_prob_fn, + surrogate_posterior=normal.Normal(loc=0.0, scale=1.0), + gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION, + importance_sample_size=1, + sample_size=1, + seed=seed, + ) + + def test_seeded_target_log_prob_fn_with_impl_error(self): + """Call a tlp_fn that doesn't take a seed, but fails even without it.""" + def target_log_prob_fn(x): + del x + raise ValueError('Implementation mistake.') + + seed = test_util.test_seed(sampler_type='stateless') + with self.assertRaisesRegex( + RuntimeError, + re.compile( + r"unexpected keyword argument 'seed'.*Implementation mistake", + re.MULTILINE | re.DOTALL, + ), + ): + cd.monte_carlo_variational_loss( + target_log_prob_fn, + surrogate_posterior=normal.Normal(loc=0.0, scale=1.0), + gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION, + importance_sample_size=1, + sample_size=1, + seed=seed, + ) + @test_util.test_all_tf_execution_regimes class CsiszarVIMCOTest(test_util.TestCase):