Skip to content

vmap against list of pytree containers #5322

Answered by shoyer
bhchiang asked this question in Q&A
Discussion options

You must be logged in to vote

vmap supports pytrees when written in the form of a "struct of arrays" rather than an "array of structs". This is because JAX/XLA only has support for numeric array dtypes, so the "array of struct" form can't be operated on efficiently in JAX.

In other words, you need to instead move the array dimension inside the containers, e.g.,

centers = jnp.stack([jnp.zeros(3), jnp.array([0, 2, 0])])
radiuses = jnp.stack([1, 10])
surfaces = Sphere(centers, radiuses)

Then vmap(trace)(surfaces) should work just fine, returns a namedtuple of arrays.

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
2 replies
@bhchiang
Comment options

@vaishnkv
Comment options

Answer selected by shoyer
Comment options

You must be logged in to vote
3 replies
@bhchiang
Comment options

@jecampagne
Comment options

@jecampagne
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants