Skip to content

Commit

Permalink
PR #16256: [XLA:GPU] Speed up priority fusion with incremental update
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16256

* Use incremental updates for producers that already calculated priorities. This avoid looking at unchanged consumers.
* Add `operands_to_new_consumers_` to record mapping from operand to new consumers and add `operands_to_removed_consumers_runtimes` to record mapping from operand to the runtimes of removed consumers.
* Also deferred the cache invalidation a bit cause some cache entries are still needed in `ComputeRuntimesOfRemovedConsumers`.
Copybara import of the project:

--
ba5ceb8 by cjkkkk <[email protected]>:

rebased and squashed

--
270c6f8 by cjkkkk <[email protected]>:

address comments

--
5a5bc75 by cjkkkk <[email protected]>:

fix clang

--
1aea158 by cjkkkk <[email protected]>:

use const span

--
fd212fd by cjkkkk <[email protected]>:

address comments

Merging this change closes #16256

COPYBARA_INTEGRATE_REVIEW=#16256 from Cjkkkk:priority_incremental_update fd212fd
PiperOrigin-RevId: 668007902
  • Loading branch information
Cjkkkk authored and copybara-github committed Aug 27, 2024
1 parent d3d3c22 commit 6ea9438
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 20 deletions.
10 changes: 10 additions & 0 deletions xla/service/gpu/model/gpu_performance_model_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ std::optional<absl::Duration> GpuPerformanceModelCache::Get(
return std::nullopt;
}

const absl::flat_hash_map<const HloInstruction*, absl::Duration>&
GpuPerformanceModelCache::GetAllConsumers(const HloInstruction& producer) {
return fusion_runtime_data_[&producer];
}

bool GpuPerformanceModelCache::ContainsConsumers(
const HloInstruction& producer) {
return fusion_runtime_data_.contains(&producer);
}

