Skip to content

Commit e135dd5

Browse files
DDElelalala-shAMD-dtengsolinzbysjfeng
authored
[CK_TILE] Add mxfp4 flatmm (#3080)
* Squashed commit of the following: commit 3e1a851 Author: Ding, Yi <[email protected]> Date: Thu Oct 23 06:10:54 2025 +0000 Fix & clean after rebase commit 1edf485 Author: Ding, Yi <[email protected]> Date: Wed Oct 22 10:46:13 2025 +0000 Squashed commit of the following: commit 0b6b9db Author: mtgu0705 <[email protected]> Date: Mon Sep 22 02:04:27 2025 -0500 fix bandwidth calculation commit 9aebf53 Author: mtgu0705 <[email protected]> Date: Mon Sep 22 00:58:59 2025 -0500 updates commit 62607de Author: mtgu0705 <[email protected]> Date: Fri Sep 19 00:39:46 2025 -0500 fix a bug, set the A DS_read preload size to 4 for MXFP4 commit 92ad6fc Author: mtgu0705 <[email protected]> Date: Thu Sep 18 01:19:03 2025 -0500 fix a_wrap preload issue for large MPerBlock. commit f2db447 Author: mtgu0705 <[email protected]> Date: Wed Sep 17 21:34:03 2025 -0500 optimized the VGPR repack issue for MXFP4 commit 346a400 Author: Gino Lu <[email protected]> Date: Wed Sep 17 04:19:44 2025 -0500 fix time error commit 80c1743 Author: mtgu0705 <[email protected]> Date: Wed Sep 17 03:58:00 2025 -0500 updated, function passed. commit ce26d90 Author: mtgu0705 <[email protected]> Date: Tue Sep 16 22:21:39 2025 -0500 fix, function partially passed commit 0a89ed1 Author: mtgu0705 <[email protected]> Date: Tue Sep 16 03:01:12 2025 -0500 fix, reference function passed, next check kernel function commit ec9bcef Author: Gino Lu <[email protected]> Date: Tue Sep 16 02:29:01 2025 -0500 let pack/unpack return pk_fp4_t commit a333206 Author: mtgu0705 <[email protected]> Date: Mon Sep 15 20:50:26 2025 -0500 fix commit 3893c06 Author: Gino Lu <[email protected]> Date: Mon Sep 15 05:51:06 2025 -0500 fix bug commit 8052bea Author: mtgu0705 <[email protected]> Date: Mon Sep 15 04:02:05 2025 -0500 fix core dump issue, function is not correct. commit 9ceb3fd Author: mtgu0705 <[email protected]> Date: Mon Sep 15 03:03:02 2025 -0500 updates, build pass commit cc94eb6 Author: mtgu0705 <[email protected]> Date: Mon Sep 15 00:05:18 2025 -0500 updates commit 22586c3 Author: Gino Lu <[email protected]> Date: Sun Sep 14 23:40:28 2025 -0500 fix bug commit e92e67b Author: Gino Lu <[email protected]> Date: Fri Sep 12 03:28:50 2025 -0500 fix interface commit 8b1dd60 Author: Gino Lu <[email protected]> Date: Fri Sep 12 02:53:50 2025 -0500 add interface in warp_gemm_impl commit c6135f6 Author: mtgu0705 <[email protected]> Date: Wed Sep 10 05:03:08 2025 -0500 updates some fixes. commit b0d71b8 Author: mtgu0705 <[email protected]> Date: Tue Sep 9 04:37:42 2025 -0500 fix after merge ginolu/add_wgmfma_dispatcher commit f119c30 Merge: c5030e6 72c8ef8 Author: mtgu0705 <[email protected]> Date: Mon Sep 8 22:09:15 2025 -0500 Merge remote-tracking branch 'origin/ginolu/add_wgmfma_dispatcher' into mtgu/cktile_mxfp4_flatmm_dev commit c5030e6 Author: mtgu0705 <[email protected]> Date: Mon Sep 8 21:42:47 2025 -0500 update mx flatmm tail pipeline commit 72c8ef8 Merge: 9661bb4 e4a7728 Author: Gino Lu <[email protected]> Date: Mon Sep 8 19:10:23 2025 -0500 Merge branch 'develop' into ginolu/add_wgmfma_dispatcher commit 9661bb4 Author: Gino Lu <[email protected]> Date: Mon Sep 8 19:09:55 2025 -0500 fix type error commit 0509597 Author: mtgu0705 <[email protected]> Date: Mon Sep 8 04:01:40 2025 -0500 update hotloop pipeline commit 754ae04 Merge: 15d4440 83f607e Author: Gino Lu <[email protected]> Date: Fri Sep 5 04:22:26 2025 -0500 Merge branch 'develop' into ginolu/add_wgmfma_dispatcher commit 15d4440 Author: Gino Lu <[email protected]> Date: Fri Sep 5 04:21:26 2025 -0500 fix clang format commit 146963d Author: mtgu0705 <[email protected]> Date: Wed Sep 3 10:00:54 2025 -0500 some updates commit 12526b6 Merge: 47cee04 00fd72b Author: asleepzzz <[email protected]> Date: Wed Sep 3 13:22:03 2025 +0800 Merge branch 'develop' into ginolu/add_wgmfma_dispatcher commit 47cee04 Author: Gino Lu <[email protected]> Date: Mon Sep 1 02:11:02 2025 -0500 fix vec size error commit d289292 Author: Gino Lu <[email protected]> Date: Mon Sep 1 01:23:39 2025 -0500 fix format error commit 16993ac Author: mtgu0705 <[email protected]> Date: Sat Aug 30 03:19:07 2025 -0500 update codes commit 9c37e55 Author: mtgu0705 <[email protected]> Date: Fri Aug 29 11:27:33 2025 -0500 init ck_tile mxfp4 flatmm commit 5c484a5 Author: Feng Shijie <[email protected]> Date: Thu Aug 28 08:02:50 2025 +0000 Add bias for f16xf4 moe_flatmm commit dd6539f Author: Feng Shijie <[email protected]> Date: Wed Aug 27 13:39:47 2025 +0000 update case construction commit 65b7024 Author: Feng Shijie <[email protected]> Date: Tue Aug 26 12:32:29 2025 +0000 support swiglu activaion and use rcpf to accelerate silu commit b422e41 Author: Gino Lu <[email protected]> Date: Tue Aug 26 02:33:55 2025 -0500 first commit commit d05eed9 Author: root <[email protected]> Date: Fri Aug 22 04:01:59 2025 -0500 add line to last commit d69cab7 Author: root <[email protected]> Date: Fri Aug 22 03:20:46 2025 -0500 adjust A_LDS descriptor to avoid bankconflict commit 65989e9 Author: root <[email protected]> Date: Thu Aug 21 09:46:52 2025 -0500 enable hotloop commit c378e9b Author: Feng Shijie <[email protected]> Date: Thu Aug 21 09:12:21 2025 +0000 support atomic_pk_add_bf16 on gfx950 commit 85976b0 Author: Feng Shijie <[email protected]> Date: Thu Aug 21 06:58:55 2025 +0000 use int64_t as expert stride to avoid overflow commit 9fbcc8f Author: Feng Shijie <[email protected]> Date: Wed Aug 20 13:53:32 2025 +0000 use v4i32 as the storage type for B to avoid repack operation commit 81899bd Author: Feng Shijie <[email protected]> Date: Wed Aug 20 06:40:03 2025 +0000 add pk_fp4_t and e8m0_t support for amd_buffer_load_impl commit c27eb07 Author: Feng Shijie <[email protected]> Date: Wed Aug 20 04:39:14 2025 +0000 optimize cvt_pkf4_to_f16 implementation commit 3ca0bd5 Author: Feng Shijie <[email protected]> Date: Tue Aug 19 14:56:46 2025 +0000 optimize A_LDS descriptor to avoid bankconflict commit f7f0306 Author: Feng Shijie <[email protected]> Date: Mon Aug 18 18:43:37 2025 +0000 fix gate-up when GU_NRepeat > 1 commit be55c0f Author: Feng Shijie <[email protected]> Date: Mon Aug 18 17:28:11 2025 +0000 add fp16xf4 moe commit 599e1f5 Author: Feng Shijie <[email protected]> Date: Sun Aug 17 17:51:18 2025 +0000 rename example commit 7899fb4 Author: Feng Shijie <[email protected]> Date: Fri Aug 15 06:20:46 2025 +0000 remove additional check when e8m0->float commit 714b341 Author: Feng Shijie <[email protected]> Date: Thu Aug 14 09:34:12 2025 +0000 eliminate repeat dequant commit 53e8c0c Merge: 5de6208 cc9c7b9 Author: Feng Shijie <[email protected]> Date: Wed Aug 13 16:51:49 2025 +0000 Merge remote-tracking branch 'origin/moe_flatmm' into feat-mixed_input_flatmm commit 5de6208 Author: Feng Shijie <[email protected]> Date: Wed Aug 13 16:16:48 2025 +0000 update f16xMXF4 commit 732ebde Author: Feng Shijie <[email protected]> Date: Wed Aug 13 10:48:53 2025 +0000 update scale-preshuffle for MXF4 commit edb58d0 Author: Feng Shijie <[email protected]> Date: Mon Aug 11 11:24:34 2025 +0000 update commit cc9c7b9 Author: Feng Shijie <[email protected]> Date: Mon Aug 11 08:38:23 2025 +0000 optimize gemm2 atomic_add pattern commit 200a11a Author: Feng Shijie <[email protected]> Date: Mon Aug 11 07:59:47 2025 +0000 update scale for mxfp4 commit 87aed56 Author: Feng Shijie <[email protected]> Date: Mon Aug 11 07:56:14 2025 +0000 update case construction commit 8b85fa6 Author: Feng Shijie <[email protected]> Date: Mon Aug 11 06:03:06 2025 +0000 update granularity control commit 1b8c709 Author: Feng Shijie <[email protected]> Date: Mon Aug 11 03:42:46 2025 +0000 fix TileConfig commit 8ba1c70 Author: Gino Lu <[email protected]> Date: Thu Aug 7 21:37:28 2025 +0800 Add e8m0 scaled convert into CK_TILE (#2617) * first commit * remove redundent code * modify according to comments. * fix type_convert error with scaled_type_convert commit f788d3d Author: Feng Shijie <[email protected]> Date: Fri Aug 8 20:19:16 2025 +0000 add mixed_prec fp16xfp4 commit 3dea10a Author: Feng Shijie <[email protected]> Date: Thu Aug 7 09:22:04 2025 +0000 debug mixed_prec flatmm commit 0ba513b Merge: 90e910f c0cb4d0 Author: lalala-sh <[email protected]> Date: Wed Aug 6 16:49:47 2025 +0800 Merge pull request #2626 from ROCm/felix/flatmm_fix_splitk fix split k commit 6d3cbc7 Author: Feng Shijie <[email protected]> Date: Wed Aug 6 08:33:33 2025 +0000 add moe_flatmm commit c0cb4d0 Author: coderfeli <[email protected]> Date: Wed Aug 6 02:45:31 2025 +0000 fix split k commit 90e910f Author: Feng Shijie <[email protected]> Date: Mon Aug 4 07:16:36 2025 +0000 fix flatmm with scaling when WarpTileM == 32 commit aa5e008 Author: Feng Shijie <[email protected]> Date: Fri Aug 1 11:01:23 2025 +0000 optimize scaling epilogue commit ac5908c Author: Feng Shijie <[email protected]> Date: Fri Aug 1 07:28:38 2025 +0000 fix wrong config for fp8 scaling commit 3f43b84 Author: Feng Shijie <[email protected]> Date: Wed Jul 30 06:20:30 2025 +0000 prune debug message commit 2e5d4c7 Author: Feng Shijie <[email protected]> Date: Wed Jul 30 04:52:08 2025 +0000 fix compile error commit c117a19 Author: Feng Shijie <[email protected]> Date: Tue Jul 29 15:42:58 2025 +0000 Add persistent option on flatmm for tuning commit a587701 Author: AMD-dteng <[email protected]> Date: Tue Jul 29 22:48:00 2025 +0800 update pipeline v1: add atomic IGLP schedule commit f9e4814 Author: lalala-sh <[email protected]> Date: Thu Jul 24 09:09:27 2025 +0000 fix error log throwing commit 1b6d7cf Author: Feng Shijie <[email protected]> Date: Mon Jul 28 08:24:51 2025 +0000 crz idea commit 5473f06 Author: Feng Shijie <[email protected]> Date: Sun Jul 27 11:57:38 2025 +0000 Add permuteN optimzization when NRepeat % 2 == 0 on flatmm commit bfb9f40 Author: sjfeng <[email protected]> Date: Sun Jul 27 17:24:08 2025 +0800 try to remove c_shuffle_lds commit 1264f4d Author: Feng Shijie <[email protected]> Date: Fri Jul 25 07:41:48 2025 +0000 fix loop-dim mismatch and improve c_shuffle alu parallelism commit 1239d8a Merge: 4066454 b908f5e Author: lalala-sh <[email protected]> Date: Thu Jul 24 08:46:51 2025 +0000 merge flatmm -scale commit 4066454 Author: lalala-sh <[email protected]> Date: Thu Jul 24 16:19:58 2025 +0800 revert delete of inc file commit 6839098 Author: solin <[email protected]> Date: Thu Jul 24 04:38:16 2025 +0000 reorg flatmm code commit b908f5e Author: Feng Shijie <[email protected]> Date: Wed Jul 23 19:12:31 2025 +0000 fix flatmm syntax error on gfx950 commit 5a1183e Author: Feng Shijie <[email protected]> Date: Wed Jul 23 19:04:22 2025 +0000 support flatmm scaling commit 89fa639 Author: valarLip <[email protected]> Date: Wed Jul 23 08:44:12 2025 +0000 merge flatmm pipe v0 from dteng_flatmm_opt commit 3f7d848 Author: lalala-sh <[email protected]> Date: Wed Jul 23 15:38:12 2025 +0800 build pass commit 6dacf83 Author: lalala-sh <[email protected]> Date: Wed Jul 23 07:20:26 2025 +0000 fix bug commit 7e1bd4b Author: lalala-sh <[email protected]> Date: Wed Jul 23 15:01:53 2025 +0800 sync commit 46a538e Author: valarLip <[email protected]> Date: Tue Jul 22 08:09:35 2025 +0000 adaptive scheduler instead of Macro definition commit 9aa3396 Author: lalala-sh <[email protected]> Date: Thu Jul 17 08:40:35 2025 +0000 fix tail handler bug commit fb76450 Author: lalala-sh <[email protected]> Date: Wed Jul 16 10:12:19 2025 +0000 merge from dteng_flatmm_opt --------- Co-authored-by: lalala-sh <[email protected]> Co-authored-by: AMD-dteng <[email protected]> Co-authored-by: solin <[email protected]> Co-authored-by: sjfeng <[email protected]> Co-authored-by: valarLip <[email protected]> Co-authored-by: asleepzzz <[email protected]> Co-authored-by: Feng Shijie <[email protected]> Co-authored-by: coderfeli <[email protected]> Co-authored-by: Gino Lu <[email protected]> Co-authored-by: mtgu0705 <[email protected]> * Fix crash on small M * Apply suggestion from @Copilot --------- Co-authored-by: lalala-sh <[email protected]> Co-authored-by: AMD-dteng <[email protected]> Co-authored-by: solin <[email protected]> Co-authored-by: sjfeng <[email protected]> Co-authored-by: valarLip <[email protected]> Co-authored-by: asleepzzz <[email protected]> Co-authored-by: Feng Shijie <[email protected]> Co-authored-by: coderfeli <[email protected]> Co-authored-by: Gino Lu <[email protected]> Co-authored-by: mtgu0705 <[email protected]>
1 parent b387249 commit e135dd5

File tree

13 files changed

+2953
-6
lines changed

13 files changed

+2953
-6
lines changed

example/ck_tile/18_flatmm/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ if(has_supported_gpu)
1414
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
1515
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
1616
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
17+
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only
1718

1819
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
1920
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
@@ -27,6 +28,6 @@ if(has_supported_gpu)
2728
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
2829
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
2930
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
30-
31+
target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only
3132
endif()
3233

example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp

Lines changed: 506 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
// SPDX-License-Identifier: MIT
3+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
4+
5+
#pragma once
6+
7+
#include <string>
8+
9+
#include "ck_tile/core.hpp"
10+
#include "ck_tile/host/kernel_launch.hpp"
11+
#include "ck_tile/ops/epilogue.hpp"
12+
#include "ck_tile/ops/flatmm.hpp"
13+
#include "ck_tile/ops/gemm.hpp"
14+
15+
#include "mxfp4_flatmm.hpp"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
// SPDX-License-Identifier: MIT
3+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
4+
5+
#pragma once
6+
7+
#include "ck_tile/core.hpp"
8+
9+
// GEMM config with 16x16 warp tile
10+
struct MXfp4_FlatmmConfig16
11+
{
12+
static constexpr ck_tile::index_t M_Tile = 128;
13+
static constexpr ck_tile::index_t N_Tile = 512;
14+
static constexpr ck_tile::index_t K_Tile = 256;
15+
16+
static constexpr ck_tile::index_t M_Warp = 1;
17+
static constexpr ck_tile::index_t N_Warp = 4;
18+
static constexpr ck_tile::index_t K_Warp = 1;
19+
20+
static constexpr ck_tile::index_t M_Warp_Tile = 16;
21+
static constexpr ck_tile::index_t N_Warp_Tile = 16;
22+
static constexpr ck_tile::index_t K_Warp_Tile = 128;
23+
24+
static constexpr bool kPadM = false;
25+
static constexpr bool kPadN = false;
26+
static constexpr bool kPadK = false;
27+
28+
static constexpr bool TransposeC = false;
29+
static constexpr bool UseStructuredSparsity = false;
30+
31+
static constexpr int kBlockPerCu = 1;
32+
static constexpr int TileParitionerGroupNum = 8;
33+
static constexpr int TileParitionerM01 = 4;
34+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
35+
static constexpr ck_tile::index_t NumWaveGroups = 1;
36+
static constexpr bool DoubleSmemBuffer = false;
37+
38+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
39+
static constexpr bool TiledMMAPermuteN = false;
40+
};
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
template <typename PrecActType,
5+
typename PrecWeightType,
6+
typename CDataType,
7+
typename FlatmmConfig,
8+
bool UsePersistentKernel = false,
9+
typename ALayout,
10+
typename BLayout,
11+
typename CLayout>
12+
int run_mx_flatmm_with_layouts(int argc,
13+
char* argv[],
14+
const ALayout a_layout = ALayout{},
15+
const BLayout b_layout = BLayout{},
16+
const CLayout c_layout = CLayout{})
17+
{
18+
auto [result, arg_parser] = create_args(argc, argv);
19+
if(!result)
20+
return -1;
21+
22+
using ADataType = PrecActType;
23+
using BDataType = PrecWeightType;
24+
using AccDataType = float;
25+
26+
using ScaleType = ck_tile::e8m0_t;
27+
28+
constexpr int ScaleGranularityM = 1;
29+
constexpr int ScaleGranularityN = 1;
30+
constexpr int ScaleGranularityK = 32;
31+
32+
ck_tile::index_t M = arg_parser.get_int("m");
33+
ck_tile::index_t N = arg_parser.get_int("n");
34+
ck_tile::index_t K = arg_parser.get_int("k");
35+
36+
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
37+
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
38+
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
39+
40+
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
41+
ck_tile::index_t init_method = arg_parser.get_int("init");
42+
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
43+
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
44+
45+
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
46+
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
47+
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
48+
49+
auto scale_stride_A = ck_tile::get_default_stride(
50+
M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout));
51+
auto scale_stride_B = ck_tile::get_default_stride(
52+
K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout));
53+
54+
if(K % ScaleGranularityK != 0)
55+
throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK.");
56+
if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 ||
57+
K % ck_tile::numeric_traits<BDataType>::PackedSize != 0)
58+
throw std::runtime_error("wrong! K must be multiple of packed size.");
59+
60+
ck_tile ::HostTensor<ADataType> a_host(
61+
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
62+
ck_tile::HostTensor<BDataType> b_origin_host(
63+
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
64+
ck_tile::HostTensor<CDataType> c_rslt_host(
65+
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
66+
67+
ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor(
68+
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
69+
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
70+
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
71+
72+
if(init_method == 0)
73+
{
74+
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
75+
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
76+
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a);
77+
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
78+
}
79+
else if(init_method == 1)
80+
{
81+
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
82+
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
83+
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a);
84+
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
85+
}
86+
else
87+
{
88+
throw std::runtime_error("wrong! Unexpected init_method");
89+
}
90+
91+
ck_tile::HostTensor<BDataType> b_shuffled_host(
92+
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
93+
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
94+
95+
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
96+
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);
97+
98+
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
99+
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
100+
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
101+
102+
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
103+
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
104+
105+
a_dev_buf.ToDevice(a_host.data());
106+
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
107+
c_rslt_host.SetZero();
108+
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
109+
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
110+
111+
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
112+
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
113+
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
114+
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
115+
116+
invoke_mx_flatmm<FlatmmConfig,
117+
ADataType,
118+
BDataType,
119+
ck_tile::tuple<>,
120+
AccDataType,
121+
CDataType,
122+
ALayout,
123+
BLayout,
124+
ck_tile::tuple<>,
125+
CLayout,
126+
decltype(scale_a_dev_ptr),
127+
decltype(scale_b_dev_ptr),
128+
UsePersistentKernel>(a_dev_buf,
129+
b_shuffled_dev_buf,
130+
c_dev_buf,
131+
M,
132+
N,
133+
K,
134+
stride_A,
135+
stride_B,
136+
stride_C,
137+
kbatch,
138+
scale_a_dev_ptr,
139+
scale_b_dev_ptr,
140+
n_warmup,
141+
n_repeat);
142+
143+
c_dev_buf.FromDevice(c_rslt_host.data());
144+
145+
bool pass = true;
146+
if(arg_parser.get_int("v") == 1)
147+
{
148+
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
149+
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
150+
c_m_n_host_ref.SetZero();
151+
152+
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
153+
a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b);
154+
155+
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
156+
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
157+
158+
pass = ck_tile::check_err(
159+
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
160+
161+
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
162+
<< std::endl;
163+
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
164+
}
165+
166+
return pass;
167+
}

include/ck_tile/host/reference/reference_gemm.hpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,93 @@ reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::
382382
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
383383
}
384384

