@@ -15,17 +15,106 @@ namespace fbgemm_gpu {
15
15
16
16
#if CUDART_VERSION >= 12000
17
17
18
- // FP8 Tensorwise grouped cutlass kernel dispatch.
18
+ // BF16 grouped cutlass kernel dispatch.
19
19
template <typename InputType>
20
20
at::Tensor dispatch_bf16_grouped_kernel (
21
+ int G,
21
22
int total_M,
23
+ int N,
24
+ int K,
22
25
InputType X, // BF16
23
26
InputType W, // BF16
24
27
at::Tensor output,
25
28
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
26
29
std::optional<at::Tensor> M_sizes = std::nullopt) {
27
30
// Use heuristics to pick best kernel implementation.
28
31
32
+ // Llama4 128E
33
+ if (G == 128 ) {
34
+ if (N == 5120 && K == 1024 ) {
35
+ if (total_M <= 128 ) {
36
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
37
+ X, W, output, zero_start_index_M, M_sizes);
38
+ } else if (total_M <= 256 ) {
39
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_t (
40
+ X, W, output, zero_start_index_M, M_sizes);
41
+ } else if (total_M <= 2048 ) {
42
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
43
+ X, W, output, zero_start_index_M, M_sizes);
44
+ } else if (total_M <= 4096 ) {
45
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_f (
46
+ X, W, output, zero_start_index_M, M_sizes);
47
+ } else if (total_M <= 8192 ) {
48
+ return bf16bf16bf16_grouped_128_64_128_1_1_1_f (
49
+ X, W, output, zero_start_index_M, M_sizes);
50
+ } else if (total_M <= 16384 ) {
51
+ return bf16bf16bf16_grouped_128_128_128_2_1_1_t (
52
+ X, W, output, zero_start_index_M, M_sizes);
53
+ } else {
54
+ return bf16bf16bf16_grouped_128_256_128_2_1_1_f (
55
+ X, W, output, zero_start_index_M, M_sizes);
56
+ }
57
+ }
58
+
59
+ if (N == 2048 && K == 5120 ) {
60
+ if (total_M <= 2048 ) {
61
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
62
+ X, W, output, zero_start_index_M, M_sizes);
63
+ } else {
64
+ return bf16bf16bf16_grouped_128_128_128_2_1_1_t (
65
+ X, W, output, zero_start_index_M, M_sizes);
66
+ }
67
+ }
68
+ }
69
+
70
+ // Llama4 64E
71
+ if (G == 16 ) {
72
+ if (N == 5120 && K == 1024 ) {
73
+ if (total_M <= 32 ) {
74
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
75
+ X, W, output, zero_start_index_M, M_sizes);
76
+ } else if (total_M <= 64 ) {
77
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_t (
78
+ X, W, output, zero_start_index_M, M_sizes);
79
+ } else if (total_M <= 256 ) {
80
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
81
+ X, W, output, zero_start_index_M, M_sizes);
82
+ } else if (total_M <= 512 ) {
83
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_t (
84
+ X, W, output, zero_start_index_M, M_sizes);
85
+ } else if (total_M <= 1024 ) {
86
+ return bf16bf16bf16_grouped_128_64_128_2_1_1_t (
87
+ X, W, output, zero_start_index_M, M_sizes);
88
+ } else {
89
+ return bf16bf16bf16_grouped_128_256_128_2_1_1_f (
90
+ X, W, output, zero_start_index_M, M_sizes);
91
+ }
92
+ }
93
+
94
+ if (N == 2048 && K == 5120 ) {
95
+ if (total_M <= 16 ) {
96
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
97
+ X, W, output, zero_start_index_M, M_sizes);
98
+ } else if (total_M <= 64 ) {
99
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_f (
100
+ X, W, output, zero_start_index_M, M_sizes);
101
+ } else if (total_M <= 256 ) {
102
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
103
+ X, W, output, zero_start_index_M, M_sizes);
104
+ } else if (total_M <= 512 ) {
105
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_f (
106
+ X, W, output, zero_start_index_M, M_sizes);
107
+ } else if (total_M <= 1024 ) {
108
+ return bf16bf16bf16_grouped_128_64_128_1_1_1_f (
109
+ X, W, output, zero_start_index_M, M_sizes);
110
+ } else {
111
+ return bf16bf16bf16_grouped_128_128_128_2_1_1_t (
112
+ X, W, output, zero_start_index_M, M_sizes);
113
+ }
114
+ }
115
+ }
116
+
117
+ // Fallback to legacy heuristic for now.
29
118
if (total_M <= 16 ) {
30
119
return bf16bf16bf16_grouped_128_16_128_1_1_1_f (
31
120
X, W, output, zero_start_index_M, M_sizes);
@@ -52,13 +141,18 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
52
141
at::Tensor Y;
53
142
int64_t total_M = 0 ;
54
143
int64_t G = X.size ();
144
+ int64_t max_N = 0 ;
145
+ int64_t max_K = 0 ;
55
146
56
147
// Allocate output tensor.
57
148
std::vector<int64_t > output_sizes;
58
149
int64_t total_output_size = 0 ;
59
150
for (int i = 0 ; i < G; ++i) {
60
151
int64_t M = X[i].size (0 );
61
152
int64_t N = W[i].size (0 );
153
+ int64_t K = W[i].size (1 );
154
+ max_N = std::max (max_N, N);
155
+ max_K = std::max (max_K, K);
62
156
total_M += M;
63
157
const int64_t output_size = M * N;
64
158
total_output_size += output_size;
@@ -67,8 +161,8 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
67
161
Y = at::empty (total_output_size, X[0 ].options ().dtype (at::kBFloat16 ));
68
162
69
163
// Run kernel.
70
- at::Tensor g_out =
71
- dispatch_bf16_grouped_kernel<at::TensorList>( total_M, X, W, Y);
164
+ at::Tensor g_out = dispatch_bf16_grouped_kernel<at::TensorList>(
165
+ G, total_M, max_N, max_K , X, W, Y);
72
166
73
167
// Return appropriate output type.
74
168
if constexpr (std::is_same_v<OutputType, at::Tensor>) {
@@ -98,6 +192,7 @@ at::Tensor
98
192
bf16bf16bf16_grouped_stacked (at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
99
193
int64_t total_M = X.size (0 );
100
194
int64_t N = W.size (1 );
195
+ int64_t K = W.size (2 );
101
196
int64_t G = M_sizes.size (0 );
102
197
TORCH_CHECK (
103
198
M_sizes.device () == X.device (),
@@ -111,7 +206,7 @@ bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
111
206
}
112
207
// Return continuous view of output.
113
208
at::Tensor out = dispatch_bf16_grouped_kernel<at::Tensor>(
114
- total_M, X, W, Y, std::nullopt, M_sizes);
209
+ G, total_M, N, K , X, W, Y, std::nullopt, M_sizes);
115
210
return out.view ({total_M, N});
116
211
}
117
212
@@ -125,13 +220,14 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
125
220
int64_t G = X.size (0 );
126
221
int64_t M = X.size (1 );
127
222
int64_t N = W.size (1 );
223
+ int64_t K = W.size (2 );
128
224
int64_t total_output_size = G * M * N;
129
225
at::Tensor Y;
130
226
Y = at::zeros (total_output_size, X.options ().dtype (at::kBFloat16 ));
131
227
132
228
// Return continuous view of output.
133
229
at::Tensor output = dispatch_bf16_grouped_kernel<at::Tensor>(
134
- G * M, X, W, Y, zero_start_index_M);
230
+ G, G * M, N, K , X, W, Y, zero_start_index_M);
135
231
// View as proper shape.
136
232
return output.view ({G, M, N});
137
233
}
0 commit comments