Skip to content

Commit

Permalink
Cost model now considers the compute latency in addition to its throu…
Browse files Browse the repository at this point in the history
…ghput.

PiperOrigin-RevId: 671505921
  • Loading branch information
mehrdadkhani authored and Google-ML-Automation committed Sep 17, 2024
1 parent 8ace4ee commit 98fb24a
Show file tree
Hide file tree
Showing 19 changed files with 103 additions and 11 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ int64_t PriorityFusionShapeSize(const Shape& shape) {
HloCostAnalysis::Options PriorityFusionOptions() {
return {/*shape_size=*/PriorityFusionShapeSize,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
}

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusion_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ HloPassPipeline FusionPipeline(
GpuHloCostAnalysis::Options cost_analysis_options{
shape_size_bytes_function,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
fusion.AddPass<PriorityFusion>(thread_pool, gpu_device_info,
std::move(cost_analysis_options));
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ absl::Status RunFusionPasses(HloModule* hlo_module,
GpuHloCostAnalysis::Options cost_analysis_options{
shape_size_fn,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};

HloPassPipeline post_fusion_analysis("post_fusion_analysis");
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/analytical_latency_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ AnalyticalLatencyEstimator::AnalyticalLatencyEstimator(
cost_analysis_.emplace(
GpuHloCostAnalysis::Options{shape_size_function_,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true},
gpu_info_);
TF_CHECK_OK(computation->Accept(&cost_analysis_.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GpuCostModelStatsCollectionTest : public HloTestBase {
TestGpuDeviceInfo::RTXA6000DeviceInfo(),
GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true}};
};

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class GpuHloCostAnalysisTest : public HloTestBase {
public:
HloCostAnalysis::Options options_{ShapeSizeBytesFunction(),
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
GpuHloCostAnalysis analysis_{options_};
GpuHloCostAnalysisTest() : HloTestBase() {}
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/gpu_indexing_performance_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
cost_analysis_(
GpuHloCostAnalysis::Options{shape_size_,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true},
*device_info_),
mlir_context_(mlir_context) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {
GpuHloCostAnalysis cost_analysis(
GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true},
device_info_);

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/gpu_performance_model_base_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GpuPerformanceModelBaseTest : public HloTestBase {

GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(),
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
// The reference times in the test cases below are measured
// on A6000 by profiling the execution of the HLOs.
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/gpu_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class GpuPerformanceModelTest : public HloTestBase {
mlir::MLIRContext mlir_context_;
GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(),
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
// The reference times in the test cases below are measured
// on A6000 by profiling the execution of the HLOs.
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/fusion_merger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
cost_analysis_.emplace(
GpuHloCostAnalysis::Options{shape_size_function_,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true},
gpu_device_info_);
TF_CHECK_OK(computation_->Accept(&cost_analysis_.value()));
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/multi_output_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ absl::StatusOr<bool> MultiOutputFusion::DoMultiOutputFusion() {
RecomputeReachability();
GpuHloCostAnalysis cost_analysis({shape_size_function_,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true},
device_info_);
TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis));
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/priority_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class PriorityFusionTest : public HloTestBase {
/*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(),
GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true}};
};

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/softmax_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton(
GpuHloCostAnalysis::Options cost_analysis_options{
shape_size,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
GpuHloCostAnalysis cost_analysis(cost_analysis_options, device_info);
TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis));
Expand Down
12 changes: 9 additions & 3 deletions xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ limitations under the License.
namespace xla {

HloCostAnalysis::HloCostAnalysis(const Options& options) : options_(options) {}
// TODO(mehrdadk): merge all constructors into HloCostAnalysis(const Options&
// options)
HloCostAnalysis::HloCostAnalysis(ShapeSizeFunction shape_size,
const Properties& per_second_rates)
: HloCostAnalysis(Options{shape_size, per_second_rates}) {}
const Properties& per_second_rates,
const Properties& min_latencies_seconds)
: HloCostAnalysis(
Options{shape_size, per_second_rates, min_latencies_seconds}) {}

absl::Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
// Set current instruction cost values to reasonable default values. Each
Expand Down Expand Up @@ -82,7 +86,9 @@ absl::Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
}
float per_second_rate = options_.per_second_rate(key);
if (per_second_rate != 0) {
optimal_seconds = std::max(optimal_seconds, val / per_second_rate);
float time_for_key =
std::max(val / per_second_rate, options_.min_latency_seconds(key));
optimal_seconds = std::max(optimal_seconds, time_for_key);
}
});
current_properties_[kOptimalSecondsKey] = optimal_seconds;
Expand Down
26 changes: 24 additions & 2 deletions xla/service/hlo_cost_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,11 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
// property is bytes accessed, this is the number of bytes that can be
// processed per second. Is empty if no rates have been set.
Properties per_second_rates = {};
// The minimum amount of time (in seconds) required to process per each
// property. Hardware design choices (e.g., clock speeds, memory access
// latencies) impose a lower bound on the duration of any operation, even
// the simplest ones.
Properties min_latencies_seconds;
// Operations like broadcast with reused inputs are not handled
// efficiently on some platforms. Depending on the goal of the analysis
// we may need to count or ignore them.
Expand All @@ -414,31 +419,44 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
void set_flops_per_second(float value) {
per_second_rates[kFlopsKey] = value;
}
void set_flops_min_latency_second(float value) {
min_latencies_seconds[kFlopsKey] = value;
}
void set_transcendentals_per_second(float value) {
per_second_rates[kTranscendentalsKey] = value;
}
void set_bytes_per_second(float value) {
per_second_rates[kBytesAccessedKey] = value;
}
void set_bytes_min_latency_second(float value) {
min_latencies_seconds[kBytesAccessedKey] = value;
}

// Returns the specified per-second rate used by cost analysis.
float per_second_rate(absl::string_view key) const {
return per_second_rates[key];
}

float min_latency_seconds(absl::string_view key) const {
return min_latencies_seconds[key];
}

std::string ToString() const {
return absl::StrFormat(
"HloCostAnalysis::Options{\n"
" per_second_rates: %s\n"
" min_latency_seconds: %s\n"
" count_multiple_input_accesses: %d\n"
"}",
per_second_rates.ToString(), count_multiple_input_accesses);
per_second_rates.ToString(), min_latencies_seconds.ToString(),
count_multiple_input_accesses);
}
};

explicit HloCostAnalysis(const Options& options);
explicit HloCostAnalysis(ShapeSizeFunction shape_size,
const Properties& per_second_rates = {});
const Properties& per_second_rates = {},
const Properties& min_latency_seconds = {});

// For all element-wise instruction we call HandleElementwiseOp. If necessary,
// override HandleElementwiseOp instead.
Expand Down Expand Up @@ -594,6 +612,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
float per_second_rate(absl::string_view key) const {
return options_.per_second_rate(key);
}
// Returns the specified minimum latency used by cost analysis.
float min_latency_seconds(absl::string_view key) const {
return options_.min_latency_seconds(key);
}

// Return the key that is used to index into Properties for the specified
// input/output at the shape index.
Expand Down
26 changes: 26 additions & 0 deletions xla/service/hlo_cost_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,32 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
}

// No instruction can finish faster than the clock cycle
TEST_F(HloCostAnalysisTest, LatencyBoundedOptimalTime) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
ENTRY Entry {
param0 = f32[1,1] parameter(0)
param1 = f32[1,1] parameter(1)
ROOT add = f32[1,1] add(param0, param1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));

const HloInstruction* add = module->entry_computation()->root_instruction();
HloCostAnalysis::Options options{ShapeSize};
const float clock_cycle_seconds = 10.0f;
options.set_flops_per_second(1024);
options.set_bytes_per_second(1024);
options.set_transcendentals_per_second(1024);
options.set_flops_min_latency_second(clock_cycle_seconds);
HloCostAnalysis cost_analysis(options);
ASSERT_IS_OK(add->Accept(&cost_analysis));
EXPECT_EQ(cost_analysis.optimal_seconds(), clock_cycle_seconds);
}

using FusionCostAnalysis = HloTestBase;

TEST_F(FusionCostAnalysis, LoopFusionDynUpdateSlice) {
Expand Down
6 changes: 4 additions & 2 deletions xla/service/memory_space_assignment/cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ float HloCostAnalysisCosts::BytesPerSecond() {

float HloCostAnalysisCosts::ComputeSeconds(const HloInstruction& instruction) {
return std::max(
static_cast<float>(hlo_cost_analysis_.flop_count(instruction)) /
hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
std::max(
hlo_cost_analysis_.min_latency_seconds(HloCostAnalysis::kFlopsKey),
static_cast<float>(hlo_cost_analysis_.flop_count(instruction)) /
hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey)),
static_cast<float>(hlo_cost_analysis_.transcendental_count(instruction)) /
hlo_cost_analysis_.per_second_rate(
HloCostAnalysis::kTranscendentalsKey));
Expand Down
30 changes: 26 additions & 4 deletions xla/service/memory_space_assignment/cost_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/memory_space_assignment/cost_analysis.h"

#include <algorithm>
#include <cstdint>
#include <memory>

Expand Down Expand Up @@ -56,6 +57,7 @@ class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase {
options.set_flops_per_second(8);
options.set_bytes_per_second(32);
options.set_transcendentals_per_second(16);
options.set_flops_min_latency_second(1);
hlo_cost_analysis_ = std::make_unique<HloCostAnalysis>(options);
TF_RETURN_IF_ERROR(
module->entry_computation()->Accept(hlo_cost_analysis_.get()));
Expand Down Expand Up @@ -90,8 +92,9 @@ TEST_F(MemorySpaceAssignmentCostAnalysisTest, NoPipelineOverhead) {
TF_ASSERT_OK(Initialize(module.get()));

const HloInstruction* add = module->entry_computation()->root_instruction();
const float expected_compute_elapsed =
/*num_flops=*/8 / /*flops_per_second=*/8.0;
const float expected_compute_elapsed = std::max(
/*num_flops=*/8.0f / /*flops_per_second=*/8.0f,
hlo_cost_analysis_->min_latency_seconds(HloCostAnalysis::kFlopsKey));
LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed;
EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add),
expected_compute_elapsed);
Expand Down Expand Up @@ -161,8 +164,9 @@ TEST_F(MemorySpaceAssignmentCostAnalysisTest, PipelineOverhead) {
/*pipeline_overhead_window_size_mib=*/(64.0 / 1024 / 1024)));

const HloInstruction* add = module->entry_computation()->root_instruction();
const float expected_compute_elapsed =
/*num_flops=*/8 / /*flops_per_second=*/8.0;
const float expected_compute_elapsed = std::max(
/*num_flops=*/8.0f / /*flops_per_second=*/8.0f,
hlo_cost_analysis_->min_latency_seconds(HloCostAnalysis::kFlopsKey));
LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed;
EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add),
expected_compute_elapsed);
Expand Down Expand Up @@ -230,5 +234,23 @@ TEST_F(MemorySpaceAssignmentCostAnalysisTest, PipelineOverhead) {
expected_compute_elapsed);
}

TEST_F(MemorySpaceAssignmentCostAnalysisTest, LatencyBoundCompute) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
ENTRY Entry {
param0 = f32[2,2] parameter(0)
param1 = f32[2,2] parameter(1)
ROOT add = f32[2,2] add(param0, param1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK(Initialize(module.get()));

const HloInstruction* add = module->entry_computation()->root_instruction();
EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), 1.0f);
}

} // namespace
} // namespace xla

0 comments on commit 98fb24a

Please sign in to comment.