Skip to content

Commit

Permalink
replacing ifdefs on host code with those on kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Malyshev committed Aug 1, 2024
1 parent 00600ac commit e2b0310
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 59 deletions.
88 changes: 60 additions & 28 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#include <stdexcept>
#include <algorithm>

#if defined(__gfx942__)

constexpr int WARP_SIZE = 64;

template <typename T>
Expand Down Expand Up @@ -332,6 +330,8 @@ __device__ __forceinline__ T loadnt(T* addr) {
#define M 1
#define DTYPE half

#if defined(__gfx942__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
Expand Down Expand Up @@ -459,6 +459,18 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__gfx942__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}

#endif // defined(__gfx942__) TODO: Add NAVI support

#if defined(__gfx942__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
Expand Down Expand Up @@ -804,6 +816,16 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__gfx942__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}

#endif // defined(__gfx942__) TODO: Add NAVI support

#undef YTILE
#undef UNRL
#undef M
Expand All @@ -812,6 +834,8 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
#define UNRL 2
#define M 2

#if defined(__gfx942__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
Expand Down Expand Up @@ -1157,6 +1181,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__gfx942__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}

#endif // defined(__gfx942__) TODO: Add NAVI support

#undef YTILE
#undef UNRL
#undef M
Expand All @@ -1165,6 +1199,8 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
#define UNRL 2
#define M 3

#if defined(__gfx942__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
Expand Down Expand Up @@ -1510,6 +1546,16 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__gfx942__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}

#endif // defined(__gfx942__) TODO: Add NAVI support

#undef YTILE
#undef UNRL
#undef M
Expand All @@ -1518,6 +1564,8 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
#define UNRL 1
#define M 4

#if defined(__gfx942__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
Expand Down Expand Up @@ -1863,6 +1911,16 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__gfx942__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}

#endif // defined(__gfx942__) TODO: Add NAVI support

void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in,
const int K_in, const int N_in, cudaStream_t stream,
const int CuCount = 0) {
Expand Down Expand Up @@ -1904,29 +1962,3 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in,
throw std::runtime_error("CUDA kernel failed : " + std::to_string(err));
}
}

#else // defined(__gfx942__)

void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block = 4) {
TORCH_CHECK(false, "LLGemm1 not supported on current arch");
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx = 0) {
TORCH_CHECK(false, "LLGemmZZ not supported on current arch");
}

void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows,
int numAColumns, int numBRows, int numBColumns, int numCRows,
int numCColumns, cudaStream_t stream) {
TORCH_CHECK(false, "MMGPUKernel not supported on current arch");
}

void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in,
const int K_in, const int N_in, cudaStream_t stream,
const int CuCount = 0) {
TORCH_CHECK(false, "wvSpltK not supported on current arch");
}

#endif
81 changes: 50 additions & 31 deletions csrc/custom/paged_attention/attention_ll4mi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

#include <algorithm>

#if defined(__gfx942__)

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define WARP_SIZE 64

#if defined(__gfx942__) // TODO: Add NAVI support

#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16

Expand Down Expand Up @@ -746,6 +746,54 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
out_ptr[threadIdx.x] = (scalar_t)acc;
}

#else // !defined(__gfx942__) TODO: Add NAVI support

template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
#if 0
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
#endif
int max_ctx_blocks) {
assert(false);
}

// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions) {
assert(false);
}

#endif // defined(__gfx942__) TODO: Add NAVI support

#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
Expand Down Expand Up @@ -948,32 +996,3 @@ void paged_attention_custom(
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP

#else //defined(__gfx942__)

void paged_attention_custom(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int block_size, int max_context_len,
#if 0
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
TORCH_CHECK(false, "paged_attention_custom not supported on current arch");
}


#endif

0 comments on commit e2b0310

Please sign in to comment.