-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Implement RNN support #25755
[ROCm] Implement RNN support #25755
Conversation
@dfm and @superbobry could you please take a look? |
0b07837
to
36d037e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dfm want to have a look as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good overall - thanks! My main high level comment is that it would be useful to move as much of the #ifdef JAX_GPU_HIP
logic into vendor.h
rather than in rnn_kernels.cc
directly. It's ok to have some, but the more we can move, the better. Can you look into redefining some of the macros in vendor.h
to consolidate the logic there?
jax/experimental/rnn.py
Outdated
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda') | ||
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since gpu_rnn
is in jaxlib, these changes will cause problems with version skew. JAX always needs to work with the most recent stable release of jaxlib. Perhaps you could protect this using hasattr(gpu_rnn, "miopen_rnn_fwd_lowering")
?
jax/experimental/rnn.py
Outdated
mlir.register_lowering( | ||
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, this needs to be protected against old version of jaxlib.
2e86003
to
18cc2d2
Compare
@dfm still not sure why this error wouldn't go away. I have protections in place. Probably it is how you test this in your internal CI? Seems like you are getting the jaxlib from upstream and that is why the related tests fail? |
@dfm thanks. I see what you mean. However, miopen apis are quiet different from cudnn. For e.g.
I checked to see how many of |
Yes! We require that Also: It looks like this has introduced some build issues for the CUDA CI. Can you take a look at those too? |
a909942
to
dfd1a65
Compare
dfd1a65
to
fe68eb8
Compare
@dfm I just fixed the patch. Could you please approve? thanks! |
Created from: ROCm#171