Skip to content

vmap function with NamedTuple as input #16641

Answered by jakevdp
Dalouvid asked this question in Q&A
Discussion options

You must be logged in to vote

The in_axes must match the structure of the arguments. Since you're passing a named tuple of arguments, you should also pass a named tuple of in_axes:

vapply = jax.vmap(tup_net, in_axes=(None, Input(0, 0, None)))

This leads to another error from misusing flax modules, but that should answer your question regarding the vmap issue.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Dalouvid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants