Skip to content

Commit 9932686

Browse files
Chris Thifacebook-github-bot
Chris Thi
authored andcommitted
Refactor Cutlass BF16 Grouped GEMM (#4124)
Summary: Pull Request resolved: #4124 X-link: facebookresearch/FBGEMM#1205 We plan to make some changes to the kernel heuristics to improve performance on this kernel. Do a quick refactor first to parallelize kernel compilation, similar with [cutlass FP8 rowwise](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/), to keep the next diffs smaller. No functional changes in this diff. Reviewed By: jianyuh Differential Revision: D74760416 fbshipit-source-id: 138fbc8b62e6d22ed60448e79050c4d1ebd470aa
1 parent 52f07e7 commit 9932686

10 files changed

+877
-478
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 8 additions & 478 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
1,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_16_128_1_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 16, 128, 1, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_16_128_1_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
16,
33+
128,
34+
1,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 256, 128, 1, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
256,
33+
128,
34+
1,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 256, 128, 2, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
256,
33+
128,
34+
2,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 32, 128, 1, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
32,
33+
128,
34+
1,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 64, 128, 1, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
64,
33+
128,
34+
1,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 256, 128, 128, 2, 1, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
256,
32+
128,
33+
128,
34+
2,
35+
1,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)