Skip to content

Commit 05a77f3

Browse files
Chris Thifacebook-github-bot
Chris Thi
authored andcommitted
Update heuristic for Cutlass BF16 Grouped GEMM
Summary: X-link: facebookresearch/FBGEMM#1220 This diff updates the heuristic used for Cutlass BF16 grouped gemm, improving performance in some important shapes. Differential Revision: D74836650
1 parent ffbec71 commit 05a77f3

7 files changed

+371
-5
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,106 @@ namespace fbgemm_gpu {
1515

1616
#if CUDART_VERSION >= 12000
1717

18-
// FP8 Tensorwise grouped cutlass kernel dispatch.
18+
// BF16 grouped cutlass kernel dispatch.
1919
template <typename InputType>
2020
at::Tensor dispatch_bf16_grouped_kernel(
21+
int G,
2122
int total_M,
23+
int N,
24+
int K,
2225
InputType X, // BF16
2326
InputType W, // BF16
2427
at::Tensor output,
2528
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
2629
std::optional<at::Tensor> M_sizes = std::nullopt) {
2730
// Use heuristics to pick best kernel implementation.
2831

32+
// Llama4 128E
33+
if (G == 128) {
34+
if (N == 5120 && K == 1024) {
35+
if (total_M <= 128) {
36+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
37+
X, W, output, zero_start_index_M, M_sizes);
38+
} else if (total_M <= 256) {
39+
return bf16bf16bf16_grouped_128_32_128_2_1_1_t(
40+
X, W, output, zero_start_index_M, M_sizes);
41+
} else if (total_M <= 2048) {
42+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
43+
X, W, output, zero_start_index_M, M_sizes);
44+
} else if (total_M <= 4096) {
45+
return bf16bf16bf16_grouped_128_32_128_2_1_1_f(
46+
X, W, output, zero_start_index_M, M_sizes);
47+
} else if (total_M <= 8192) {
48+
return bf16bf16bf16_grouped_128_64_128_1_1_1_f(
49+
X, W, output, zero_start_index_M, M_sizes);
50+
} else if (total_M <= 16384) {
51+
return bf16bf16bf16_grouped_128_128_128_2_1_1_t(
52+
X, W, output, zero_start_index_M, M_sizes);
53+
} else {
54+
return bf16bf16bf16_grouped_128_256_128_2_1_1_f(
55+
X, W, output, zero_start_index_M, M_sizes);
56+
}
57+
}
58+
59+
if (N == 2048 && K == 5120) {
60+
if (total_M <= 2048) {
61+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
62+
X, W, output, zero_start_index_M, M_sizes);
63+
} else {
64+
return bf16bf16bf16_grouped_128_128_128_2_1_1_t(
65+
X, W, output, zero_start_index_M, M_sizes);
66+
}
67+
}
68+
}
69+
70+
// Llama4 64E
71+
if (G == 16) {
72+
if (N == 5120 && K == 1024) {
73+
if (total_M <= 32) {
74+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
75+
X, W, output, zero_start_index_M, M_sizes);
76+
} else if (total_M <= 64) {
77+
return bf16bf16bf16_grouped_128_32_128_2_1_1_t(
78+
X, W, output, zero_start_index_M, M_sizes);
79+
} else if (total_M <= 256) {
80+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
81+
X, W, output, zero_start_index_M, M_sizes);
82+
} else if (total_M <= 512) {
83+
return bf16bf16bf16_grouped_128_32_128_2_1_1_t(
84+
X, W, output, zero_start_index_M, M_sizes);
85+
} else if (total_M <= 1024) {
86+
return bf16bf16bf16_grouped_128_64_128_2_1_1_t(
87+
X, W, output, zero_start_index_M, M_sizes);
88+
} else {
89+
return bf16bf16bf16_grouped_128_256_128_2_1_1_f(
90+
X, W, output, zero_start_index_M, M_sizes);
91+
}
92+
}
93+
94+
if (N == 2048 && K == 5120) {
95+
if (total_M <= 16) {
96+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
97+
X, W, output, zero_start_index_M, M_sizes);
98+
} else if (total_M <= 64) {
99+
return bf16bf16bf16_grouped_128_32_128_2_1_1_f(
100+
X, W, output, zero_start_index_M, M_sizes);
101+
} else if (total_M <= 256) {
102+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
103+
X, W, output, zero_start_index_M, M_sizes);
104+
} else if (total_M <= 512) {
105+
return bf16bf16bf16_grouped_128_32_128_2_1_1_f(
106+
X, W, output, zero_start_index_M, M_sizes);
107+
} else if (total_M <= 1024) {
108+
return bf16bf16bf16_grouped_128_64_128_1_1_1_f(
109+
X, W, output, zero_start_index_M, M_sizes);
110+
} else {
111+
return bf16bf16bf16_grouped_128_128_128_2_1_1_t(
112+
X, W, output, zero_start_index_M, M_sizes);
113+
}
114+
}
115+
}
116+
117+
// Fallback to legacy heuristic for now.
29118
if (total_M <= 16) {
30119
return bf16bf16bf16_grouped_128_16_128_1_1_1_f(
31120
X, W, output, zero_start_index_M, M_sizes);
@@ -52,13 +141,18 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
52141
at::Tensor Y;
53142
int64_t total_M = 0;
54143
int64_t G = X.size();
144+
int64_t max_N = 0;
145+
int64_t max_K = 0;
55146

56147
// Allocate output tensor.
57148
std::vector<int64_t> output_sizes;
58149
int64_t total_output_size = 0;
59150
for (int i = 0; i < G; ++i) {
60151
int64_t M = X[i].size(0);
61152
int64_t N = W[i].size(0);
153+
int64_t K = W[i].size(1);
154+
max_N = std::max(max_N, N);
155+
max_K = std::max(max_K, K);
62156
total_M += M;
63157
const int64_t output_size = M * N;
64158
total_output_size += output_size;
@@ -67,8 +161,8 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
67161
Y = at::empty(total_output_size, X[0].options().dtype(at::kBFloat16));
68162

69163
// Run kernel.
70-
at::Tensor g_out =
71-
dispatch_bf16_grouped_kernel<at::TensorList>(total_M, X, W, Y);
164+
at::Tensor g_out = dispatch_bf16_grouped_kernel<at::TensorList>(
165+
G, total_M, max_N, max_K, X, W, Y);
72166

73167
// Return appropriate output type.
74168
if constexpr (std::is_same_v<OutputType, at::Tensor>) {
@@ -98,6 +192,7 @@ at::Tensor
98192
bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
99193
int64_t total_M = X.size(0);
100194
int64_t N = W.size(1);
195+
int64_t K = W.size(2);
101196
int64_t G = M_sizes.size(0);
102197
TORCH_CHECK(
103198
M_sizes.device() == X.device(),
@@ -111,7 +206,7 @@ bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
111206
}
112207
// Return continuous view of output.
113208
at::Tensor out = dispatch_bf16_grouped_kernel<at::Tensor>(
114-
total_M, X, W, Y, std::nullopt, M_sizes);
209+
G, total_M, N, K, X, W, Y, std::nullopt, M_sizes);
115210
return out.view({total_M, N});
116211
}
117212

@@ -125,13 +220,14 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
125220
int64_t G = X.size(0);
126221
int64_t M = X.size(1);
127222
int64_t N = W.size(1);
223+
int64_t K = W.size(2);
128224
int64_t total_output_size = G * M * N;
129225
at::Tensor Y;
130226
Y = at::zeros(total_output_size, X.options().dtype(at::kBFloat16));
131227

132228
// Return continuous view of output.
133229
at::Tensor output = dispatch_bf16_grouped_kernel<at::Tensor>(
134-
G * M, X, W, Y, zero_start_index_M);
230+
G, G * M, N, K, X, W, Y, zero_start_index_M);
135231
// View as proper shape.
136232
return output.view({G, M, N});
137233
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
2,
35+
1,
36+
1,
37+
true>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 16, 128, 2, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
16,
33+
128,
34+
2,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 32, 128, 2, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
32,
33+
128,
34+
2,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 32, 128, 2, 1, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<at::TensorList, 128, 32, 128, 2, 1, 1, true>(
30+
X, W, output, zero_start_index_M, M_sizes);
31+
}
32+
33+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 64, 128, 2, 1, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<at::TensorList, 128, 64, 128, 2, 1, 1, true>(
30+
X, W, output, zero_start_index_M, M_sizes);
31+
}
32+
33+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)