Skip to content

How to avoid double random.fold_in when indexing into two different sources? #15240

Answered by froystig
cgarciae asked this question in Q&A
Discussion options

You must be logged in to vote

I suspect that iterating fold_in (the hash) is often fine, and won't often present a bottleneck. If this does need optimizing, we could have some fun:

You could use a pairing function, which enumerates the integer grid, and compose it with fold_in. Here is an example using Cantor's pairing function:

def cantor(a, b):
  a, b = a + 1, b + 1
  return (a * a + 2 * a * b + b * b - a - 3 * b + 2) // 2

def fold_in2(key, a, b):
  return jax.random.fold_in(key, cantor(a, b))

Here's how the enumeration looks:

>>> import numpy as np
>>> np.array([cantor(i, j) for i in range(5) for j in range(5)]).reshape(5, 5)
array([[ 1,  2,  4,  7, 11],
       [ 3,  5,  8, 12, 17],
       [ 6,  9, 13, 18, 24],
  …

Replies: 2 comments 7 replies

Comment options

You must be logged in to vote
6 replies
@cgarciae
Comment options

cgarciae Mar 29, 2023
Collaborator Author

@froystig
Comment options

@cgarciae
Comment options

cgarciae Mar 29, 2023
Collaborator Author

@froystig
Comment options

@froystig
Comment options

Answer selected by cgarciae
Comment options

You must be logged in to vote
1 reply
@froystig
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants