Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665744848
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Aug 21, 2024
1 parent 4066a2a commit a485298
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
19 changes: 16 additions & 3 deletions tensorflow_probability/python/internal/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ============================================================================
"""Random samplers."""

import contextlib
import hashlib
import warnings

Expand Down Expand Up @@ -48,6 +49,18 @@

SEED_DTYPE = np.uint32 if JAX_MODE else np.int32

_old_salt = False


@contextlib.contextmanager
def enable_old_salt(enable):
global _old_salt
try:
_old_salt = enable
yield
finally:
_old_salt = False


def zeros_seed():
if JAX_MODE:
Expand Down Expand Up @@ -140,9 +153,9 @@ def sanitize_seed(seed, salt=None, name=None):
# discipline of splitting.

if salt is not None:
salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) % (
2**31 - 1
)
salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16)
if not _old_salt:
salt = salt % (2**31 - 1)
seed = fold_in(seed, salt)

if JAX_MODE:
Expand Down
11 changes: 11 additions & 0 deletions tensorflow_probability/python/internal/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from absl import flags
from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal import samplers
Expand All @@ -40,6 +41,16 @@ def setUp(self):
from jax import config # pylint: disable=g-import-not-at-top
config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng)

@test_util.substrate_disable_stateful_random_test
def test_old_salt(self):
if not tf1.control_flow_v2_enabled():
self.skipTest('TF2 only.')
with samplers.enable_old_salt(True):
seed = samplers.sanitize_seed(0, salt='nacl')
seed = samplers.sanitize_seed(seed, salt='kcl')
val = samplers.uniform([5], 0, 1000, dtype=tf.int32, seed=seed)
self.assertAllEqual([483, 61, 906, 125, 381], self.evaluate(val))

def test_new_style_jax_keys(self):
if not JAX_MODE:
self.skipTest('JAX-only distinction')
Expand Down

0 comments on commit a485298

Please sign in to comment.