forked from DefTruth/CUDA-Learn-Notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsgemv.cu
102 lines (96 loc) · 3.58 KB
/
sgemv.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/types.h>
#include <torch/extension.h>
#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
// -------------------------------------- FP32 --------------------------------------
// Warp Reduce Sum
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
// SGEMV: Warp SGEMV K32
// 假设K为32的倍数,每个warp负责一行
// grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
// a: MxK, x: Kx1, y: Mx1, compute: y = a * x
__global__ void sgemv_k32(float* a, float* x, float* y, int M, int K) {
int tx = threadIdx.x; // 0~31
int ty = threadIdx.y; // 0~4
int bx = blockIdx.x; // 0~M/4
int lane = tx % WARP_SIZE; // 0~31
int m = bx * blockDim.y + ty; // (0~M/4) * 4 + (0~3)
if (m < M) {
float sum = 0.0f;
int NUM_WARPS = (K + WARP_SIZE - 1) / WARP_SIZE;
#pragma unroll
for (int w = 0; w < NUM_WARPS; ++w) {
// 若NUM_WARPS>=2,先将当前行的数据累加到第一个warp中
int k = w * WARP_SIZE + lane;
sum += a[m * K + k] * x[k];
}
sum = warp_reduce_sum_f32<WARP_SIZE>(sum);
if (lane == 0) y[m] = sum;
}
}
// SGEMV: Warp SGEMV K128 + Vec4
// 假设K为128的倍数 float4
// grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
// a: MxK, x: Kx1, y: Mx1, compute: y = a * x
__global__ void sgemv_k128_f32x4(float* a, float* x, float* y, int M, int K) {
// 每个线程负责4个元素,一个warp覆盖128个元素
int tx = threadIdx.x; // 0~31
int ty = threadIdx.y; // 0~3
int bx = blockIdx.x; // 0~M/4
int lane = tx % WARP_SIZE; // 0~31
int m = blockDim.y * bx + ty; // (0~M/4) * 4 + (0~3)
if (m < M) {
float sum = 0.0f;
// process 4*WARP_SIZE elements per warp.
int NUM_WARPS = (((K + WARP_SIZE - 1) / WARP_SIZE) + 4 - 1) / 4;
#pragma unroll
for (int w = 0; w < NUM_WARPS; ++w) {
int k = (w * WARP_SIZE + lane) * 4;
float4 reg_x = FLOAT4(x[k]);
float4 reg_a = FLOAT4(a[m * K + k]);
sum += (reg_a.x * reg_x.x + reg_a.y * reg_x.y
+ reg_a.z * reg_x.z + reg_a.w * reg_x.w);
}
sum = warp_reduce_sum_f32<WARP_SIZE>(sum);
if(lane == 0) y[m] = sum;
}
}
// SGEMV: Warp SGEMV K16
// 假设K为16 < 32,每个warp负责2行,每行有16个元素
// NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE;
// NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS)
// a: MxK, x: Kx1, y: Mx1, compute: y = a * x
template<const int ROW_PER_WARP = 2>
__global__ void sgemv_k16(float* A, float* x, float* y, int M, int K) {
constexpr int K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1) / ROW_PER_WARP;
int tx = threadIdx.x; // 0~31
int ty = threadIdx.y; // 0~NUM_WARPS
int bx = blockIdx.x; // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP)
int lane = tx % WARP_SIZE; // 0~31
int k = lane % K_WARP_SIZE; // 0~15
// gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS
int m = (blockDim.y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE;
if (m < M) {
float sum = A[m * K + k] * x[k];
sum = warp_reduce_sum_f32<K_WARP_SIZE>(sum);
// 注意是k == 0,而不是lane == 0
if(k == 0) y[m] = sum;
}
}