Skip to content

Commit

Permalink
Adding Strictness level to PGLE accuracy checker.
Browse files Browse the repository at this point in the history
Two flags control the behavior now.

 * `xla_gpu_pgle_profile_file_or_directory_path` unspecified,
   `xla_gpu_strict_pgle_accuracy_checker` off: this means that PGLE will
not be used.
 * `xla_gpu_pgle_profile_file_or_directory_path` specified,
   `xla_gpu_strict_pgle_accuracy_checker` off: this means that PGLE will
warn about accuracy checker failures like missing instructions, but will
continue with them.
 * `xla_gpu_pgle_profile_file_or_directory_path` specified,
   `xla_gpu_strict_pgle_accuracy_checker` on: this means that PGLE will
error out if the accuracy checker fails.
 * `xla_gpu_pgle_profile_file_or_directory_path` unspecified,
   `xla_gpu_strict_pgle_accuracy_checker` on: this is an invalid flag
combination.
  • Loading branch information
shraiysh committed Sep 17, 2024
1 parent 536ba0b commit a414eb8
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 28 deletions.
8 changes: 4 additions & 4 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_enable_triton_gemm_int4(false);

opts.set_xla_gpu_enable_pgle_accuracy_checker(false);
opts.set_xla_gpu_strict_pgle_accuracy_checker(false);

opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10);
opts.set_xla_gpu_executable_terminate_timeout_seconds(30);
Expand Down Expand Up @@ -1919,9 +1919,9 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"a training. The location of the marker (if any) is determined "
"by the option value of type DebugOptions::StepMarkerLocation."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_pgle_accuracy_checker",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pgle_accuracy_checker),
debug_options->xla_gpu_enable_pgle_accuracy_checker(),
"xla_gpu_strict_pgle_accuracy_checker",
bool_setter_for(&DebugOptions::set_xla_gpu_strict_pgle_accuracy_checker),
debug_options->xla_gpu_strict_pgle_accuracy_checker(),
"Enables strict PGLE checking. If an FDO profile is specified and "
"latency hiding scheduler encounters missing instructions in the profile "
"compilation will halt."));
Expand Down
29 changes: 15 additions & 14 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,15 +458,24 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
VLOG(1) << "Fingerprint before LHS for module " << module->name() << "("
<< module->unique_id() << ") = " << fingerprint;

const DebugOptions& options = module->config().debug_options();

const bool enable_latency_hiding_scheduler =
module->config()
.debug_options()
.xla_gpu_enable_latency_hiding_scheduler();
options.xla_gpu_enable_latency_hiding_scheduler();

if (!enable_latency_hiding_scheduler) {
return ScheduleMetadata{memory_limit};
}

VLOG(0) << "Here";

if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() &&
options.xla_gpu_strict_pgle_accuracy_checker()) {
return absl::InvalidArgumentError(
"xla_gpu_strict_pgle_accuracy_checker is turned on, but no profile "
"path specified in xla_gpu_pgle_profile_file_or_directory_path");
}

SchedulerConfig config = GetSchedulerConfig(memory_limit);
auto gpu_latency_estimator =
std::make_unique<GpuLatencyEstimator>(pointer_size);
Expand All @@ -476,9 +485,7 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
ReadPGLEProfile(module, fingerprint);

const bool enable_analytical_latency_estimator =
module->config()
.debug_options()
.xla_gpu_enable_analytical_latency_estimator();
options.xla_gpu_enable_analytical_latency_estimator();
HloPassPipeline pipeline("latency-hiding-scheduler");
if (profile.has_value()) {
auto aggregator = std::make_unique<GPUProfileStatisticsAggregator>();
Expand All @@ -487,11 +494,7 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
std::move(aggregator));
LOG(INFO) << "Found profile, using profile guided latency estimator";
VLOG(1) << "Profile:\n" << profile->DebugString();
if (module->config()
.debug_options()
.xla_gpu_enable_pgle_accuracy_checker()) {
pipeline.AddPass<PGLEAccuracyChecker>(*pg_latency_estimator);
}
pipeline.AddPass<PGLEAccuracyChecker>(*pg_latency_estimator);
latency_estimator = std::move(pg_latency_estimator);
} else if (enable_analytical_latency_estimator) {
latency_estimator = std::make_unique<AnalyticalLatencyEstimator>(
Expand All @@ -506,9 +509,7 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
}

auto async_tracker = [&]() -> std::unique_ptr<AsyncTracker> {
return module->config()
.debug_options()
.xla_gpu_lhs_enable_gpu_async_tracker()
return options.xla_gpu_lhs_enable_gpu_async_tracker()
? std::make_unique<GpuAsyncTracker>(config)
: std::make_unique<GpuAsyncTrackerBase>(config);
}();
Expand Down
27 changes: 26 additions & 1 deletion xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) {

HloModuleConfig config(module->config());
DebugOptions dboptions(config.debug_options());
dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true);
dboptions.set_xla_gpu_strict_pgle_accuracy_checker(true);
config.set_debug_options(dboptions);
module->set_config(config);

Expand Down Expand Up @@ -1637,5 +1637,30 @@ TEST_F(GpuHloScheduleTest, AsyncOps) {
HloOpcode::kAsyncDone, HloOpcode::kAdd));
}

TEST_F(GpuHloScheduleTest, InvalidPGLEOptions) {
const char* hlo = R"(
HloModule test
ENTRY add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = add(a,b)
}
)";

HloModuleConfig config;
DebugOptions options;
options.set_xla_gpu_strict_pgle_accuracy_checker(true);
options.set_xla_gpu_enable_latency_hiding_scheduler(true);
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo, config));

GTEST_FLAG_SET(death_test_style, "threadsafe");
EXPECT_DEATH(
BuildHloOrdering(module.get()),
"xla_gpu_strict_pgle_accuracy_checker is turned on, but no profile path "
"specified in xla_gpu_pgle_profile_file_or_directory_path");
}

} // namespace gpu
} // namespace xla
20 changes: 12 additions & 8 deletions xla/service/profile_guided_latency_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,20 @@ absl::Status ProfileGuidedLatencyEstimator::CheckAccuracy(
ProfileStatisticsAggregator::Statistics stats = aggregator_->GetStats();
size_t missing_instructions_count = stats.missing_instructions.size();
if (missing_instructions_count > 0) {
LOG(ERROR) << "Found " << stats.found_instructions_count
<< " instructions from the profile.";
LOG(ERROR) << "Missing " << missing_instructions_count
<< " instructions from the profile.";
LOG(WARNING) << "Found " << stats.found_instructions_count
<< " instructions from the profile.";
LOG(WARNING) << "Missing " << missing_instructions_count
<< " instructions from the profile.";
for (const HloInstruction* instr : stats.missing_instructions) {
LOG(ERROR) << " " << instr->name();
LOG(WARNING) << " " << instr->name();
}
if (module.config()
.debug_options()
.xla_gpu_strict_pgle_accuracy_checker()) {
return absl::InvalidArgumentError(
absl::StrCat("Found ", missing_instructions_count,
" missing instructions. Discarding the profile."));
}
return absl::InvalidArgumentError(
absl::StrCat("Found ", missing_instructions_count,
" missing instructions. Discarding the profile."));
}
return absl::OkStatus();
}
Expand Down
2 changes: 1 addition & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ message DebugOptions {
// Enables strict PGLE checking. If an FDO profile is specified and latency
// hiding scheduler encounters missing instructions in the profile
// compilation will halt.
bool xla_gpu_enable_pgle_accuracy_checker = 326;
bool xla_gpu_strict_pgle_accuracy_checker = 326;

// Timeouts for RendezvousSingle stuck warning and termination.
int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327;
Expand Down

0 comments on commit a414eb8

Please sign in to comment.