PRNGKey handling with sharding + jit #24479
Unanswered
giovannicemin
asked this question in
Ideas
Replies: 1 comment 1 reply
-
Wrapping @partial(jax.vmap, in_axes=(0, None, None))
def f(key, a, b):
k1, k2 = jax.random.split(key)
return random_choice(k1, a), random_choice(k2, b) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello everyone,
I recently ran into some issues handling random keys within functions with sharded inputs.The main challenge is that jax.random.split only accepts single keys, while sharding passes batched data to the function. I came up with a workaround that I'd like to share:
If this solution works for you, great!
If you have a better approach, I'd love to hear your suggestions.
A minimal example is:
This throws the following error: ValueError: split accepts a single key, but was given a key array of shape (8,) != (). Use jax.vmap for batching.
The solution (or workaround) that I found is to replace
jax.random.split
with:in this way, the above code spits out:
Beta Was this translation helpful? Give feedback.
All reactions