diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 6035e2dce8da9..6720611a9476d 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -288,8 +288,12 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_max_plans(5); + // TODO: remove this as it is replaced by xla_gpu_pgle_accuracy_checker. opts.set_xla_gpu_enable_pgle_accuracy_checker(false); + opts.set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_WARN); + opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); opts.set_xla_gpu_executable_terminate_timeout_seconds(30); opts.set_xla_gpu_experimental_disable_binary_libraries(false); @@ -701,6 +705,18 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; + // Custom "sub-parser" lambda for xla_gpu_pgle_accuracy_checker. + auto setter_for_xla_gpu_pgle_accuracy_checker = + [debug_options](const std::string& value) { + DebugOptions::PGLEStrictnessLevel strictness_level; + if (!DebugOptions::PGLEStrictnessLevel_Parse(value, + &strictness_level)) { + return false; + } + debug_options->set_xla_gpu_pgle_accuracy_checker(strictness_level); + return true; + }; + // Don't use an initializer list for initializing the vector; this would // create a temporary copy, and exceeds the stack space when compiling with // certain configurations. @@ -1975,12 +1991,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "a training. The location of the marker (if any) is determined " "by the option value of type DebugOptions::StepMarkerLocation.")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_pgle_accuracy_checker", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_pgle_accuracy_checker), - debug_options->xla_gpu_enable_pgle_accuracy_checker(), - "Enables strict PGLE checking. If an FDO profile is specified and " - "latency hiding scheduler encounters missing instructions in the profile " - "compilation will halt.")); + "xla_gpu_pgle_accuracy_checker", setter_for_xla_gpu_pgle_accuracy_checker, + DebugOptions::PGLEStrictnessLevel_Name( + debug_options->xla_gpu_pgle_accuracy_checker()), + "If an FDO profile is specified and latency hiding scheduler encounters " + "missing instructions in the profile, then the compilation will halt " + "(ERROR), or a warning will be emitted (WARN), or the checker is " + "disabled (OFF)")); flag_list->push_back(tsl::Flag( "xla_gpu_executable_warn_stuck_timeout", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index b7759f78ee6ba..997da70039ba5 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -2633,7 +2633,8 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( pipeline.AddPass(); } - if (module->config().debug_options().xla_gpu_enable_pgle_accuracy_checker()) { + if (module->config().debug_options().xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { AddHloVerifier( &main_pipeline, module->config().debug_options().xla_experimental_ignore_channel_id(), diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index f4ef235c37634..d89635fe871e9 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -449,15 +449,32 @@ absl::StatusOr ScheduleGpuModule( VLOG(1) << "Fingerprint before LHS for module " << module->name() << "(" << module->unique_id() << ") = " << fingerprint; + const DebugOptions& options = module->config().debug_options(); const bool enable_latency_hiding_scheduler = - module->config() - .debug_options() - .xla_gpu_enable_latency_hiding_scheduler(); + options.xla_gpu_enable_latency_hiding_scheduler(); if (!enable_latency_hiding_scheduler) { return ScheduleMetadata{memory_limit}; } + if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() && + module->config().fdo_profile().empty() && + options.xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { + return absl::InvalidArgumentError( + "xla_gpu_pgle_accuracy_checker is set to ERROR, but no profile " + "path specified in xla_gpu_pgle_profile_file_or_directory_path"); + } + + if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() && + module->config().fdo_profile().empty() && + options.xla_gpu_pgle_accuracy_checker(), + DebugOptions::PGLE_STRICTNESS_LEVEL_WARN) { + LOG(WARNING) + << "xla_gpu_pgle_accuracy_checker is set to WARN, but no profile path " + "specified in xla_gpu_pgle_profile_file_or_directory_path"; + } + SchedulerConfig config = GetSchedulerConfig( memory_limit, module->config() @@ -481,9 +498,7 @@ absl::StatusOr ScheduleGpuModule( ReadPGLEProfile(module, fingerprint); const bool enable_analytical_latency_estimator = - module->config() - .debug_options() - .xla_gpu_enable_analytical_latency_estimator(); + options.xla_gpu_enable_analytical_latency_estimator(); HloPassPipeline pipeline("latency-hiding-scheduler"); if (profile.has_value()) { auto aggregator = std::make_unique(); @@ -492,9 +507,10 @@ absl::StatusOr ScheduleGpuModule( std::move(aggregator)); LOG(INFO) << "Found profile, using profile guided latency estimator"; VLOG(1) << "Profile:\n" << profile->DebugString(); - if (module->config() - .debug_options() - .xla_gpu_enable_pgle_accuracy_checker()) { + if (options.xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_WARN || + options.xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { pipeline.AddPass(*pg_latency_estimator); } latency_estimator = std::move(pg_latency_estimator); @@ -511,9 +527,7 @@ absl::StatusOr ScheduleGpuModule( } auto async_tracker = [&]() -> std::unique_ptr { - return module->config() - .debug_options() - .xla_gpu_lhs_enable_gpu_async_tracker() + return options.xla_gpu_lhs_enable_gpu_async_tracker() ? std::make_unique(config) : std::make_unique(config); }(); diff --git a/xla/service/gpu/gpu_hlo_schedule_test.cc b/xla/service/gpu/gpu_hlo_schedule_test.cc index 70a91ca387265..2f4756c3ed71e 100644 --- a/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -566,7 +566,8 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) { HloModuleConfig config(module->config()); DebugOptions dboptions(config.debug_options()); - dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true); + dboptions.set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR); config.set_debug_options(dboptions); module->set_config(config); @@ -1696,5 +1697,30 @@ TEST_F(GpuHloScheduleTest, CopyStartDoneScheduled) { )")); } +TEST_F(GpuHloScheduleTest, InvalidPGLEOptions) { + const char* hlo = R"( + HloModule test + ENTRY add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + )"; + + HloModuleConfig config; + DebugOptions options; + options.set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR); + options.set_xla_gpu_enable_latency_hiding_scheduler(true); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo, config)); + + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_DEATH(BuildHloOrdering(module.get()), + "xla_gpu_pgle_accuracy_checker is set to ERROR, but no profile " + "path specified in xla_gpu_pgle_profile_file_or_directory_path"); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 7d947c23f881a..42e05cf9db71c 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -55,17 +55,17 @@ int GetIndexByName(absl::Span instruction_sequence, class GpuLatencyHidingSchedulerBaseTest : public HloTestBase { protected: absl::StatusOr ScheduleModule( - HloModule* module, int64_t num_parallel_resources = 1) { + HloModule* module, int64_t num_parallel_resources = 1, + DebugOptions::PGLEStrictnessLevel strictness = + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { auto& test_backend = backend(); const auto& gpu_device_info = test_backend.default_stream_executor()->GetDeviceDescription(); - HloModuleConfig config(module->config()); - DebugOptions dboptions(config.debug_options()); - dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true); - dboptions.set_xla_gpu_experimental_parallel_collective_overlap_limit( + DebugOptions& options = module->mutable_config().mutable_debug_options(); + options.set_xla_gpu_experimental_parallel_collective_overlap_limit( num_parallel_resources); - config.set_debug_options(dboptions); - module->set_config(config); + options.set_xla_gpu_pgle_accuracy_checker(strictness); + TF_RETURN_IF_ERROR( ScheduleGpuModule(module, /*pointer_size=*/8, gpu_device_info) .status()); diff --git a/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc b/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc index 3f2d1ab6426fd..d1994a1f6f068 100644 --- a/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc +++ b/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc @@ -148,6 +148,10 @@ TEST_F(PGLEAccuracyCheckerTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); *module->mutable_config().mutable_fdo_profile() = kProfileString; + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR); auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); diff --git a/xla/service/profile_guided_latency_estimator.cc b/xla/service/profile_guided_latency_estimator.cc index d8e20f2445c4a..d66f533493318 100644 --- a/xla/service/profile_guided_latency_estimator.cc +++ b/xla/service/profile_guided_latency_estimator.cc @@ -188,16 +188,19 @@ absl::Status ProfileGuidedLatencyEstimator::CheckAccuracy( ProfileStatisticsAggregator::Statistics stats = aggregator_->GetStats(); size_t missing_instructions_count = stats.missing_instructions.size(); if (missing_instructions_count > 0) { - LOG(ERROR) << "Found " << stats.found_instructions_count - << " instructions from the profile."; - LOG(ERROR) << "Missing " << missing_instructions_count - << " instructions from the profile."; + LOG(WARNING) << "Found " << stats.found_instructions_count + << " instructions from the profile."; + LOG(WARNING) << "Missing " << missing_instructions_count + << " instructions from the profile."; for (const HloInstruction* instr : stats.missing_instructions) { - LOG(ERROR) << " " << instr->name(); + LOG(WARNING) << " " << instr->name(); + } + if (module.config().debug_options().xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { + return absl::InvalidArgumentError( + absl::StrCat("Found ", missing_instructions_count, + " missing instructions. Discarding the profile.")); } - return absl::InvalidArgumentError( - absl::StrCat("Found ", missing_instructions_count, - " missing instructions. Discarding the profile.")); } return absl::OkStatus(); } diff --git a/xla/xla.proto b/xla/xla.proto index 6e76f68c19e78..ac9d110aa0346 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -994,6 +994,7 @@ message DebugOptions { // Enables strict PGLE checking. If an FDO profile is specified and latency // hiding scheduler encounters missing instructions in the profile // compilation will halt. + // TODO: remove this field - it is replaced by xla_gpu_pgle_accuracy_checker. bool xla_gpu_enable_pgle_accuracy_checker = 326; // Timeouts for RendezvousSingle stuck warning and termination. @@ -1026,7 +1027,17 @@ message DebugOptions { // coll.2-done = collective(coll.2-start) int32 xla_gpu_experimental_parallel_collective_overlap_limit = 336; - // Next id: 339 + // Enables strict PGLE checking. If an FDO profile is specified and latency + // hiding scheduler encounters missing instructions in the profile + // compilation will halt or warn depending on the value of this option. + enum PGLEStrictnessLevel { + PGLE_STRICTNESS_LEVEL_OFF = 0; + PGLE_STRICTNESS_LEVEL_WARN = 1; + PGLE_STRICTNESS_LEVEL_ERROR = 2; + } + PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 339; + + // Next id: 340 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.