Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with backward on RNN with cuda #220

Closed
lth456321 opened this issue Dec 27, 2024 · 5 comments · Fixed by #225
Closed

Error with backward on RNN with cuda #220

lth456321 opened this issue Dec 27, 2024 · 5 comments · Fixed by #225
Labels
bug Something isn't working feat New feature or request package: autojac

Comments

@lth456321
Copy link

Many thanks for authors' work, but I meet trouble when using thorchjd with rnn, it seems like rnn is not supported in vmap. May I ask if someone can provide some suggestions? Thanks in advance

@PierreQuinton
Copy link
Contributor

Thank you for your interest in our library, we will investigate that and hopefully provide a solution.
In the meantime, could you provide us with a minimal example that illustrates your problem?
This would help considerably.

@PierreQuinton PierreQuinton added bug Something isn't working feat New feature or request package: autojac labels Dec 27, 2024
@lth456321
Copy link
Author

acbc95e5c5747f5a58e740d21b0f80b
It is a minimal example followed your example, when I delete the code about model_rnn, the error won't occur

@PierreQuinton
Copy link
Contributor

Thanks a lot, we will look into it.
By the way, you can use the ``` to write code in markdown (which github uses), what I typically do is to start with ```python, then copy-paste my code and then finish with another ```. Here is a preview:

  # This is an example of python code
  i = 0
  while True:
    i++

@ValerianRey
Copy link
Contributor

ValerianRey commented Dec 28, 2024

acbc95e5c5747f5a58e740d21b0f80b It is a minimal example followed your example, when I delete the code about model_rnn, the error won't occur

Thanks! I managed to reproduce your issue (see #221). It seems that the cuda implementation of RNNs is not compatible with batching, which we use internally in torchjd.

As a short-term solution, you can set your device to "cpu" instead of "cuda". This will, however, slow down the training.

@PierreQuinton, I think we should make backward not use the Jac transform when parallel_chunk_size=1 (and use stacked Grad transforms instead in this case). This way, we would have a way to avoid using vmap, and this issue would have a fairly easy solution. The alternative is to wait for the moment we can stop relying on vmap altogether, but this is very hypothetical and could come in a long time.

@ValerianRey ValerianRey changed the title issuse about using with rnn Error with backward on RNN with cuda Dec 28, 2024
@ValerianRey
Copy link
Contributor

ValerianRey commented Jan 2, 2025

With #222, we now don't rely on vmap when calling backward with parallel_chunk_size=1. With #225, we will have a usage example for RNNs that will advise to use parallel_chunk_size=1. So for me, this issue will be solved when we merge #225.

@lth456321 We will release v0.4.0 soon after, so you should be able to make this change and make your code work if you update torchjd. Keep us updated if this solves your problem.

Thanks again for your issue, it has lead to significant improvements of the library!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feat New feature or request package: autojac
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants