|
| 1 | +// SPDX-License-Identifier: MIT |
| 2 | +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | + |
| 4 | +template <typename PrecActType, |
| 5 | + typename PrecWeightType, |
| 6 | + typename CDataType, |
| 7 | + typename FlatmmConfig, |
| 8 | + bool UsePersistentKernel = false, |
| 9 | + typename ALayout, |
| 10 | + typename BLayout, |
| 11 | + typename CLayout> |
| 12 | +int run_mx_flatmm_with_layouts(int argc, |
| 13 | + char* argv[], |
| 14 | + const ALayout a_layout = ALayout{}, |
| 15 | + const BLayout b_layout = BLayout{}, |
| 16 | + const CLayout c_layout = CLayout{}) |
| 17 | +{ |
| 18 | + auto [result, arg_parser] = create_args(argc, argv); |
| 19 | + if(!result) |
| 20 | + return -1; |
| 21 | + |
| 22 | + using ADataType = PrecActType; |
| 23 | + using BDataType = PrecWeightType; |
| 24 | + using AccDataType = float; |
| 25 | + |
| 26 | + using ScaleType = ck_tile::e8m0_t; |
| 27 | + |
| 28 | + constexpr int ScaleGranularityM = 1; |
| 29 | + constexpr int ScaleGranularityN = 1; |
| 30 | + constexpr int ScaleGranularityK = 32; |
| 31 | + |
| 32 | + ck_tile::index_t M = arg_parser.get_int("m"); |
| 33 | + ck_tile::index_t N = arg_parser.get_int("n"); |
| 34 | + ck_tile::index_t K = arg_parser.get_int("k"); |
| 35 | + |
| 36 | + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); |
| 37 | + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); |
| 38 | + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); |
| 39 | + |
| 40 | + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); |
| 41 | + ck_tile::index_t init_method = arg_parser.get_int("init"); |
| 42 | + ck_tile::index_t n_warmup = arg_parser.get_int("warmup"); |
| 43 | + ck_tile::index_t n_repeat = arg_parser.get_int("repeat"); |
| 44 | + |
| 45 | + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); |
| 46 | + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); |
| 47 | + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout)); |
| 48 | + |
| 49 | + auto scale_stride_A = ck_tile::get_default_stride( |
| 50 | + M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout)); |
| 51 | + auto scale_stride_B = ck_tile::get_default_stride( |
| 52 | + K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout)); |
| 53 | + |
| 54 | + if(K % ScaleGranularityK != 0) |
| 55 | + throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK."); |
| 56 | + if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 || |
| 57 | + K % ck_tile::numeric_traits<BDataType>::PackedSize != 0) |
| 58 | + throw std::runtime_error("wrong! K must be multiple of packed size."); |
| 59 | + |
| 60 | + ck_tile ::HostTensor<ADataType> a_host( |
| 61 | + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); |
| 62 | + ck_tile::HostTensor<BDataType> b_origin_host( |
| 63 | + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); |
| 64 | + ck_tile::HostTensor<CDataType> c_rslt_host( |
| 65 | + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); |
| 66 | + |
| 67 | + ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor( |
| 68 | + M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); |
| 69 | + ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor( |
| 70 | + K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); |
| 71 | + |
| 72 | + if(init_method == 0) |
| 73 | + { |
| 74 | + ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host); |
| 75 | + ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host); |
| 76 | + ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a); |
| 77 | + ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b); |
| 78 | + } |
| 79 | + else if(init_method == 1) |
| 80 | + { |
| 81 | + ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host); |
| 82 | + ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host); |
| 83 | + ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a); |
| 84 | + ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b); |
| 85 | + } |
| 86 | + else |
| 87 | + { |
| 88 | + throw std::runtime_error("wrong! Unexpected init_method"); |
| 89 | + } |
| 90 | + |
| 91 | + ck_tile::HostTensor<BDataType> b_shuffled_host( |
| 92 | + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); |
| 93 | + preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K); |
| 94 | + |
| 95 | + const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a); |
| 96 | + const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b); |
| 97 | + |
| 98 | + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); |
| 99 | + ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes()); |
| 100 | + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); |
| 101 | + |
| 102 | + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); |
| 103 | + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); |
| 104 | + |
| 105 | + a_dev_buf.ToDevice(a_host.data()); |
| 106 | + b_shuffled_dev_buf.ToDevice(b_shuffled_host.data()); |
| 107 | + c_rslt_host.SetZero(); |
| 108 | + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); |
| 109 | + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); |
| 110 | + |
| 111 | + auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{ |
| 112 | + static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM}; |
| 113 | + auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{ |
| 114 | + static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; |
| 115 | + |
| 116 | + invoke_mx_flatmm<FlatmmConfig, |
| 117 | + ADataType, |
| 118 | + BDataType, |
| 119 | + ck_tile::tuple<>, |
| 120 | + AccDataType, |
| 121 | + CDataType, |
| 122 | + ALayout, |
| 123 | + BLayout, |
| 124 | + ck_tile::tuple<>, |
| 125 | + CLayout, |
| 126 | + decltype(scale_a_dev_ptr), |
| 127 | + decltype(scale_b_dev_ptr), |
| 128 | + UsePersistentKernel>(a_dev_buf, |
| 129 | + b_shuffled_dev_buf, |
| 130 | + c_dev_buf, |
| 131 | + M, |
| 132 | + N, |
| 133 | + K, |
| 134 | + stride_A, |
| 135 | + stride_B, |
| 136 | + stride_C, |
| 137 | + kbatch, |
| 138 | + scale_a_dev_ptr, |
| 139 | + scale_b_dev_ptr, |
| 140 | + n_warmup, |
| 141 | + n_repeat); |
| 142 | + |
| 143 | + c_dev_buf.FromDevice(c_rslt_host.data()); |
| 144 | + |
| 145 | + bool pass = true; |
| 146 | + if(arg_parser.get_int("v") == 1) |
| 147 | + { |
| 148 | + ck_tile::HostTensor<CDataType> c_m_n_host_ref( |
| 149 | + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); |
| 150 | + c_m_n_host_ref.SetZero(); |
| 151 | + |
| 152 | + ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>( |
| 153 | + a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b); |
| 154 | + |
| 155 | + const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2; |
| 156 | + const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2; |
| 157 | + |
| 158 | + pass = ck_tile::check_err( |
| 159 | + c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol); |
| 160 | + |
| 161 | + std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol |
| 162 | + << std::endl; |
| 163 | + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; |
| 164 | + } |
| 165 | + |
| 166 | + return pass; |
| 167 | +} |
0 commit comments