-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some questions of Bi-Mamba2 and Hydra #1
Comments
I want to send the same question and check how much it reduces GPU VRAM usage. Additionally, I would like to experiment with Hydra in computer vision based on this GitHub repository, as the other Hydra is too slow and not as meaningful as Mamba 1. |
Hey, thanks for reaching out,
I'm glad people are wanting to use this, as the subquadratic nature of SSMs really starts to show with optimized bi-directional kernels. Lmk if you have any other questions, or if you find any issues with the kernel. I think that it should be very stable rn (i.e. passing all the tests I created), but there might be something tiny overlooked. |
Oh, I also want to note that this kernel appears to be much faster on NVIDIA gpu's due to tensor cores and better Triton optimizations. So for now I would recommend using them for the best performance, however, AMD will still work. |
Hi Hayden, thanks a lot for your reply!
I print out the shape of the dt variable in ssd/bi/ssd_combined.py:
The print result is:
It is worth noting that the shape of the forward propagation process dt is correct([32 ,400, 8]). |
I further printed the dt dimension in the forward function of MambaChunkScanCombinedFn and found that its shape at this time was abnormal:
Print result: MCSCF fwd dt shape: torch.Size([32, 8, 2, 256]), but i think it should be [32, 400 , 8] |
Hmm interesting, dt is in the correct chunked shape, but that shouldn't be outputting. Let me take a look and try to fix/create a working API |
Can you also paste a script so that I can reproduce your error? |
Okay, some updates. I have fixed the issue with the shape. I was accidentally saving the wrong dt for the bwd pass. This should be good now. After looking through the Hydra module, I have created my own module which replicates a similar module without extra compute/parameters. I have added a module for BiMama2, which people can use now. I want to note again that I have only tested these modules for numeric results and don't know about the training dynamics. Hopefully, people can try using it and let me know! |
Hi, Hayden, thank you very much for fixing the bug mentioned above. I also tried to replace Bi-Mamba2 with my network. It is a great improvement in speed compared to Hydra, which is an amazing work. I am currently studying how to combine Bi-Mamba2 with Yolov8, but in my preliminary experiments, I found that no matter how I adjust the optimizer type and the size of the learning rate, the loss function is always NAN. It is worth noting that the loss function was normal when using Hydra before. In my further attempts, I found that if I set AMP=True, the loss can be calculated, but this is not what I want. So I would like to ask whether Bi-Mamba2 involves the selection of FP16 and FP32 when implementing the underlying kernel code. Can you give me some good suggestions to help me achieve the normal training process when AMP=False. |
Hmmm very weird, are the gradient becoming NaN or are the outputs? I.e do you think you can identify where the NaN is coming from. I have a couple things I can check, but it is a little weird that it doesn't work with fp32. I know the original triton kernels used bf16 or fp32, so if training is stable with both of these then idk how much I can do. Giving me more information will be useful as I can't reproduce the issue. A toy training script where this happens would be amazing. |
Yes, I use a stable segsum calculation method. Currently, my method is arithmetically correct, meaning that you can download this repo and run pytest. Come back in a couple of hours and it should be passing everything. I have not been able to recreate the NaN. Also as a side note. All SSD kernels use some subtraction for an approximate answer. I won't go into it in this comment but it should be fine. I really need to know which kernel to look at. I.e. where the gradients are first appearing? For example, they could be appearing |
Thanks to Hayden and Trương for their attention to the questions raised. After many attempts, I found that I needed a larger BatchSize to solve the problem of NAN loss (NAN loss is generally caused by gradient explosion or zero gradient) compared to the original Hydra architecture. But this is a challenge for my GPU memory. Since the entire network model I designed is based on the Yolov8 project, I am not sure how to provide you with more detailed examples to help you find the problem. Anyway, I am very grateful for your patient answers and help! |
Hello, @Hprairie . I did some tests on the problem of NAN gradients caused by increasing the number of Bi-Mamba blocks. I tried to print the gradient values of dx, ddt, dA, etc. during the backpropagation process. The specific code of my test is as follows:
Test results:
I found that at the beginning of training, the gradient values of dx, dB, dC, and dD were not NAN at first, but after several rounds of batch training, they all became NAN. In addition, someone in this issue also faced the problem of NAN gradients(found this line of code causes NaN of dx). I don't know if this information can help you. |
Okay this is super helpful, The NaN's appear in ddt and dA first which is interesting. I'll look into this |
Can you help me out by printing the each variable at this line https://github.com/Hprairie/Bi-Mamba2/blob/08b3cd3cf6d60ee2f4e712f6efecebe86ec15f92/src/ssd/bi/ssd_combined.py#L466C18-L466C19 . It will identify which kernel the NaN's are coming from. |
OK, I just tested it and the results are as follows:
The results show that when training starts, the gradients of ddA_prev_b and ddA_prev_f are NAN |
Hi Hayden, thank you very much for this exciting and excellent work! I have a few questions I would like to ask you for help.
(1) First, your Bi-Mamba2 quasiseperable matrix omits the shift() operation compared to the Hydra method. In this case, can your method share the parameters of the two matrixers like Hydra? Are the other processing parameters B, C, etc. consistent with the Hydra paper?
(2) How can I use your Bi-Mamba2? Can I just replace the mamba_chunk_scan_combined() in Hydra with bimamba_chunk_scan_combined() to run it?
Looking forward to your reply, thank you again for your excellent work!
The text was updated successfully, but these errors were encountered: