-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Invalid gradient and Dx? #3
Comments
Whoops that's a great catch. I'll go ahead and push something in a couple minutes to fix it. All you need to do is set D=D when the module layer when calling the function. As for NaN gradients, I'm looking for parts of the kernel which cause this. Haven't identified anything yet, but I'll lyk. |
Should be the same for z. My classes resumed so I was rather quick to push the layer out. Thanks for the catch :) |
Alright, I pushed a fix. Lmk if NaNs still occur, as this is something I haven't been able to test for personally. |
Thanks @Hprairie; after passing Also, it seems that |
Thanks again for pointing this out, I learned something new. |
Thanks @Hprairie,
|
Hmmm okay I'll try to block out some time to look into the NaN problem. |
Yes, I have the same problem no matter how high you set the batch size or learning rate, they are the same. Colab link With ViT, remove the attetion and replace it with Bi-Mamba 2:https://colab.research.google.com/drive/1rgXkwnlevzZ0YPbefQS8qHRe7gFlb4J-?authuser=3. The loss result is always NaN, Bi-Mamba 2 gives too high gradient |
I am looking into this and attempting to fix the |
Hi @Hprairie, I previously built mamba-2/hydra-based models, and I am now trying to replace the layers with your bi-mamba2 module. However, I found the new model can easily get invalid gradients (e.g., infinite gradient norm) that never appeared with mamba-2/hydra.
I tested with both
torch==2.1.0, triton==3.0.0, cu122
andtorch==2.4.0, triton==3.0.0, cu121
, it seems that the more bi-mamba2 layers I stack or the more multi-processes I used, the easier the model gets this problem.Any ideas?
Besides, you mentioned that the kernel implements y=SS(x)+flip(SS(flip(x)))+Dx, but in BiMamba2() Line 108, the skip parameters
self.D
andself.fc_D
are not used for Dx. Can I ask how to pass these parameters tobimamba_chunk_scan_combined()
, or we should do something similar as in Hydra?Thanks!!!
The text was updated successfully, but these errors were encountered: