diff --git a/examples/00_bmg_gemm/00_bmg_gemm.cpp b/examples/00_bmg_gemm/00_bmg_gemm.cpp index 251a4d1f10..7715f262df 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm.cpp @@ -51,6 +51,9 @@ executing Intel specific prefetch instructions for future iterations to ensure that the required blocks of A and B are resident in cache before they are needed. + B is always shaped [N, K]. When it's row-major, it's discontiguous (physical layout is [K, N]). + When it's column-major, it's contiguous (with a stride of [K, 1]). + To build & run this example (from your build dir): $ ninja 00_bmg_gemm @@ -242,6 +245,7 @@ struct ExampleRunner { initialize_block(block_C, seed + 2021); } + template cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; @@ -288,8 +292,15 @@ struct ExampleRunner { float cute_time = timer.seconds() / options.iterations; double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + if constexpr(std::is_same_v) { + std::cout << "Tensor B is row-major." << std::endl; + } else if constexpr(std::is_same_v) { + std::cout << "Tensor B is column-major." << std::endl; + } else { + static_assert(false, "Should not reach this case"); + } std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; - printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n\n", tflops / cute_time, cute_time*1000); } return cutlass::Status::kSuccess; @@ -297,26 +308,11 @@ struct ExampleRunner { }; -int main(int argc, const char** argv) -{ - // - // Parse options - // - - Options options; - - options.parse(argc, argv); - if (options.help) { - options.print_usage(std::cout) << std::endl; - return 0; - } - - if (options.error) { - std::cerr << "Aborting execution." << std::endl; - return -1; - } +template +void launcher(Options& options) +{ // // Run examples // @@ -340,13 +336,14 @@ int main(int argc, const char** argv) using ElementOutput = float; // <- data type of elements in output matrix D using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; + using LayoutB = B_Layout; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; // The 2D block copy operations used for the A and B matrices using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // TODO: use XE_2D_U16x32x32_LD_T once it's added + using GmemTiledCopyB = std::conditional_t, XE_2D_U16x32x32_LD_V, XE_2D_U16x16x16_LD_T>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -422,7 +419,28 @@ int main(int argc, const char** argv) ExampleRunner runner; - CUTLASS_CHECK(runner.run(options, hw_info)); + CUTLASS_CHECK(runner.template run(options, hw_info)); +} + +int main(int argc, const char** argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + launcher(options); + launcher(options); return 0; }