385+
template <typename ADataType,
386+
typename BDataType,
387+
typename ScaleDataType,
388+
typename AccDataType,
389+
typename CDataType,
390+
typename AElementOp = ck_tile::identity,
391+
typename BElementOp = ck_tile::identity,
392+
typename ACCElementOp = ck_tile::identity>
393+
CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
394+
const HostTensor<BDataType>& b_k_n,
395+
HostTensor<CDataType>& c_m_n,
396+
const HostTensor<ScaleDataType>& scale_a,
397+
const HostTensor<ScaleDataType>& scale_b,
398+
const AElementOp& = {},
399+
const BElementOp& = {},
400+
const ACCElementOp& = {})
401+
{
402+
static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
403+
static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
404+
static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
405+
406+
const std::size_t M = a_m_k.get_length(0);
407+
const std::size_t N = b_k_n.get_length(1);
408+
const std::size_t K = a_m_k.get_length(1);
409+
410+
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
411+
412+
HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
413+
{std::size_t(K), std::size_t(1)});
414+
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
415+
{std::size_t(1), std::size_t(K)});
416+
417+
for(std::size_t m = 0; m < M; ++m)
418+
{
419+
for(std::size_t k = 0; k < K; ++k)
420+
{
421+
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
422+
{
423+
if(k % 2 == 1)
424+
continue; // skip odd k
425+
426+
auto a_f4x2 = a_m_k(m, k);
427+
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
428+
auto a_f4_lo =
429+
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
430+
auto a_f4_hi =
431+
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
432+
433+
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
434+
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
435+
}
436+
}
437+
}
438+
439+
for(std::size_t n = 0; n < N; n++)
440+
{
441+
for(std::size_t k = 0; k < K; k++)
442+
{
443+
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
444+
{
445+
if(k % 2 == 1)
446+
continue; // skip odd k
447+
448+
auto b_f4x2 = b_k_n(k, n);
449+
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
450+
auto b_f4_lo =
451+
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
452+
auto b_f4_hi =
453+
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
454+
455+
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
456+
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
457+
}
458+
else
459+
{
460+
b_k_n_scaled(k, n) =
461+
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
462+
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
463+
}
464+
}
465+
}
466+
467+
// call reference gemm
468+
reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
469+
a_m_k_scaled, b_k_n_scaled, c_m_n);
470+
}
471+
385472
template <typename ADataType,
386473
typename BDataType,
387474
typename DsDataType,

include/ck_tile/ops/flatmm.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
1414
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
1515
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
16+
#include "ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp"
1617
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
1718
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
1819
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
1920
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
2021
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
22+
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
23+
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
2124
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
2225
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
2326
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"

include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,8 +902,8 @@ struct FlatmmKernel
902902
{
903903
const auto [iM, iN] =
904904
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
905-
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
906-
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
905+
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
906+
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
907907

908908
const SplitKBatchOffset splitk_batch_offset(kargs);
909909
// options

0 commit comments

Comments
 (0)