Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
executing Intel specific prefetch instructions for future iterations to ensure that the required
blocks of A and B are resident in cache before they are needed.

B is always shaped [N, K]. When it's row-major, it's discontiguous (physical layout is [K, N]).
When it's column-major, it's contiguous (with a stride of [K, 1]).

To build & run this example (from your build dir):

$ ninja 00_bmg_gemm
Expand Down Expand Up @@ -242,6 +245,7 @@ struct ExampleRunner {
initialize_block(block_C, seed + 2021);
}

template<class B_Layout>
cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};

Expand Down Expand Up @@ -288,35 +292,27 @@ struct ExampleRunner {

float cute_time = timer.seconds() / options.iterations;
double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
if constexpr(std::is_same_v<B_Layout, cutlass::layout::RowMajor>) {
std::cout << "Tensor B is row-major." << std::endl;
} else if constexpr(std::is_same_v<B_Layout, cutlass::layout::ColumnMajor>) {
std::cout << "Tensor B is column-major." << std::endl;
} else {
static_assert(false, "Should not reach this case");
}
std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000);
printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n\n", tflops / cute_time, cute_time*1000);
}

return cutlass::Status::kSuccess;
}

};

int main(int argc, const char** argv)
{
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}

template<class B_Layout>
void launcher(Options& options)
{
//
// Run examples
//
Expand All @@ -340,13 +336,14 @@ int main(int argc, const char** argv)
using ElementOutput = float; // <- data type of elements in output matrix D

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutB = B_Layout;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// The 2D block copy operations used for the A and B matrices
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// TODO: use XE_2D_U16x32x32_LD_T once it's added
using GmemTiledCopyB = std::conditional_t<std::is_same_v<B_Layout, cutlass::layout::RowMajor>, XE_2D_U16x32x32_LD_V, XE_2D_U16x16x16_LD_T>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand Down Expand Up @@ -422,7 +419,28 @@ int main(int argc, const char** argv)

ExampleRunner<Gemm> runner;

CUTLASS_CHECK(runner.run(options, hw_info));
CUTLASS_CHECK(runner.template run<B_Layout>(options, hw_info));
}

int main(int argc, const char** argv) {
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
launcher<cutlass::layout::RowMajor>(options);
launcher<cutlass::layout::ColumnMajor>(options);
return 0;
}
Loading