Skip to content

TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed #15087

Answered by jakevdp
ShawonAshraf asked this question in Q&A
Discussion options

You must be logged in to vote

Could you edit your question to include full code reproducing the issue? (i.e. a minimal reproducible example) Here's what I tried, making my best guess at what the function inputs might look like, and I was unable to reproduce the error you report:

params = {
    'w1': jnp.ones((5, 4)),
    'w2': jnp.ones((4, 3)),
    'b1': 1.0,
    'b2': 1.0
}
features_train = jnp.ones((4, 5))
jax.vmap(forward, in_axes=(None, 0))(params, features_train)

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@ShawonAshraf
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@ShawonAshraf
Comment options

Answer selected by ShawonAshraf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants