Skip to content

Commit

Permalink
[XLA:GPU] Simplify AutotuneOneConvRunner() parameter.
Browse files Browse the repository at this point in the history
All we care about is the instruction string, not the cache key. Note that in
fact we want the regular output of ToString() including metadata, not the
string output we use for the cache key which excludes metadata.

PiperOrigin-RevId: 696105764
  • Loading branch information
akuegel authored and Google-ML-Automation committed Nov 13, 2024
1 parent 4e783ed commit 619adc0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
16 changes: 4 additions & 12 deletions xla/service/gpu/autotuning/conv_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,7 @@ static const DisabledAlgorithm kDisabledAlgorithms[] = {
absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
GenericConvRunner* const runner,
std::optional<ReferenceResult>* reference_result,
absl::Span<const AlgorithmDesc> disabled_algos,
std::optional<AutotuneCacheKey> instruction_info,
absl::Span<const AlgorithmDesc> disabled_algos, absl::string_view instr_str,
const AutotuneRuntimeArguments& runtime_arguments) {
auto alg = runner->ToAlgorithmDesc();

Expand All @@ -543,10 +542,6 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(

AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);

std::string instr_str = instruction_info.has_value()
? std::string(instruction_info->GetHlo())
: "<unknown>";

for (const auto& disabled_algo : kDisabledAlgorithms) {
if (disabled_algo.cudnn_version_range.IsInRange(
GetCudnnVersion(stream_exec)) &&
Expand Down Expand Up @@ -773,10 +768,7 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
<< instr_str << " for " << (*reference_result)->algorithm.ToString()
<< " vs " << alg.ToString();
PrintPlatformInfo(stream);
if (instruction_info.has_value()) {
VLOG(2) << "Full module on failure: \n"
<< instruction_info->GetModelStr();
}
VLOG(2) << "Full module on failure: \n" << instr_str;
auto* fail = result.mutable_failure();
fail->set_kind(AutotuneResult::WRONG_RESULT);
fail->set_buffer_address(
Expand Down Expand Up @@ -863,7 +855,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
TF_ASSIGN_OR_RETURN(
auto result,
AutotuneOneConvRunner(&runner_cache, &reference_result, disabled_algos,
instruction_info, runtime_arguments));
instr->ToString(), runtime_arguments));
profile_results.emplace_back(std::move(result));
}

Expand All @@ -885,7 +877,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
for (auto& runner_cache : fallback_runners) {
TF_ASSIGN_OR_RETURN(
auto result, AutotuneOneConvRunner(&runner_cache, &reference_result,
disabled_algos, instruction_info,
disabled_algos, instr->ToString(),
runtime_arguments));
profile_results.emplace_back(std::move(result));
}
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/conv_algorithm_picker.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class GpuConvAlgorithmPicker : public HloModulePass {
GenericConvRunner* runner,
std::optional<ReferenceResult>* reference_result,
absl::Span<const stream_executor::dnn::AlgorithmDesc> disabled_algos,
std::optional<AutotuneCacheKey> instruction_info,
absl::string_view instr_str,
const AutotuneRuntimeArguments& runtime_arguments);

// Pick the best algorithm for CUDA platform.
Expand Down

0 comments on commit 619adc0

Please sign in to comment.