-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Adding Initial Value Support to Selective Scan Forward Kernel #285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The initial value is set in the prefix_op. You probably want to change this line:
to something like (I haven't tested this):
|
If initial_value's shape is (B, D, N), is the following acess way right? Here is the code where I get the pointer to the initial value:
I then access the initial value in the following way:
I would appreciate it if you could review this code and provide any feedback or suggestions. I am particularly interested in whether I am accessing the initial value correctly given its shape. Thank you in advance for your help. |
You can print stuff out with printf to see if you're accessing the right indices. |
Thanks😀 |
Hello, Tridao @tridao In the corresponding forward file (fwd.cuh), only InclusiveScan is utilized. However, in bwd.cuh, two operations are present: InclusiveScan and InclusiveReverseScan. The relevant code snippet is as follows:
I am unsure about the necessary modifications if I want to pass a d_initial_value into the backward interface. Could you please tell me how to adjust running_prefix and postfix_op? I would greatly appreciate any advice or guidance on this matter. Thank you in advance for your assistance. |
@EricLina I might be a little late to this, but I was looking at something similar. The forward scan is used to recalculate hidden states, while the reverse scan is used to propagate gradients backward. The gradients for the initial hidden state should just be if (chunk == 0 && threadIdx.x == 0){
dInitialState = thread_reverse_data[0] * smem_delta_a[state_idx + (chunk % 2) * MAX_DSTATE]
} If your input initial state is of the shape (B D N), then you should be able to store this directly. Let me know if that makes sense. |
Thanks!🤓 |
Hi! Have you been able to resolve the issue? I'm also currently working on integrating an initial value into the selective scan forward kernel. To access the initial_state, I used the following code:
And to load the initial weight, I implemented this logic:
I’ve verified that the loaded running_prefix matches the initial_state input. However, when using the modified kernel for Mamba forward, the result doesn't match the one produced by the Mamba step function when providing an initial ssm_state. Even though this might be an older issue, I’d greatly appreciate your insights or suggestions on what might be causing the discrepancy. Thank you in advance! |
This looks like you are loading in a new value for running_prefix = (threadIdx.x % 32 == 0)
? make_float2(
1.f,
hidden_prev[state_idx * params.hidden_prev_dim_stride + r * params.hidden_prev_n_stride]
)
: make_float2(1.f, 0.f); The layout for the CUDA kernel float2 is (exp(delta * A), hidden_state). I think there is also a PR for this #488. I haven't tried it yet, but I would recommend taking a look into this. |
Thank you so much! 😀It works by following #488. |
@tridao
Hello!
I am currently working with the selective scan forward kernel, specifically the step h_t = A*h_{t-1} + Bx, where h_0 is currently set to 0. I would like to modify this behavior to allow h_0 to be an initial value (init_value).
Upon reviewing the code, I noticed that the InclusiveScan function from Ktraits::BlockScanT(smem_scan) does not seem to support an initial value option. Here is the relevant line of code: Ktraits::BlockScanT(smem_scan).InclusiveScan()
Could you provide some guidance on how to modify the code to support an initial value for h_0? Any help would be greatly appreciated.
The text was updated successfully, but these errors were encountered: