From 6c1f76442efa536d72378d8621811465a34ff2f1 Mon Sep 17 00:00:00 2001 From: Elfie Guo Date: Tue, 17 Sep 2024 16:42:56 -0700 Subject: [PATCH] PR #16975: Add a few related optimization passes for fp8 gemm custom-calls. Imported from GitHub PR https://github.com/openxla/xla/pull/16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 90f596851f20459e37b713a10283499658ebf41e by Elfie Guo : Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/16975 from elfiegg:pass 90f596851f20459e37b713a10283499658ebf41e PiperOrigin-RevId: 675755585 --- xla/service/gpu/gpu_compiler.cc | 6 ++ xla/service/gpu/gpu_compiler_test.cc | 56 +++++++++++++++++++ .../gpu_compiler_test_autotune_db.textproto | 23 ++++++++ 3 files changed, 85 insertions(+) diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index cdbe98490bf47e..c5834482ff1ec2 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1552,6 +1552,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + pipeline.AddPass(&NormalizeLayoutForGpuCustomCalls); + + // Layout normalization will create scatters that are not simplified and + // also have unsorted update_window_dims. + pipeline.AddPass(); + pipeline.AddPass( static_cast(stream_executor::MemoryType::kHost), /* after_layout= */ true); diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 51b459e8a81a02..c316314cb64e5b 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -473,6 +473,62 @@ ENTRY main { triton_disabled_module->computation_count()); } +TEST_F(GpuCompilerTest, + CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) { + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "Autotuning results have only been generated for Ampere " + << "and Hopper GPUs"; + } + const absl::string_view hlo_string = R"( +HloModule test + +ENTRY main { + p0 = f8e4m3fn[12288,4096]{0,1} parameter(0) + p1 = f8e4m3fn[4096,16384]{0,1} parameter(1) + dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast = bf16[] constant(0.956) + broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={} + ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast) + })"; + + HloModuleConfig config; + DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); + triton_enabled_debug_options + .set_xla_gpu_require_complete_aot_autotune_results(true); + config.set_debug_options(triton_enabled_debug_options); + + // Load autotuning DB. We shouldn't depend on actual execution times in a unit + // test. + std::string path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"); + TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, + GetOptimizedModule(std::move(module))); + + AutotunerUtil::ClearAutotuneResults(); + DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); + triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); + triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true); + config.set_debug_options(triton_disabled_debug_options); + + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_disabled_module, + GetOptimizedModule(std::move(module))); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module), + std::move(triton_disabled_module), + ErrorSpec{1e-6, 1e-6}, false)); +} + class FloatNormalizationTest : public GpuCompilerTest, public ::testing::WithParamInterface< std::pair> {}; diff --git a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index 51caadb7bd2d06..699b397681682f 100644 --- a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -71,3 +71,26 @@ results { } } } +results { + device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB" + hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" + result { + gemm { + } + run_time { + nanos: 1 + } + } +} +results { + device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB" + hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 1 + } + } +} \ No newline at end of file