xmap properly compiles but then raises an axis size error #14879
Unanswered
cmunna0052
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am not sure if this is a bug or simply a misunderstanding on my part. I am trying to shard a feed-forward network with arbitrary hidden layers using xmap. My goal is the following: Each weight matrix is split up into 10 groups of columns (corresponding to 10 devices). At each step, the current x-vector is sent in full to all 10 devices. On each device it is multiplied into 1/10th of the weight matrix columns, added to 1/10th of bias vector, and then put through the activation to get 1/10th of the next x-vector. Then, jax.lax.all_gather is called to recombine the sharded x-vectors. This repeats until the end. Here is the code:
When I run this code, it completes the compilation of forward (I can tell by adding print statements throughout), but then in the actual calculation it fails at
assert axis_size == frame_size, "axis size doesn't match"
Does anyone see what is going on here? Is there a better/easier way to do this? I can see one solution where I split up the parameter matrices manually, but I thought it would be possible to have xmap handle that for me.
Beta Was this translation helpful? Give feedback.
All reactions