Skip to content

Commit

Permalink
Switch TFP to use typed JAX PRNG keys.
Browse files Browse the repository at this point in the history
Sanitize the old-style keys to the new ones.

The sanitization is the main benefit here, since otherwise you'd get *very*
inscrutable errors if you passed new-style keys to TFP. The new style keys are
recommended (https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html) so
it makes sense to switch to them.

I also reverted the exception logic in csiszar_divergence, as this change revealed a counter example that made no reference to the seed argument.

PiperOrigin-RevId: 606320203
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Feb 12, 2024
1 parent 23a292a commit b597b1c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 9 deletions.
10 changes: 7 additions & 3 deletions tensorflow_probability/python/internal/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
def zeros_seed():
if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
return jax.random.PRNGKey(0)
return jax.random.key(0)
return tf.constant([0, 0], dtype=SEED_DTYPE)


Expand Down Expand Up @@ -143,8 +143,12 @@ def sanitize_seed(seed, salt=None, name=None):
seed = fold_in(seed, salt)

if JAX_MODE:
# Seed must be a jax.PRNGKey -- so just return it.
return seed
import jax # pylint: disable=g-import-not-at-top
# Typed keys are returned as is, otherwise wrap them.
if jax.dtypes.issubdtype(seed.dtype, jax.dtypes.prng_key):
return seed
else:
return jax.random.wrap_key_data(seed)
return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed')


Expand Down
8 changes: 8 additions & 0 deletions tensorflow_probability/python/internal/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def setUp(self):
from jax import config # pylint: disable=g-import-not-at-top
config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng)

def test_new_style_jax_keys(self):
if not JAX_MODE:
self.skipTest('JAX-only distinction')
import jax # pylint: disable=g-import-not-at-top
seed1 = samplers.sanitize_seed(jax.random.PRNGKey(0))
seed2 = samplers.sanitize_seed(jax.random.key(0))
self.assertAllEqual(seed1, seed2)

@test_util.substrate_disable_stateful_random_test
def test_sanitize_int(self):
seed1 = samplers.sanitize_seed(seed=123)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def test_seed(hardcoded_seed=None,
answer = answer % (2**32 - 1)
if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
answer = jax.random.PRNGKey(answer)
answer = jax.random.key(answer)
else:
answer = tf.constant([0, answer], dtype=tf.uint32)
answer = tf.bitcast(answer, tf.int32)
Expand Down
27 changes: 22 additions & 5 deletions tensorflow_probability/python/vi/csiszar_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import enum
import functools
import traceback
import warnings

# Dependency imports
Expand Down Expand Up @@ -58,13 +59,29 @@


def _call_fn_maybe_with_seed(fn, args, *, seed=None):
"""Try calling `fn` with or without a `seed` arg."""
try:
return nest_util.call_fn(functools.partial(fn, seed=seed), args)
except (TypeError, ValueError) as e:
if ("'seed'" in str(e) or ('one of *args or **kwargs' in str(e))):
return nest_util.call_fn(fn, args)
else:
raise e
except Exception as e1_: # pylint: disable=broad-exception-caught
e1 = e1_

# Don't call this inside the above except, so we don't get e1 to be in the
# context of e2, which is confusing.
try:
return nest_util.call_fn(fn, args)
except Exception as e2: # pylint: disable=broad-exception-caught
# For Python 3.9 compatibility, we call format_exception in this odd way.
tb1 = ''.join(
traceback.format_exception(None, value=e1, tb=e1.__traceback__)
)
tb2 = ''.join(
traceback.format_exception(None, value=e2, tb=e2.__traceback__)
)
raise RuntimeError(
f'Attempted to detect if {fn} requires a `seed`, but failed.\n'
f'Calling it with seed raised:\n\n{tb1}\n\n'
f'Calling it without the seed raised:\n\n{tb2}'
) from None


class GradientEstimators(enum.Enum):
Expand Down
58 changes: 58 additions & 0 deletions tensorflow_probability/python/vi/csiszar_divergence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for Csiszar divergences."""

import functools
import re

# Dependency imports
from absl.testing import parameterized
Expand Down Expand Up @@ -927,6 +928,63 @@ def target_log_prob_fn(x):
seed=seed)
self.assertAllClose(iwae_loss, loss, atol=0.03)

def test_seeded_target_log_prob_fn(self):
"""Call a tlp_fn that requires a seed."""
def target_log_prob_fn(x, seed):
del x, seed
return 0.

seed = test_util.test_seed(sampler_type='stateless')
cd.monte_carlo_variational_loss(
target_log_prob_fn,
surrogate_posterior=normal.Normal(loc=0.0, scale=1.0),
gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION,
importance_sample_size=1,
sample_size=1,
seed=seed,
)

def test_seeded_target_log_prob_fn_with_seed_error(self):
"""Call a tlp_fn that takes a seed, but errors if it is set."""
def target_log_prob_fn(x, seed=None):
del x
if seed is not None:
raise ValueError('Inscrutable error.')
return 0.

seed = test_util.test_seed(sampler_type='stateless')
cd.monte_carlo_variational_loss(
target_log_prob_fn,
surrogate_posterior=normal.Normal(loc=0.0, scale=1.0),
gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION,
importance_sample_size=1,
sample_size=1,
seed=seed,
)

def test_seeded_target_log_prob_fn_with_impl_error(self):
"""Call a tlp_fn that doesn't take a seed, but fails even without it."""
def target_log_prob_fn(x):
del x
raise ValueError('Implementation mistake.')

seed = test_util.test_seed(sampler_type='stateless')
with self.assertRaisesRegex(
RuntimeError,
re.compile(
r"unexpected keyword argument 'seed'.*Implementation mistake",
re.MULTILINE | re.DOTALL,
),
):
cd.monte_carlo_variational_loss(
target_log_prob_fn,
surrogate_posterior=normal.Normal(loc=0.0, scale=1.0),
gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION,
importance_sample_size=1,
sample_size=1,
seed=seed,
)


@test_util.test_all_tf_execution_regimes
class CsiszarVIMCOTest(test_util.TestCase):
Expand Down

0 comments on commit b597b1c

Please sign in to comment.