Skip to content

Commit

Permalink
Make unroll_config optional in WhileLoopUnroller::GetUnrollableLoops.
Browse files Browse the repository at this point in the history
This change allows more loops to be considered when using GetUnrollableLoops, e.g., in while_initializer_removal pass.

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16256 from Cjkkkk:priority_incremental_update fd212fd
PiperOrigin-RevId: 667782425
  • Loading branch information
fhoushmand authored and copybara-github committed Aug 27, 2024
1 parent d3d3c22 commit 2abfdb8
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 33 deletions.
10 changes: 10 additions & 0 deletions xla/service/gpu/model/gpu_performance_model_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ std::optional<absl::Duration> GpuPerformanceModelCache::Get(
return std::nullopt;
}

const absl::flat_hash_map<const HloInstruction*, absl::Duration>&
GpuPerformanceModelCache::GetAllConsumers(const HloInstruction& producer) {
return fusion_runtime_data_[&producer];
}

bool GpuPerformanceModelCache::ContainsConsumers(
const HloInstruction& producer) {
return fusion_runtime_data_.contains(&producer);
}

void GpuPerformanceModelCache::Set(const HloInstruction& instruction,
const EstimateRunTimeData& runtime_data) {
instruction_runtime_data_[&instruction] = runtime_data;
Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/model/gpu_performance_model_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ class GpuPerformanceModelCache {
std::optional<EstimateRunTimeData> Get(const HloInstruction& instruction);
std::optional<absl::Duration> Get(const HloInstruction& producer,
const HloInstruction& consumer);

const absl::flat_hash_map<const HloInstruction*, absl::Duration>&
// Returns cache entries for all consumers of this producer.
GetAllConsumers(const HloInstruction& producer);
// Checks if producer-consumer pair cache entries exist for this producer.
bool ContainsConsumers(const HloInstruction& producer);
// Sets cache value for the instruction or producer-consumer pair.
void Set(const HloInstruction& instruction,
const EstimateRunTimeData& runtime_data);
Expand Down
102 changes: 83 additions & 19 deletions xla/service/gpu/transforms/priority_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ GpuBackendConfig GetTritonGpuBackendConfig(
// max-benefit producer to fuse, and update the estimated benefits of the fused
// nodes and their operands.
class PriorityFusionQueue {
using Priority = int64_t;
using Priority = absl::Duration;
using CanFuseCallback = std::function<FusionDecision(
HloInstruction* /*producer*/, int64_t /*consumer operand_index*/)>;

Expand Down Expand Up @@ -211,7 +211,7 @@ class PriorityFusionQueue {

// If the priority is negative, it's not helpful to perform fusion on this
// instruction.
if (priority < 0) {
if (priority < absl::ZeroDuration()) {
continue;
}

Expand Down Expand Up @@ -312,6 +312,8 @@ class PriorityFusionQueue {
to_update_priority_.begin(), to_update_priority_.end()});

to_update_priority_.clear();
operands_to_new_consumers_.clear();
operands_to_removed_consumers_runtimes_.clear();
return absl::OkStatus();
}

Expand All @@ -325,7 +327,6 @@ class PriorityFusionQueue {
consumer->name(), "| inside PriorityFusion"),
*consumer, producer);
}
InvalidateCaches(consumer);
}

// Invalidates all cached value related to this instruction. Called before the
Expand All @@ -352,6 +353,47 @@ class PriorityFusionQueue {
fusion_info_cache_.Invalidate(instruction);
}

void UpdateRuntimes(
GpuPerformanceModel::RunTimes& runtimes, const HloInstruction* consumer,
const absl::flat_hash_map<const HloInstruction*, absl::Duration>&
original_consumers) {
auto it = original_consumers.find(consumer);
if (it != original_consumers.end()) {
runtimes.time_fused += it->second;
auto consumer_cache_result = gpu_performance_model_cache_.Get(*consumer);
CHECK(consumer_cache_result.has_value());
runtimes.time_unfused += (*consumer_cache_result).exec_time;
}
}

// Prepare for incremental updates
void ComputeRuntimesOfRemovedConsumers() {
for (auto pair : operands_to_new_consumers_) {
auto operand = pair.first;
// Checks if this producer's priority was calculated before. If so, we can
// do incremental update here.
if (!reverse_map_.contains(operand)) {
continue;
}
// Get all of this producer's original consumers. Bitcast/constant have
// priority calculated but they don't have cache entries.
if (!gpu_performance_model_cache_.ContainsConsumers(*operand)) {
continue;
}
const auto& original_consumers =
gpu_performance_model_cache_.GetAllConsumers(*operand);
GpuPerformanceModel::RunTimes runtimes;
for (auto consumer : current_consumers()) {
UpdateRuntimes(runtimes, consumer, original_consumers);
}
UpdateRuntimes(runtimes, current_producer(), original_consumers);
auto operand_cache_result = gpu_performance_model_cache_.Get(*operand);
runtimes.time_unfused += (*operand_cache_result).exec_time +
GpuPerformanceModel::kKernelLaunchOverhead;
operands_to_removed_consumers_runtimes_.emplace(operand, runtimes);
}
}

// Updates data for the new fusion instruction and its users and operands.
void OnFusingInstruction(HloInstruction* fusion,
HloInstruction* original_producer,
Expand Down Expand Up @@ -382,15 +424,6 @@ class PriorityFusionQueue {
RemoveInstruction(original_consumer);
}

// Detach 'original_producer' from its operands if it has no users.
// This avoids having it appear as a "phantom" user in subsequent priority
// calculations on 'fusion.operands' below, before it is finally removed
// in 'RemoveInstruction'.
if (original_producer->user_count() == 0) {
InvalidateCaches(original_producer);
original_producer->DetachFromOperandsAndUsers();
}

// Collect the instructions whose priorities need to be updated.
for (HloInstruction* operand : fusion->operands()) {
if (operand == original_producer ||
Expand All @@ -405,6 +438,9 @@ class PriorityFusionQueue {
}

to_update_priority_.insert(operand);
// update the consumers of this operand that we care about,
// so we can do incremental update of the operand
operands_to_new_consumers_[operand].push_back(fusion);
}
to_update_priority_.insert(fusion);
}
Expand Down Expand Up @@ -445,13 +481,13 @@ class PriorityFusionQueue {
Priority CalculateProducerPriority(HloInstruction* producer) {
// Bitcasts should always be fused first, since they are no-ops.
if (producer->opcode() == HloOpcode::kBitcast) {
return std::numeric_limits<Priority>::max();
return absl::InfiniteDuration();
}
// We always fuse constants, but the cost model doesn't handle them very
// well: fusing constants changes costs significantly. Also, there's no
// point recomputing priorities. Therefore, we fuse all of them at the end.
if (producer->opcode() == HloOpcode::kConstant) {
return std::numeric_limits<Priority>::min();
return -absl::InfiniteDuration();
}

// Don't fuse if we can't fuse in all users.
Expand All @@ -464,15 +500,35 @@ class PriorityFusionQueue {
step->set_producer_name(std::string(producer->name()));
step->set_reason(fusion_decision.Explain());
}
return std::numeric_limits<Priority>::min();
return -absl::InfiniteDuration();
}

auto removed_consumers_runtime_it =
operands_to_removed_consumers_runtimes_.find(producer);
bool is_incremental_update = removed_consumers_runtime_it !=
operands_to_removed_consumers_runtimes_.end();
absl::Span<HloInstruction* const> fused_consumers =
is_incremental_update
? operands_to_new_consumers_.find(producer)->second
: absl::MakeConstSpan(producer->users());
GpuPerformanceModel::RunTimes run_times =
GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
producer, *device_info_, &cost_analysis_,
GpuPerformanceModelOptions::PriorityFusion(
&fusion_analysis_cache_, &gpu_performance_model_cache_),
producer->users());
fused_consumers);
Priority current_priority;
if (is_incremental_update) {
// subtract the runtimes of removed consumers
const GpuPerformanceModel::RunTimes& removed_consumers_runtime =
removed_consumers_runtime_it->second;
run_times.time_unfused -= removed_consumers_runtime.time_unfused;
run_times.time_fused -= removed_consumers_runtime.time_fused;
// get the original priority
const PriorityQueue::iterator& queue_it =
FindOrDie(reverse_map_, producer);
current_priority = queue_it->first.first;
}

if (fusion_process_dump_) {
absl::MutexLock lock(&fusion_process_dump_mutex_);
Expand All @@ -485,8 +541,7 @@ class PriorityFusionQueue {
step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
}
return absl::ToInt64Nanoseconds(run_times.time_unfused -
run_times.time_fused);
return current_priority + run_times.time_unfused - run_times.time_fused;
}

FusionDecision IsTritonSupported(const HloInstruction& instruction) {
Expand Down Expand Up @@ -739,7 +794,10 @@ class PriorityFusionQueue {
// avoid recomputing priorities multiple times before we dequeue a new
// producer.
absl::flat_hash_set<HloInstruction*> to_update_priority_;

absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
operands_to_new_consumers_;
absl::flat_hash_map<HloInstruction*, GpuPerformanceModel::RunTimes>
operands_to_removed_consumers_runtimes_;
// Proto with structured logs of fusion decisions. Used only for debugging. If
// null, logging is disabled.
FusionProcessDumpProto* fusion_process_dump_;
Expand Down Expand Up @@ -904,12 +962,18 @@ absl::StatusOr<bool> PriorityFusion::Run(
changed = true;
}

fusion_queue->ComputeRuntimesOfRemovedConsumers();
if (producer->user_count() == 0) {
fusion_queue->InvalidateCaches(producer);
producer->DetachFromOperandsAndUsers();
fusion_queue->RemoveInstruction(producer);
// Remove from computation.
TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer));
}

for (auto* consumer : fusion_queue->current_consumers()) {
fusion_queue->InvalidateCaches(consumer);
}
TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities());
}

Expand Down
2 changes: 1 addition & 1 deletion xla/service/hlo_unstacker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct UnstackerMetadata {
VLOG(3) << "Prepared module: " << module->name() << " for unstacking.";
}
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> loops =
WhileLoopUnroller::GetUnrollableLoops(module, {});
WhileLoopUnroller::GetUnrollableLoops(module, {}, std::nullopt);
for (const auto& [instr, while_loop_config] : loops) {
metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config;
metadata.bodies[instr->while_body()] = instr;
Expand Down
3 changes: 2 additions & 1 deletion xla/service/scan_loop_accumulator_input_unification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ absl::StatusOr<bool> ScanLoopAccumulatorInputUnification::Run(
// accumulators and inputs that are by definition updated and read fully via
// dynamic-update-slice and dynamic-sliced within a loop.
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> unrollable_loops =
WhileLoopUnroller::GetUnrollableLoops(module, execution_threads);
WhileLoopUnroller::GetUnrollableLoops(module, execution_threads,
std::nullopt);

// TODO(b/337883537): We might want to simplify compare instructions before
// this. It helps us identify more inputs and accumulators.
Expand Down
17 changes: 10 additions & 7 deletions xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
WhileLoopUnroller::GetUnrollableLoops(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const UnrollConfig& unroll_config) {
std::optional<UnrollConfig> unroll_config) {
// Processing the while loops in the reverse topological order. If the body
// of while loop A calls while loop B, B comes before A.
std::vector<HloInstruction*> all_while_ops;
Expand All @@ -676,13 +676,16 @@ WhileLoopUnroller::GetUnrollableLoops(
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> while_loop_configs;
for (HloInstruction* instr : all_while_ops) {
std::optional<WhileLoopConfig> config = IsLoopUnrollable(instr);
if (config.has_value()) {
if (!InitialFeasibilityCheck(instr, config.value(), unroll_config)) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
if (!config.has_value()) {
continue;
}
if (unroll_config.has_value() &&
!InitialFeasibilityCheck(instr, config.value(),
unroll_config.value())) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
}
return while_loop_configs;
}
Expand Down
8 changes: 5 additions & 3 deletions xla/service/while_loop_unroller.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct WhileLoopConfig {
int64_t induction_var_idx;
};

// Config for unrollable while loops.
// Result for unrolled while loops.
struct UnrollResult {
// Whether it's unrolled.
bool unrolled = false;
Expand Down Expand Up @@ -120,12 +120,14 @@ class WhileLoopUnroller : public HloModulePass {
static std::optional<WhileLoopConfig> IsLoopUnrollable(
HloInstruction* while_op);

// Returns the list of unrollable loops in the given module
// Returns the list of unrollable loops in the given module. If
// `unroll_config` is provided, it will be used to check feasibility according
// to InitialFeasibilityCheck method
static std::vector<std::pair<HloInstruction*, WhileLoopConfig>>
GetUnrollableLoops(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const UnrollConfig& unroll_config = UnrollConfig());
std::optional<UnrollConfig> unroll_config);

// Unrolls the given while loop with the default behaviour set to full unroll.
// If wrap_in_trivial_loop is set, the unrolled body of the loop will be
Expand Down
2 changes: 1 addition & 1 deletion xla/service/while_loop_unroller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) {
ParseAndReturnVerifiedModule(hlo_string));

auto unrollable_loops =
WhileLoopUnroller::GetUnrollableLoops(module.get(), {});
WhileLoopUnroller::GetUnrollableLoops(module.get(), {}, std::nullopt);
// Only while1 and while2 are unrollable
EXPECT_EQ(unrollable_loops.size(), 2);
}
Expand Down

0 comments on commit 2abfdb8

Please sign in to comment.