-
Hey guys, first, thanks for your work! I have a problem. I am not able to vmap a NN over a batch-dimension of a named-tuple property. from typing import NamedTuple, Optional
import jax
import jax.numpy as np
import flax.linen as nn
class Input(NamedTuple):
a: Optional[np.ndarray]
b: Optional[np.ndarray]
c: Optional[np.ndarray] the network looks roughly like this class TupleNetwork(nn.Module):
@nn.compact
def __call__(self, tuple_inp):
a = tuple_inp.a
b = tuple_inp.b
c = tuple_inp.c
a_1 = nn.Dense(4)(a)
b_1 = nn.Dense(4)(b)
c_1 = nn.Dense(4)(c)
return a_1 + b_1 + c_1
tup_net = TupleNetwork() I initialize the network with dummy_input = Input(np.array([10.0]),
np.array([10.0]),
np.array([10.0]))
key, _ = jax.random.split(jax.random.PRNGKey(0))
params = tup_net.init(key, dummy_input) so all good so far. calling the network works just fine: output = tup_net.apply(params, dummy_input) But now I want to vmap the apply function over the first dimension of "a" and "b" in the named tuple. A possible input for that would look like: batch_input = Input(10 *np.ones(5), 10 *np.ones(5), np.array([10.0])) To vmap over a and b, I do: vapply = jax.vmap(tup_net, in_axes=(None, (0, 0, None))) When calling the function with batched input batch_output = vapply(params, batch_input) This gives me the following error: My question is, how do I have to define my |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The vapply = jax.vmap(tup_net, in_axes=(None, Input(0, 0, None))) This leads to another error from misusing flax modules, but that should answer your question regarding the |
Beta Was this translation helpful? Give feedback.
The
in_axes
must match the structure of the arguments. Since you're passing a named tuple of arguments, you should also pass a named tuple ofin_axes
:This leads to another error from misusing flax modules, but that should answer your question regarding the
vmap
issue.