-
Let's say I want to vmap a function against a list of pytree objects. For example, here I have a namedtuple as my pytree.
This entire snippet is running inside a vmap, and ideally I want to do everything in parallel - but I can't create the list of surfaces as a jnp.ndarray. Other options I was thinking about:
Any suggestions appreciated, thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
In other words, you need to instead move the array dimension inside the containers, e.g.,
Then |
Beta Was this translation helpful? Give feedback.
-
@shoyer I have the same issue, and I'm wondering about best practices for a scenario like this. In my case, I have a list of NamedTuples. I have functions that operate on these NamedTuples and return devicearrays. For readability, I'd like to avoid re-writing those functions. For instance:
won't work out of the box, as you clarified above. My question is whether there's a way of getting this to work without modifying the function |
Beta Was this translation helpful? Give feedback.
vmap
supports pytrees when written in the form of a "struct of arrays" rather than an "array of structs". This is because JAX/XLA only has support for numeric array dtypes, so the "array of struct" form can't be operated on efficiently in JAX.In other words, you need to instead move the array dimension inside the containers, e.g.,
Then
vmap(trace)(surfaces)
should work just fine, returns a namedtuple of arrays.