void GpuPerformanceModelCache::Set(const HloInstruction& instruction,
const EstimateRunTimeData& runtime_data) {
instruction_runtime_data_[&instruction] = runtime_data;
Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/model/gpu_performance_model_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ class GpuPerformanceModelCache {
std::optional<EstimateRunTimeData> Get(const HloInstruction& instruction);
std::optional<absl::Duration> Get(const HloInstruction& producer,
const HloInstruction& consumer);

const absl::flat_hash_map<const HloInstruction*, absl::Duration>&
// Returns cache entries for all consumers of this producer.
GetAllConsumers(const HloInstruction& producer);
// Checks if producer-consumer pair cache entries exist for this producer.
bool ContainsConsumers(const HloInstruction& producer);
// Sets cache value for the instruction or producer-consumer pair.
void Set(const HloInstruction& instruction,
const EstimateRunTimeData& runtime_data);
Expand Down
102 changes: 83 additions & 19 deletions xla/service/gpu/transforms/priority_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ GpuBackendConfig GetTritonGpuBackendConfig(
// max-benefit producer to fuse, and update the estimated benefits of the fused
// nodes and their operands.
class PriorityFusionQueue {
using Priority = int64_t;
using Priority = absl::Duration;
using CanFuseCallback = std::function<FusionDecision(
HloInstruction* /*producer*/, int64_t /*consumer operand_index*/)>;

Expand Down Expand Up @@ -211,7 +211,7 @@ class PriorityFusionQueue {

// If the priority is negative, it's not helpful to perform fusion on this
// instruction.
if (priority < 0) {
if (priority < absl::ZeroDuration()) {
continue;
}

Expand Down Expand Up @@ -312,6 +312,8 @@ class PriorityFusionQueue {
to_update_priority_.begin(), to_update_priority_.end()});

to_update_priority_.clear();
operands_to_new_consumers_.clear();
operands_to_removed_consumers_runtimes_.clear();
return absl::OkStatus();
}

Expand All @@ -325,7 +327,6 @@ class PriorityFusionQueue {
consumer->name(), "| inside PriorityFusion"),
*consumer, producer);
}
InvalidateCaches(consumer);
}

// Invalidates all cached value related to this instruction. Called before the
Expand All @@ -352,6 +353,47 @@ class PriorityFusionQueue {
fusion_info_cache_.Invalidate(instruction);
}

void UpdateRuntimes(
GpuPerformanceModel::RunTimes& runtimes, const HloInstruction* consumer,
const absl::flat_hash_map<const HloInstruction*, absl::Duration>&
original_consumers) {
auto it = original_consumers.find(consumer);
if (it != original_consumers.end()) {
runtimes.time_fused += it->second;
auto consumer_cache_result = gpu_performance_model_cache_.Get(*consumer);
CHECK(consumer_cache_result.has_value());
runtimes.time_unfused += (*consumer_cache_result).exec_time;
}
}

// Prepare for incremental updates
void ComputeRuntimesOfRemovedConsumers() {
for (auto pair : operands_to_new_consumers_) {
auto operand = pair.first;
// Checks if this producer's priority was calculated before. If so, we can
// do incremental update here.
if (!reverse_map_.contains(operand)) {
continue;
}
// Get all of this producer's original consumers. Bitcast/constant have
// priority calculated but they don't have cache entries.
if (!gpu_performance_model_cache_.ContainsConsumers(*operand)) {
continue;
}
const auto& original_consumers =
gpu_performance_model_cache_.GetAllConsumers(*operand);
GpuPerformanceModel::RunTimes runtimes;
for (auto consumer : current_consumers()) {
UpdateRuntimes(runtimes, consumer, original_consumers);
}
UpdateRuntimes(runtimes, current_producer(), original_consumers);
auto operand_cache_result = gpu_performance_model_cache_.Get(*operand);
runtimes.time_unfused += (*operand_cache_result).exec_time +
GpuPerformanceModel::kKernelLaunchOverhead;
operands_to_removed_consumers_runtimes_.emplace(operand, runtimes);
}
}

// Updates data for the new fusion instruction and its users and operands.
void OnFusingInstruction(HloInstruction* fusion,
HloInstruction* original_producer,
Expand Down Expand Up @@ -382,15 +424,6 @@ class PriorityFusionQueue {
RemoveInstruction(original_consumer);
}

// Detach 'original_producer' from its operands if it has no users.
// This avoids having it appear as a "phantom" user in subsequent priority
// calculations on 'fusion.operands' below, before it is finally removed
// in 'RemoveInstruction'.
if (original_producer->user_count() == 0) {
InvalidateCaches(original_producer);
original_producer->DetachFromOperandsAndUsers();
}

// Collect the instructions whose priorities need to be updated.
for (HloInstruction* operand : fusion->operands()) {
if (operand == original_producer ||
Expand All @@ -405,6 +438,9 @@ class PriorityFusionQueue {
}

to_update_priority_.insert(operand);
// update the consumers of this operand that we care about,
// so we can do incremental update of the operand
operands_to_new_consumers_[operand].push_back(fusion);
}
to_update_priority_.insert(fusion);
}
Expand Down Expand Up @@ -445,13 +481,13 @@ class PriorityFusionQueue {
Priority CalculateProducerPriority(HloInstruction* producer) {
// Bitcasts should always be fused first, since they are no-ops.
if (producer->opcode() == HloOpcode::kBitcast) {
return std::numeric_limits<Priority>::max();
return absl::InfiniteDuration();
}
// We always fuse constants, but the cost model doesn't handle them very
// well: fusing constants changes costs significantly. Also, there's no
// point recomputing priorities. Therefore, we fuse all of them at the end.
if (producer->opcode() == HloOpcode::kConstant) {
return std::numeric_limits<Priority>::min();
return -absl::InfiniteDuration();
}

// Don't fuse if we can't fuse in all users.
Expand All @@ -464,15 +500,35 @@ class PriorityFusionQueue {
step->set_producer_name(std::string(producer->name()));
step->set_reason(fusion_decision.Explain());
}
return std::numeric_limits<Priority>::min();
return -absl::InfiniteDuration();
}

auto removed_consumers_runtime_it =
operands_to_removed_consumers_runtimes_.find(producer);
bool is_incremental_update = removed_consumers_runtime_it !=
operands_to_removed_consumers_runtimes_.end();
absl::Span<HloInstruction* const> fused_consumers =
is_incremental_update
? operands_to_new_consumers_.find(producer)->second
: absl::MakeConstSpan(producer->users());
GpuPerformanceModel::RunTimes run_times =
GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
producer, *device_info_, &cost_analysis_,
GpuPerformanceModelOptions::PriorityFusion(
&fusion_analysis_cache_, &gpu_performance_model_cache_),
producer->users());
fused_consumers);
Priority current_priority;
if (is_incremental_update) {
// subtract the runtimes of removed consumers
const GpuPerformanceModel::RunTimes& removed_consumers_runtime =
removed_consumers_runtime_it->second;
run_times.time_unfused -= removed_consumers_runtime.time_unfused;
run_times.time_fused -= removed_consumers_runtime.time_fused;
// get the original priority
const PriorityQueue::iterator& queue_it =
FindOrDie(reverse_map_, producer);
current_priority = queue_it->first.first;
}

if (fusion_process_dump_) {
absl::MutexLock lock(&fusion_process_dump_mutex_);
Expand All @@ -485,8 +541,7 @@ class PriorityFusionQueue {
step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
}
return absl::ToInt64Nanoseconds(run_times.time_unfused -
run_times.time_fused);
return current_priority + run_times.time_unfused - run_times.time_fused;
}

FusionDecision IsTritonSupported(const HloInstruction& instruction) {
Expand Down Expand Up @@ -739,7 +794,10 @@ class PriorityFusionQueue {
// avoid recomputing priorities multiple times before we dequeue a new
// producer.
absl::flat_hash_set<HloInstruction*> to_update_priority_;

absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
operands_to_new_consumers_;
absl::flat_hash_map<HloInstruction*, GpuPerformanceModel::RunTimes>
operands_to_removed_consumers_runtimes_;
// Proto with structured logs of fusion decisions. Used only for debugging. If
// null, logging is disabled.
FusionProcessDumpProto* fusion_process_dump_;
Expand Down Expand Up @@ -904,12 +962,18 @@ absl::StatusOr<bool> PriorityFusion::Run(
changed = true;
}

fusion_queue->ComputeRuntimesOfRemovedConsumers();
if (producer->user_count() == 0) {
fusion_queue->InvalidateCaches(producer);
producer->DetachFromOperandsAndUsers();
fusion_queue->RemoveInstruction(producer);
// Remove from computation.
TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer));
}

for (auto* consumer : fusion_queue->current_consumers()) {
fusion_queue->InvalidateCaches(consumer);
}
TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities());
}

Expand Down

0 comments on commit 6ea9438

Please sign in to comment.