Skip to content

Adding Initial Value Support to Selective Scan Forward Kernel #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
EricLina opened this issue Apr 9, 2024 · 10 comments
Open

Adding Initial Value Support to Selective Scan Forward Kernel #285

EricLina opened this issue Apr 9, 2024 · 10 comments

Comments

@EricLina
Copy link

EricLina commented Apr 9, 2024

@tridao
Hello!

I am currently working with the selective scan forward kernel, specifically the step h_t = A*h_{t-1} + Bx, where h_0 is currently set to 0. I would like to modify this behavior to allow h_0 to be an initial value (init_value).

Upon reviewing the code, I noticed that the InclusiveScan function from Ktraits::BlockScanT(smem_scan) does not seem to support an initial value option. Here is the relevant line of code: Ktraits::BlockScanT(smem_scan).InclusiveScan()

Could you provide some guidance on how to modify the code to support an initial value for h_0? Any help would be greatly appreciated.

@tridao
Copy link
Collaborator

tridao commented Apr 9, 2024

The initial value is set in the prefix_op. You probably want to change this line:

running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);

to something like (I haven't tested this):

if (chunk == 0) {
    running_prefix = threadIdx.x % 32 == 0 ? initial_value[(r * params.n_chunks + chunk) * params.dstate + state_idx] : make_float2(1.f, 0.f);
} else {
    running_prefix = threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)
}

@EricLina
Copy link
Author

If initial_value's shape is (B, D, N), is the following acess way right?

Here is the code where I get the pointer to the initial value:

weight_t *initial_value= reinterpret_cast<weight_t *>(params.init_state_ptr) + batch_id * params.init_state_batch_stride
        + dim_id * kNRows * params.init_state_d_stride; 

I then access the initial value in the following way:

for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
    ....
    for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
        ....
        for (int r = 0; r < kNRows; ++r) {
            ....
            if (chunk == 0) { 
                    running_prefix = threadIdx.x % 32 == 0 ? initial_value[state_idx * params.init_state_d_stride + r * params.init_state_n_stride] : make_float2(1.f, 0.f);
                } else {
                    running_prefix = threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)
                }
            ....
        }
        ....
    }
    ....
}

I would appreciate it if you could review this code and provide any feedback or suggestions. I am particularly interested in whether I am accessing the initial value correctly given its shape. Thank you in advance for your help.

@tridao
Copy link
Collaborator

tridao commented Apr 10, 2024

You can print stuff out with printf to see if you're accessing the right indices.

@EricLina
Copy link
Author

Thanks😀

@EricLina
Copy link
Author

Hello, Tridao @tridao
I am currently working on the selective_scan_bwd_kernel.cuh file and have encountered a challenge regarding the implementation of backward propagation for d_init_state.

In the corresponding forward file (fwd.cuh), only InclusiveScan is utilized. However, in bwd.cuh, two operations are present: InclusiveScan and InclusiveReverseScan. The relevant code snippet is as follows:

                // Initialize running total
                scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
                Ktraits::BlockScanT(smem_scan).InclusiveScan(
                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
                );
                scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
                SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
                Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
                    thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
                );

I am unsure about the necessary modifications if I want to pass a d_initial_value into the backward interface. Could you please tell me how to adjust running_prefix and postfix_op?

I would greatly appreciate any advice or guidance on this matter.

Thank you in advance for your assistance.

@EricLina EricLina reopened this Apr 22, 2024
@Hprairie
Copy link
Contributor

@EricLina I might be a little late to this, but I was looking at something similar. The forward scan is used to recalculate hidden states, while the reverse scan is used to propagate gradients backward. The gradients for the initial hidden state should just be dh * exp( delta_0 * A), where dh is the gradient of the hidden state at timestep 0 and delta_0 is the delta of timestep 0. The values of thread_reverse_data is just dy * C. To propagate the gradients to each step they need to multiply by the exp(delta * A) for all timesteps that the hidden state was used. You could probably do something like this to calculate the gradients of the hidden state.

if (chunk == 0 && threadIdx.x == 0){
   dInitialState = thread_reverse_data[0] * smem_delta_a[state_idx + (chunk % 2) * MAX_DSTATE]
}

