diff --git a/tensorflow_probability/python/distributions/autoregressive.py b/tensorflow_probability/python/distributions/autoregressive.py index 038d7524fe..7188589353 100644 --- a/tensorflow_probability/python/distributions/autoregressive.py +++ b/tensorflow_probability/python/distributions/autoregressive.py @@ -294,11 +294,18 @@ def _sample_n(self, n, seed=None): if num_steps_static is not None: for _ in range(num_steps_static): # pylint: disable=not-callable - samples = self.distribution_fn(samples).sample(seed=seed) + samples = self.distribution_fn(samples).sample( + seed=samplers.clone_seed(seed) + ) else: # pylint: disable=not-callable - samples = tf.foldl(lambda s, _: self.distribution_fn(s).sample(seed=seed), - elems=tf.range(0, num_steps), initializer=samples) + samples = tf.foldl( + lambda s, _: self.distribution_fn(s).sample( + seed=samplers.clone_seed(seed) + ), + elems=tf.range(0, num_steps), + initializer=samples, + ) return samples def _log_prob(self, value): diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index cbb076024d..3e5da63ebc 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -31,6 +31,7 @@ __all__ = [ 'categorical', + 'clone_seed', 'fold_in', 'gamma', 'is_stateful_seed', @@ -229,6 +230,16 @@ def split_seed(seed, n=2, salt=None, name=None): return seeds +def clone_seed(seed): + """Clones a seed so it can be reused without causing a JAX KeyReuseError.""" + if JAX_MODE: + from jax import random as jaxrand # pylint: disable=g-import-not-at-top + if hasattr(jaxrand, 'clone'): + # JAX v0.4.26+ + return jaxrand.clone(seed) + return seed + + def categorical( logits, num_samples,