From 9837af2b095a5d476bdad2e7c14d281763e88ac0 Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Mon, 12 Aug 2024 13:28:37 +0100 Subject: [PATCH] Use googlebench in Benchmarks (#116) * Use googlebench for benchmarking --- benchmarks/CMakeLists.txt | 11 ++ ...pere_gemm_bf16_bf16_fp32_tensor_op_fp32.cu | 2 +- ...pere_gemm_fp16_fp16_fp32_tensor_op_fp32.cu | 2 +- ...pere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu | 2 +- benchmarks/common/benchmark_runner.hpp | 127 +++++++----------- ...ench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp | 2 +- 6 files changed, 61 insertions(+), 85 deletions(-) diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index a074ebe429..8011efd9ba 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -26,6 +26,15 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(FetchContent) +FetchContent_Declare( + googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main +) +FetchContent_MakeAvailable(googlebenchmark) + set(CUTLASS_BENCHMARKS_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common) add_custom_target(cutlass_benchmarks) @@ -54,6 +63,7 @@ function(cutlass_benchmark_add_executable NAME) cutlass_tools_util_includes $<$:nvidia::cublas> $<$:cuda> + benchmark::benchmark ) if (CUTLASS_ENABLE_SYCL) @@ -66,6 +76,7 @@ function(cutlass_benchmark_add_executable NAME) ) endfunction() + if(SYCL_INTEL_TARGET) add_subdirectory(pvc) endif() diff --git a/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cu b/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cu index 8dad127417..fb95c4bab2 100644 --- a/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cu +++ b/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cu @@ -145,7 +145,7 @@ int main(int argc, const char** argv) using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - BenchmarkRunner runner; + BenchmarkRunner runner("ampere_gemm_bf16_bf16_fp32_tensor_op_fp32"); runner.run(options, hw_info); diff --git a/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cu b/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cu index 69bc482f12..4b4f6704ce 100644 --- a/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cu +++ b/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cu @@ -145,7 +145,7 @@ int main(int argc, const char** argv) using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - BenchmarkRunner runner; + BenchmarkRunner runner("ampere_gemm_fp16_fp16_fp32_tensor_op_fp32"); runner.run(options, hw_info); diff --git a/benchmarks/ampere/bench_ampere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu b/benchmarks/ampere/bench_ampere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu index 7807ac4cee..eab97511f6 100644 --- a/benchmarks/ampere/bench_ampere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu +++ b/benchmarks/ampere/bench_ampere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu @@ -145,7 +145,7 @@ int main(int argc, const char** argv) using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - BenchmarkRunner runner; + BenchmarkRunner runner("ampere_gemm_tf32_tf32_fp32_tensor_op_fp32"); runner.run(options, hw_info); diff --git a/benchmarks/common/benchmark_runner.hpp b/benchmarks/common/benchmark_runner.hpp index 911055401c..edc87a1a4f 100644 --- a/benchmarks/common/benchmark_runner.hpp +++ b/benchmarks/common/benchmark_runner.hpp @@ -52,6 +52,8 @@ #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/print_error.hpp" +#include + template static void fill_matrix(std::vector &M) { @@ -75,7 +77,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(100), + m(4096), n(4096), k(4096), l(1), alpha(1.f), beta(0.f) { } @@ -94,7 +96,6 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); } /// Prints the usage statement. @@ -161,11 +162,19 @@ struct BenchmarkRunner { ElementOutput epsilon; ElementOutput nonzero_floor; - BenchmarkRunner() : epsilon(static_cast(0.1f)), - nonzero_floor(static_cast(0.1f)) {}; + BenchmarkRunner(std::string test_name) : epsilon(static_cast(0.1f)), + nonzero_floor(static_cast(0.1f)), test_name(test_name) { + int argc = 0; + benchmark::SetDefaultTimeUnit(benchmark::kMillisecond); + benchmark::Initialize(&argc, nullptr); + }; - BenchmarkRunner(ElementOutput epsilon, ElementOutput nonzeroFloor) : - epsilon(epsilon), nonzero_floor(nonzeroFloor) {} + BenchmarkRunner(ElementOutput epsilon, ElementOutput nonzeroFloor, std::string test_name) : + epsilon(epsilon), nonzero_floor(nonzeroFloor), test_name(test_name) { + int argc = 0; + benchmark::SetDefaultTimeUnit(benchmark::kMillisecond); + benchmark::Initialize(&argc, nullptr); + } // // Methods @@ -261,6 +270,7 @@ struct BenchmarkRunner { } virtual void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + benchmark::ClearRegisteredBenchmarks(); ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; initialize(problem_size); @@ -293,86 +303,41 @@ struct BenchmarkRunner { // Verify that the result is correct bool passed = verify(problem_size, options.alpha, options.beta); - std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; - - if (passed && options.iterations > 0) { - GPU_Clock timer; - timer.start(); - for (int i = 0; i < options.iterations; ++i) { - gemm_op.run(); - } - - float cute_time = timer.seconds() / options.iterations; - double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; - std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l - << std::endl; - printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time * 1000); + if(not passed) { + throw std::runtime_error("Disposition Failed."); } - } -}; -template -struct PvcBenchmarkRunner : BenchmarkRunner { - using Base = BenchmarkRunner; - - using ElementB = typename Base::ElementB; - - using ProblemShapeType = typename Base::ProblemShapeType; - - void initialize(const ProblemShapeType& problem_size) override { - Base::initialize(problem_size); - } + std::stringstream full_test_name; + full_test_name << test_name << "/"; + std::string test_name_suffix = std::to_string(options.m) + "x" + + std::to_string(options.n) + "x" + + std::to_string(options.k) + "x" + + std::to_string(options.l); + full_test_name << test_name_suffix; + benchmark::RegisterBenchmark(full_test_name.str().c_str(), run_benchmark, options, gemm_op) + ->UseManualTime(); + benchmark::RunSpecifiedBenchmarks(); + } + + ~BenchmarkRunner() { + benchmark::Shutdown(); + } - void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) override { - ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; - - initialize(problem_size); - - typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {Base::block_A.get(), Base::stride_A, Base::block_B.get(), Base::stride_B}, - { - {options.alpha, options.beta}, - Base::block_C.get(), Base::stride_C, Base::block_D.get(), Base::stride_D - }, - hw_info - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - gemm_op.can_implement(arguments); - - gemm_op.initialize(arguments, workspace.get()); - - // Run the GEMM - gemm_op.run(); - -#if defined(CUTLASS_ENABLE_SYCL) - syclcompat::wait(); -#else - cudaDeviceSynchronize(); -#endif - - // Verify that the result is correct - bool passed = Base::verify(problem_size, options.alpha, options.beta); - std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; - - if (passed && options.iterations > 0) { + private: + static void run_benchmark(benchmark::State& state, const Options& options, Gemm gemm_op) { + state.counters["runtime_ms"] = 0; + for(auto _ : state) { GPU_Clock timer; timer.start(); - for (int i = 0; i < options.iterations; ++i) { - gemm_op.run(); - } - - float cute_time = timer.seconds() / options.iterations; - double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; - std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; - printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + gemm_op.run(); + auto ms_elapsed = timer.milliseconds(); + state.counters["runtime_ms"] += ms_elapsed; + state.SetIterationTime(ms_elapsed / 1000); } + state.counters["runtime_ms"] /= state.iterations(); + state.counters["TFlops"] = ((2.0 * options.m * options.n * options.k * options.l) * 1e-12) / + (state.counters["runtime_ms"] / 1000); } -}; + std::string test_name; +}; diff --git a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp index 3203e7f367..0a2d794650 100644 --- a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp +++ b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp @@ -137,7 +137,7 @@ int main(int argc, const char** argv) using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - PvcBenchmarkRunner runner; + BenchmarkRunner runner("pvc_gemm_bf16_bf16_fp32_dpas_fp32"); runner.run(options, hw_info);