TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed #15087
-
Greetings, I was trying out an ANN which has the following functions to init the parameters and run a forward pass. # inits params
# w, b
# init strategy : Kaiming
def ann(in_features, hidden_features, out_features, *prngs):
scale_factor = jnp.sqrt(2/in_features)
# layer 1
w1 = jax.random.normal(prngs[0], (in_features, hidden_features)) * scale_factor
b1 = jax.random.normal(prngs[1], (1, hidden_features))
# layer 2
w2 = jax.random.normal(
prngs[2], (hidden_features, out_features)) * scale_factor
b2 = jax.random.normal(
prngs[3], (1, out_features))
return {
"w1": w1,
"b1": b1,
"w2": w2,
"b2": b2
}
# forward pass
@jax.jit
def forward(params, x):
# from layer 1
# xW + b
out1 = x @ params["w1"] + params["b1"]
out1 = jax.nn.relu(out1)
# layer 2
out2 = out1 @ params["w2"] + params["b2"]
out2 = jax.nn.relu(out2)
# apply softmax to convert to probability dist
# since the loss function is cross entropy
logits = jax.nn.softmax(out2)
return logits I'm training the ANN with cross entropy as loss. The problem I'm facing here is that when I run jax.vmap(forward, in_axes=(None, 0))(params, features_train) I get the expected output without any errors. But when I call forward in my training loop I get the following error:
The issue mentioned in the stacktrace is about arrays. I'm having a hard time understanding why it's thrown when I'm using a dict as a container for the parameters. What could be a solution here? Should I use pytrees in forward? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Could you edit your question to include full code reproducing the issue? (i.e. a minimal reproducible example) Here's what I tried, making my best guess at what the function inputs might look like, and I was unable to reproduce the error you report: params = {
'w1': jnp.ones((5, 4)),
'w2': jnp.ones((4, 3)),
'b1': 1.0,
'b2': 1.0
}
features_train = jnp.ones((4, 5))
jax.vmap(forward, in_axes=(None, 0))(params, features_train) |
Beta Was this translation helpful? Give feedback.
Could you edit your question to include full code reproducing the issue? (i.e. a minimal reproducible example) Here's what I tried, making my best guess at what the function inputs might look like, and I was unable to reproduce the error you report: