Skip to content

Commit

Permalink
PR #16975: Add a few related optimization passes for fp8 gemm custom-…
Browse files Browse the repository at this point in the history
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
90f5968 by Elfie Guo <[email protected]>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968
PiperOrigin-RevId: 675755585
  • Loading branch information
elfiegg authored and Google-ML-Automation committed Sep 17, 2024
1 parent 2b1c609 commit 6c1f764
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,12 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();

pipeline.AddPass<LayoutNormalization>(&NormalizeLayoutForGpuCustomCalls);

// Layout normalization will create scatters that are not simplified and
// also have unsorted update_window_dims.
pipeline.AddPass<ScatterSimplifier>();

pipeline.AddPass<HostOffloadLegalize>(
static_cast<int64_t>(stream_executor::MemoryType::kHost),
/* after_layout= */ true);
Expand Down
56 changes: 56 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,62 @@ ENTRY main {
triton_disabled_module->computation_count());
}

TEST_F(GpuCompilerTest,
CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
if (!cc.IsAtLeastAmpere()) {
GTEST_SKIP() << "Autotuning results have only been generated for Ampere "
<< "and Hopper GPUs";
}
const absl::string_view hlo_string = R"(
HloModule test
ENTRY main {
p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bitcast = bf16[] constant(0.956)
broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={}
ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast)
})";

HloModuleConfig config;
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest();
triton_enabled_debug_options
.set_xla_gpu_require_complete_aot_autotune_results(true);
config.set_debug_options(triton_enabled_debug_options);

// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));

AutotunerUtil::ClearAutotuneResults();
DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false);
triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true);
config.set_debug_options(triton_disabled_debug_options);

TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_disabled_module,
GetOptimizedModule(std::move(module)));

EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module),
std::move(triton_disabled_module),
ErrorSpec{1e-6, 1e-6}, false));
}

class FloatNormalizationTest : public GpuCompilerTest,
public ::testing::WithParamInterface<
std::pair<PrimitiveType, PrimitiveType>> {};
Expand Down
23 changes: 23 additions & 0 deletions xla/service/gpu/gpu_compiler_test_autotune_db.textproto
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,26 @@ results {
}
}
}
results {
device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB"
hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
result {
gemm {
}
run_time {
nanos: 1
}
}
}
results {
device: "CUDA: 9.0, Cores: 114, GPU clock: 1.755 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 50 MB"
hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 1
}
}
}

0 comments on commit 6c1f764

Please sign in to comment.