If your input initial state is of the shape (B D N), then you should be able to store this directly.

Let me know if that makes sense.

@EricLina
Copy link
Author

@EricLina I might be a little late to this, but I was looking at something similar. The forward scan is used to recalculate hidden states, while the reverse scan is used to propagate gradients backward. The gradients for the initial hidden state should just be dh * exp( delta_0 * A), where dh is the gradient of the hidden state at timestep 0 and delta_0 is the delta of timestep 0. The values of thread_reverse_data is just dy * C. To propagate the gradients to each step they need to multiply by the exp(delta * A) for all timesteps that the hidden state was used. You could probably do something like this to calculate the gradients of the hidden state.

if (chunk == 0 && threadIdx.x == 0){
   dInitialState = thread_reverse_data[0] * smem_delta_a[state_idx + (chunk % 2) * MAX_DSTATE]
}

If your input initial state is of the shape (B D N), then you should be able to store this directly.

Let me know if that makes sense.

Thanks!🤓

@TianCuteQY
Copy link

TianCuteQY commented Dec 12, 2024

Hi! Have you been able to resolve the issue? I'm also currently working on integrating an initial value into the selective scan forward kernel.

To access the initial_state, I used the following code:

weight_t *hidden_prev = reinterpret_cast<weight_t *>(params.hidden_prev_ptr) 
                              + batch_id * params.hidden_prev_batch_stride 
                              + dim_id * kNRows * params.hidden_prev_n_stride;

And to load the initial weight, I implemented this logic:

if (chunk == 0) {
    if constexpr (kIsComplex) {
        running_prefix = (threadIdx.x % 32 == 0) 
            ? make_float4(
                hidden_prev[state_idx * params.hidden_prev_dim_stride + r * params.hidden_prev_n_stride].real_, 
                hidden_prev[state_idx * params.hidden_prev_dim_stride + r * params.hidden_prev_n_stride].imag_, 
                0.f, 0.f
            )
            : make_float4(1.f, 0.f, 0.f, 0.f);

    } else {
        running_prefix = (threadIdx.x % 32 == 0) 
            ? make_float2(
                hidden_prev[state_idx * params.hidden_prev_dim_stride + r * params.hidden_prev_n_stride], 
                0.f
            )
            : make_float2(1.f, 0.f);

        if (threadIdx.x % 32 == 0) {
            printf(
                "Loaded from hidden_prev: Chunk: %d, Batch: %d, Dim: %d, State: %d, Row: %d, RunningPrefix: (%f, %f)\n", 
                chunk, batch_id, dim_id, state_idx, r, 
                running_prefix.x, running_prefix.y
            );
        }
    }
} else {
    ...
}

I’ve verified that the loaded running_prefix matches the initial_state input. However, when using the modified kernel for Mamba forward, the result doesn't match the one produced by the Mamba step function when providing an initial ssm_state.

Even though this might be an older issue, I’d greatly appreciate your insights or suggestions on what might be causing the discrepancy. Thank you in advance!

@Hprairie
Copy link
Contributor

This looks like you are loading in a new value for $exp(delta, A)$ but then loading in the hidden state as 0. You should probably instead loading it in as:

running_prefix = (threadIdx.x % 32 == 0) 
            ? make_float2(
                1.f,
                hidden_prev[state_idx * params.hidden_prev_dim_stride + r * params.hidden_prev_n_stride]
            )
            : make_float2(1.f, 0.f);

The layout for the CUDA kernel float2 is (exp(delta * A), hidden_state).

I think there is also a PR for this #488. I haven't tried it yet, but I would recommend taking a look into this.

@TianCuteQY
Copy link

This looks like you are loading in a new value for e x p ( d e l t a , A ) but then loading in the hidden state as 0. You should probably instead loading it in as:

running_prefix = (threadIdx.x % 32 == 0) 
            ? make_float2(
                1.f,
                hidden_prev[state_idx * params.hidden_prev_dim_stride + r * params.hidden_prev_n_stride]
            )
            : make_float2(1.f, 0.f);

The layout for the CUDA kernel float2 is (exp(delta * A), hidden_state).

I think there is also a PR for this #488. I haven't tried it yet, but I would recommend taking a look into this.

Thank you so much! 😀It works by following #488.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants