Skip to content

How to dump stablehlo for pmap? #15990

Answered by mattjj
Young768 asked this question in General
Discussion options

You must be logged in to vote

Thanks for the question!

You need to give jax.pmap(f).lower arguments that are like the same arguments you'd give to jax.pmap(f) itself. So in particular, the argument (0,) can't be mapped over, hence the error. Also, you need to pass in_axes and axis_name values to jax.pmap as usual, not to .lower.

Maybe you want something more like this?

example_x = jnp.zeros(8)
lowered = jax.pmap(_wrapped_step_fun, in_axes=(0,), axis_name=...).lower(example_x)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Young768
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants