diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index cdbe98490bf47e..bb550fcefe5df6 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1552,6 +1552,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + pipeline.AddPass(&NormalizeLayoutForGpuCustomCalls); + + // Layout normalization will create scatters that are not simplified and + // also have unsorted update_window_dims. + pipeline.AddPass(); + pipeline.AddPass( static_cast(stream_executor::MemoryType::kHost), /* after_layout= */ true); diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 51b459e8a81a02..f31689c237a3d2 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -473,6 +473,62 @@ ENTRY main { triton_disabled_module->computation_count()); } +TEST_F(GpuCompilerTest, CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) { + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "Autotuning results have only been generated for Ampere " + << "and Hopper GPUs"; + } + const absl::string_view hlo_string = R"( +HloModule test + +ENTRY main { + p0 = f8e4m3fn[12288,4096]{0,1} parameter(0) + p1 = f8e4m3fn[4096,16384]{0,1} parameter(1) + dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast = bf16[] constant(0.956) + broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={} + ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast) + })"; + + HloModuleConfig config; + DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); + triton_enabled_debug_options + .set_xla_gpu_require_complete_aot_autotune_results(true); + config.set_debug_options(triton_enabled_debug_options); + + // Load autotuning DB. We shouldn't depend on actual execution times in a unit + // test. + std::string path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"); + TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, + GetOptimizedModule(std::move(module))); + + AutotunerUtil::ClearAutotuneResults(); + DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); + triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); + triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); + triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true); + config.set_debug_options(triton_disabled_debug_options); + + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_disabled_module, + GetOptimizedModule(std::move(module))); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module), + std::move(triton_disabled_module), + ErrorSpec{1e-6, 1e-6}, false)); +} + class FloatNormalizationTest : public GpuCompilerTest, public ::testing::WithParamInterface< std::pair> {}; diff --git a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index 51caadb7bd2d06..699b397681682f 100644 --- a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -71,3 +71,26 @@ results { } } } +results { + device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB" + hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" + result { + gemm { + } + run_time { + nanos: 1 + } + } +} +results { + device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB" + hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 1 + } + } +} \ No newline at end of file