Skip to content

Windows support? #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
build/
**.so
*.hip
*_hip.*
*_hip.*
.idea/
dist/
70 changes: 28 additions & 42 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
*****************************************************************************/

#pragma once

Expand All @@ -9,6 +9,10 @@
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074f
#endif

#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
Expand All @@ -28,6 +32,20 @@ template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x)
template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }

// Helper: set the kernel's max dynamic shared memory size.
// This helper is defined at global scope so that preprocessor directives are not inside lambdas.
template<typename KernelT>
__host__ inline void setDynamicSharedMemoryAttr(KernelT kernel, int smemSize) {
if (smemSize >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute((void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize));
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
}
}

template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
struct Selective_Scan_bwd_kernel_traits {
Expand Down Expand Up @@ -94,10 +112,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {

// Shared memory.
extern __shared__ char smem_[];
// cast to lvalue reference of expected type
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
Expand Down Expand Up @@ -158,7 +172,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
u -= kChunkSize;
__syncthreads();
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
// Will reload delta at the same location if kDeltaSoftplus
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
__syncthreads();
load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
Expand Down Expand Up @@ -198,13 +211,10 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
}
__syncthreads();
store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
if (params.out_z_ptr != nullptr) { // Recompute and store out_z
if (params.out_z_ptr != nullptr) {
float out_z_vals[kNItems];
#pragma unroll
for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
// printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
// }
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
+ dim_id * params.out_z_d_stride + chunk * kChunkSize;
__syncthreads();
Expand Down Expand Up @@ -245,7 +255,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
}
// const weight_t A_val = smem_a[state_idx];
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
if constexpr (!kIsComplex) {
#pragma unroll
Expand All @@ -266,7 +275,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
// Initialize running total
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
Expand All @@ -289,9 +297,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
dA_val += dx * delta_vals[i] * a;
if constexpr (!kIsVariableB || !kIsVariableC) {
if constexpr (!kIsVariableB) { // dBC_val is dB_val
if constexpr (!kIsVariableB) {
dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
} else { // dBC_val is dC_val
} else {
dBC_val += dout_vals[i] * thread_data[i].y;
}
}
Expand All @@ -300,7 +308,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
}
}
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
if constexpr (kIsVariableB || kIsVariableC) {
if constexpr (kIsVariableB) {
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
Expand Down Expand Up @@ -336,7 +343,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
} else {
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
// Pytorch's implementation of complex exp (which calls thrust) is very slow
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
Expand All @@ -359,7 +365,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
// Initialize running total
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
Expand All @@ -379,9 +384,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
if constexpr (!kIsVariableB || !kIsVariableC) {
if constexpr (!kIsVariableB) { // dBC_val is dB_val
if constexpr (!kIsVariableB) {
dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
} else { // dBC_val is dC_val
} else {
dBC_val += (2 * dout_vals[i]) * conj(x);
}
}
Expand All @@ -394,7 +399,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
}
}
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
if constexpr (kIsVariableB || kIsVariableC) {
float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
if constexpr (kIsVariableB) {
Expand Down Expand Up @@ -431,7 +435,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
if (threadIdx.x == 0) {
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
}
} else {
dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
Expand Down Expand Up @@ -502,27 +506,10 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// TODO: check this
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);

dim3 grid(params.batch, params.dim);

auto kernel = &selective_scan_bwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {

#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif

}

setDynamicSharedMemoryAttr(kernel, kSmemSize);
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand All @@ -534,7 +521,6 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {

template<typename input_t, typename weight_t>
void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {

#ifndef USE_ROCM
if (params.seqlen <= 128) {
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
Expand All @@ -547,7 +533,7 @@ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
} else {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
}
#else
#else
if (params.seqlen <= 256) {
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 512) {
Expand All @@ -558,4 +544,4 @@ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
}
#endif
}
}
Loading